276 Commits

Author SHA1 Message Date
hoshi-hiyouga
e2299e261b Merge pull request #7242 from hiyouga/hiyouga/release
[release] release v0.9.2

Former-commit-id: 6b25268990bf225d84e29d4067595cf720fa12d8
2025-03-11 15:28:45 +08:00
hoshi-hiyouga
8a44dce326 Merge pull request #7247 from hiyouga/hiyouga/commit
[misc] support print commit info

Former-commit-id: 0f7ec4f8529a5d7ea2153b881335821038307bb7
2025-03-11 15:28:04 +08:00
hoshi-hiyouga
6d9233833b Merge pull request #7244 from hiyouga/hiyouga/token
[data] avoid exit after saving preprocessed data

Former-commit-id: dcbf01b0035062fa14187e5bdbb925080d349501
2025-03-11 15:17:15 +08:00
hiyouga
d019603835 support commit info
Former-commit-id: a7d89a6dc10579deaf9f45825cc18405a27cade6
2025-03-11 15:13:59 +08:00
hiyouga
478e8194d9 remove exit in preprocess
Former-commit-id: f369b6ef41ffd9586ba568b88c5ff32a1af4bace
2025-03-11 15:08:25 +08:00
hiyouga
1890d3dafe release v0.9.2
Former-commit-id: e7ed1782d4a006400de6fc0f864abd01f7fadeea
2025-03-11 14:49:13 +08:00
hoshi-hiyouga
522a3e8493 [infer] fix vllm args (#7235)
Former-commit-id: 999be5b4512890b8cf4f45874a77e35cf35626f5
2025-03-11 01:15:35 +08:00
Ze-Yi LIN
18968405d0 [tracking] add swanlab_logdir param (#7219)
* feat: add swanlab_logdir param

* fix

Former-commit-id: 9215ad488b6ac6cd57fe8fa4acdacceb63f68ca5
2025-03-11 00:53:07 +08:00
hoshi-hiyouga
71a1c1321a [config] update args (#7231)
Former-commit-id: f71a901840811bf560df671ec63a146ff99140c6
2025-03-10 23:04:43 +08:00
hoshi-hiyouga
cf58a6d860 [config] fix export max len (#7230)
Former-commit-id: 211c0b3e8f3340acd2fae1762d9152a09f19ba34
2025-03-10 16:46:08 +08:00
hoshi-hiyouga
9adc0a2c3f [assets] update readme (#7209)
Former-commit-id: d1631b38dad9ba3d41aebbb00e3500eb79b9e8e9
2025-03-07 17:27:49 +08:00
hoshi-hiyouga
16419b2834 [data] fix loader (#7207)
* fix dataloader

* add test case

* fix type

* fix ci

* fix ci

* fix ci

* disable overwrite cache in ci

Former-commit-id: e84af0e140b1aafd1a6d6fe185a8e41c8fc5f831
2025-03-07 17:20:46 +08:00
hoshi-hiyouga
82a2bac866 [misc] fix ds config (#7205)
Former-commit-id: b478fa1d9de1858075769f86f57126fde92db813
2025-03-07 15:21:28 +08:00
ZhangChuanhui
151ef48b40 [data] fix function formatter (#7201)
Co-authored-by: zhangchuanhui <zhangchal@digitalchina.com>
Former-commit-id: 3efb32b986170d2839e526640f85ba230715879a
2025-03-07 15:17:23 +08:00
hoshi-hiyouga
a255c3a476 [misc] fix cli (#7204)
Former-commit-id: 999f57133ca163c7108d2d5ee8194eca9b2109b4
2025-03-07 15:01:18 +08:00
hoshi-hiyouga
f4ec4fa6ad [script] fix vllm version (#7193)
Former-commit-id: ababdde597b2b9bf0ab3f30f036bc8d97de07f03
2025-03-06 17:14:17 +08:00
hoshi-hiyouga
2635794727 [webui] support escape html (#7190)
Former-commit-id: cf9840374f171359c828b0d6f7a2aa9893c8f701
2025-03-06 16:52:21 +08:00
hoshi-hiyouga
d2f845d70d [deps] upgrade vllm (#7183)
Former-commit-id: 37678a3d64668c3b4a4bfefc054e3b9b40427c1a
2025-03-06 15:25:08 +08:00
hoshi-hiyouga
bb8aba5abf [data] fix mm template (#7181)
Former-commit-id: 648616d473c81d393592806307e3e25b159cb278
2025-03-06 15:18:32 +08:00
hoshi-hiyouga
9f16c50155 [model] add QwQ 32b (#7179)
Former-commit-id: 8897e48b8cd55407812453ddd4ff98ac7bdc4e91
2025-03-06 11:58:36 +08:00
Ze-Yi LIN
25bb9f5ad9 [trainer] fix swanlab callback (#7176)
Former-commit-id: 6d9acf4bd30db24499118aee16bd19cb19ba9e3d
2025-03-06 00:33:37 +08:00
hoshi-hiyouga
7b985f55db [trainer] update config (#7174)
Former-commit-id: 9f535d0e3c4ee3cd0f1b65218c2eee5d03f43c6f
2025-03-05 23:32:54 +08:00
sirui.li
fd0357a26d [data] fix qwen2audio plugin (#7166)
* Update pairwise.py

[data]Repair multimodal model dpo training

* Update pairwise.py

[data]repair multimodal model dpo training using deepcopy

* Update pairwise.py

* Update mm_plugin.py

Former-commit-id: 86763dfdb8e9e5668c1ddd7e924e4be76bf78368
2025-03-05 18:03:36 +08:00
hoshi-hiyouga
31f9daa362 [data] use bicubic resampler (#7143)
Former-commit-id: c708f19ab0ab57526134952afddaa90aae8decbf
2025-03-04 00:17:06 +08:00
hoshi-hiyouga
15ea576246 [webui] fix webui (#7142)
Former-commit-id: d07281f8a45ad8a38d390181d01dcadbcf9aa1b9
2025-03-04 00:01:49 +08:00
rabbit
19a6916d80 [data] bailing template (#7117)
* add bailing template

* add bailing template

* add bailing template

---------

Co-authored-by: chengshiwen.csw@antgroup.com <chengshiwen.csw@antgroup.com>
Former-commit-id: 4a36f5e0abb5a63f4b3b81560bb1ad0e6832d379
2025-03-03 15:33:22 +08:00
hoshi-hiyouga
585c475f71 [inference] fix hf_engine (#7120)
Former-commit-id: f8cf5319cb5d6e06a1b0d8b8db2b678627f2271e
2025-03-01 05:22:49 +08:00
hoshi-hiyouga
e62dae37fe [assets] update wechat (#7106)
Former-commit-id: 0ea430060994631e9fdb18fbbca0dd565a04fd66
2025-02-28 12:01:04 +08:00
Ze-Yi LIN
11672f760d [webui] display swanlab exp link (#7089)
* webui add swanlab link

* change callback name

* update

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 27a4b93871c63b839c92940766bd7e0177972c9b
2025-02-27 19:40:54 +08:00
leo-pony
b9f84900ee [npu] update cann base image and torch 2.4 (#7061)
* Update base npu container image version:The Python version required for Hugging Face Transformers is >= python3.10

* Fix the bug: arg type of INSTALL_DEEPSPEED shoud been string now.

* Update Ascend CANN, CANN-Kernel and corresponding torch and torch-npu version

* Upgrade torch-npu needs packages' version: torch==2.1.0 and torch-npu==2.4.0.post2

Former-commit-id: d6dafada58412b0c801e576ef4d8d96203f792af
2025-02-25 23:32:01 +08:00
hoshi-hiyouga
5f65558088 [misc] fix project toml (#7067)
Former-commit-id: 28a668ff4e0beebfe5387362f5518c1d9343666f
2025-02-25 23:22:48 +08:00
JieShen
0f54a78144 [script] add seed args (#7058)
* add seed args

* add seed args

* update seed

Former-commit-id: eb9770b2c01a840b6a0ac119210c22bdbb81e18b
2025-02-25 19:44:57 +08:00
Kingsley
2986bef530 [model] add paligemma2-mix series (#7060)
Former-commit-id: 0c0196306d343242ee5e6f22c55562f9a74aa782
2025-02-25 18:51:16 +08:00
hoshi-hiyouga
065f7fb5da [data] fix mllama (#7053)
* fix mllama

* fix test

Former-commit-id: f5af20a63f3d59a6a68d323a7c6f68e551edb3a3
2025-02-24 22:05:38 +08:00
hoshi-hiyouga
c1d5073bd3 [model] add models (#7054)
* add qwen25vl awq models

* add moonlight

Former-commit-id: ae3be2970fea8a35907202a313ab767381c44916
2025-02-24 22:05:13 +08:00
hoshi-hiyouga
ee46011b34 [assets] update readme (#7051)
Former-commit-id: c89a39bfc6a3f0aaa376cd1b221320f466aba617
2025-02-24 20:45:06 +08:00
hoshi-hiyouga
d55f420206 [assets] update wechat (#7019)
Former-commit-id: 3d102fe7e0bfc23db7d75f90ebaf53216c54cc85
2025-02-20 20:32:33 +08:00
Zhangchi Feng
fcf75633a0 [data] fix MiniCPMV plugin (#6998)
* fix template

* fix bug in messages processing

Former-commit-id: f98b828f53968fb9c72bff9e45510ad5586c4fab
2025-02-19 19:36:04 +08:00
hoshi-hiyouga
e77ced045d [webui] update css (#6985)
Former-commit-id: 760a1dfb8193de418d7aa1063c0d111a3a64ae0f
2025-02-18 18:27:57 +08:00
hoshi-hiyouga
331f53381f [data] add r1 distill dataset (#6983)
Former-commit-id: 1da5ee4edaa3896593b9cae488f0ac5917c3243e
2025-02-18 17:25:09 +08:00
hoshi-hiyouga
1d675a287d [version] support transformers 449 (#6982)
* support transformers 449

* fix mm plugin

Former-commit-id: e9118a9df0839d24f6ddff5a0b55ef101a1d3d22
2025-02-18 17:05:40 +08:00
hoshi-hiyouga
be33ef67fb [misc] fix script (#6977)
Former-commit-id: 775efa1d8cbdb1b7d122be2a986d47f85214e0a1
2025-02-18 17:00:46 +08:00
hoshi-hiyouga
f5cd17881e [data] update vlm args (#6976)
Former-commit-id: c28e710636a0286d4b8a1d494529b25168a8f3ab
2025-02-18 02:12:51 +08:00
hoshi-hiyouga
c09b648934 [data] add min resolution option (#6975)
Former-commit-id: 76bd9a98a2fb00f1a1d881e6e1364c02fd36d327
2025-02-18 01:40:46 +08:00
hoshi-hiyouga
f2fd9d1b25 [data] fix predict dataset (#6972)
Former-commit-id: f9a82e527877b1ed47cabb3d34f4d155705f4048
2025-02-17 20:29:40 +08:00
Zhangchi Feng
167342af8a [data] fix minicpmo template (#6946)
Former-commit-id: 09e4438b58d5c1a5fdde37ff781c3d79461c4743
2025-02-15 00:37:41 +08:00
Eric Tang
76f9bd1820 [ray] specify ray storage path (#6920)
Former-commit-id: 4be6b66b1eaa79955e936ce2b747a8837ecd1e49
2025-02-14 21:55:41 +08:00
hoshi-hiyouga
a893505924 [misc] fix lora regex (#6944)
* fix lora regex

* fix

Former-commit-id: 1d0ecbaee1b72f1e03154ddd4fcc8b7876e01f89
2025-02-14 21:38:43 +08:00
hoshi-hiyouga
ed25e051a9 [misc] fix grad ckpt (#6931)
Former-commit-id: deae1fc9a0bea5c8b8be1564cf9c81c9c02a0b3a
2025-02-13 23:27:51 +08:00
hoshi-hiyouga
5e5fc337f9 [model] add liger kernel to qwen2_5 vl (#6930)
* add liger kernel to qwen2_5 vl

* fix patch

* fix patch

Former-commit-id: 828776d155986166498dfc907194f64436571106
2025-02-13 23:05:54 +08:00
Billy Cao
58e9ca8aa0 [trainer] fix gen_kwarg to eval during training (#5451)
* Correctly pass gen_kwarg to eval during model runs

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 845d16122496311e08263610a6a922f82604de7b
2025-02-13 02:35:06 +08:00
SrWYG
a4c4b8496f [data] evaluate on each dataset (#5522)
* [Update] loader.py , evaluate will run separate evaluations on each dataset.

`If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run separate evaluations on each dataset. This can be useful to monitor how training affects other datasets or simply to get a more fine-grained evaluation`

seq2seqtrainner support eval_dataset as Dict.

* fix format

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: cf00f78650a442c85678ce805e030d2b96cbecd7
2025-02-13 02:19:03 +08:00
Noah
38c9641777 [data] improve error handling (#6128)
* sync from upstream

* update

* update

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 1569e6096fec07da5583f1a3435b0d23ae09b5ba
2025-02-13 01:39:41 +08:00
hoshi-hiyouga
8b8fdb3a85 [misc] update readme (#6918)
Former-commit-id: f5823479bd51c39db668b68056be749af09894d1
2025-02-13 01:01:41 +08:00
hoshi-hiyouga
290057069e [misc] update readme (#6917)
Former-commit-id: 6bbed1d8c4189fb7bea40230e278c40bb5336fbd
2025-02-13 00:58:10 +08:00
hoshi-hiyouga
46203856fc [breaking change] refactor data pipeline (#6901)
* refactor data

* rename file

Former-commit-id: 7a1a4ce6451cb782573d0bd9dd27a5e443e3a18b
2025-02-13 00:39:20 +08:00
Eric Tang
80b89978d9 [misc] support for launching LLaMA-Factory with uv run (#6907)
* yay

* uv with ray temporary commit

* remove ray specific code for now

* cleanup

Former-commit-id: 1a9cab6de49e300bf9c747eefbb11d693592b477
2025-02-13 00:38:44 +08:00
Eric Tang
5a221d91f9 [example] fix path to ray example (#6906)
Former-commit-id: e9bee3ef045d85051da04e6ad581a23a9e1a9551
2025-02-13 00:29:32 +08:00
hoshi-hiyouga
3a3f4072e5 [misc] fix grad ckpt func (#6916)
Former-commit-id: 35e069a52b3d7cfd9b0107574b09265eb2290f0b
2025-02-13 00:17:18 +08:00
marko1616
0c0cdc26bc [trainer] fix llama3.2 vision kto train (#6904)
Former-commit-id: 1563e89adc8988fc6e4250634a3f1e385979b0e5
2025-02-12 19:09:14 +08:00
hoshi-hiyouga
2581cc844b [data] feat: auto template (#6905)
* support auto template

* add unittest

Former-commit-id: 0c6c9150db6414a5a05527ea486dce6633dff4b3
2025-02-12 00:22:53 +08:00
hoshi-hiyouga
d58fcd094e [misc] update readme (#6903)
Former-commit-id: 830d028939149d54bc91b6bda110dfa5de949483
2025-02-11 22:51:26 +08:00
hoshi-hiyouga
86063e27ea [data] fix ollama template (#6902)
* fix ollama template

* add meta info

* use half precision

Former-commit-id: 1304bbea69d8c8ca57140017515dee7ae2ee6536
2025-02-11 22:43:09 +08:00
hoshi-hiyouga
88eafd865b [misc] support export ollama modelfile (#6899)
* support export ollama modelfile

* update config

* add system and num ctx

Former-commit-id: 8c2af7466f4015f300b51841db11bcd2505ebf20
2025-02-11 19:52:25 +08:00
hoshi-hiyouga
3f7bd98bfa [data] refactor template (#6896)
Former-commit-id: f78d5a3eca947ed965ca2f6c87d60441b1a59867
2025-02-11 17:59:25 +08:00
codingma
b72c4bd118 support ollama modelfile export (#4686)
Former-commit-id: 15cca102a7fc0d08b5d049cf264acc6fa576b104
2025-02-11 17:52:24 +08:00
hoshi-hiyouga
808ff89a2d [data] refactor mm plugin (#6895)
* refactor plugin

* lint

Former-commit-id: 1c8dcc3adca4a2e78f514f8bb70573dd1ca08746
2025-02-11 16:34:49 +08:00
HJ
6d7f1299bd [data] fix qwen_2_5_vl video processing (#6868)
* fix qwen_2_5_vl video processing

* Update mm_plugin.py

* Update mm_plugin.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 35f326dabdc8e84036296d2e3de1c84c67b8def8
2025-02-11 16:14:50 +08:00
hoshi-hiyouga
0420a608ca [assets] update wechat (#6892)
Former-commit-id: 0b268cc903a583ae78cb7e63d2bdc4602d7220fc
2025-02-11 13:56:26 +08:00
Zhangchi Feng
2047eab723 [da'ta] fix minicpmv plugin (#6890)
* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

* support dpo of minicpmv

* update init audio

* update init audio

* [model]fix image process in minicpmo

* fix no mm inputs

Former-commit-id: cdd19ccd8cec460606b4545e886e932c1c5c5fe1
2025-02-11 13:30:44 +08:00
HJ
e11b40c344 [data] fix: sharegpt converter (#6879)
* fix-sharegpt-format

* fix

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: ae8f8151ff750839998b50446f127061f240d41a
2025-02-10 21:59:12 +08:00
hoshi-hiyouga
b869506a57 [data] fix mllama collator (#6874)
Former-commit-id: c694fa3d66651c6ce547fa72c8260c46a406126b
2025-02-09 22:42:25 +08:00
hoshi-hiyouga
72d5b06b08 [test] align test cases (#6865)
* align test cases

* fix function formatter

Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
2025-02-09 01:03:49 +08:00
hoshi-hiyouga
94726bdc8d [dataset] add openthought (#6866)
Former-commit-id: 20c748a4f108c0087f0d85377a4aa99126a0beb0
2025-02-09 00:53:01 +08:00
hoshi-hiyouga
4d1791e905 [deps] upgrade vllm (#6857)
Former-commit-id: 4bd50f65a3d62528768561019fda2723d045c7fd
2025-02-08 15:02:28 +08:00
hoshi-hiyouga
528e06ccaa fix qwen2vl plugin (#6855)
Former-commit-id: fd13b7138ab3f4da0a429a327b9d076bcb70b944
2025-02-08 10:59:10 +08:00
hoshi-hiyouga
fec641ec82 [misc] allow extra args (#6831)
Former-commit-id: 0fd3a5295cb4e08a4e57e860e82103364c28fba8
2025-02-06 12:38:08 +08:00
Zhangchi Feng
8f401e37f8 [model] support audio (#6701)
* support qwen2_audio

* improve code

* lint

* fix

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
2025-02-05 04:59:09 +08:00
Yueqi Song
9feb78e7b4 [data] allow thought in function call (#6797)
* Update template.py

* Update template.py

* use formatter

* fix regex

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 3a31af6e920683ec074da93b1719e29f5d4cffd6
2025-02-05 02:26:23 +08:00
hoshi-hiyouga
c2022431aa [misc] update license year & fix llama pro (#6814)
* fix llamapro script

* change year

Former-commit-id: d9ae594178796994d400a5f207d6499712816f89
2025-02-05 01:53:33 +08:00
Yueqi Song
0817c24c04 [data] fix qwen tool template (#6796)
* Update tool_utils.py

* fix unittest

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 02bb78a792112f5151b3a96ddde2528823855288
2025-02-05 00:02:00 +08:00
Zhangchi Feng
cfb926fb84 [data] fix minicpmv plugin (#6801)
* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

* support dpo of minicpmv

* update init audio

* update init audio

* [model]fix image process in minicpmo

Former-commit-id: 8f704c8b6228ef50f828014f85dce67fda868660
2025-02-04 21:20:15 +08:00
neavo
34746d6151 [readme] update flash attention installation instruction on win platform (#6788)
* Update README_zh.md

* Update README.md

Former-commit-id: e48d1327fb39cc95f8fbfc746494f67a79471893
2025-02-01 12:43:29 +08:00
hoshi-hiyouga
5bb447b118 [misc] update workflows (#6787)
Former-commit-id: 15add6b250149e2aeabdc62d7dca69fc06054e01
2025-02-01 04:54:42 +08:00
hoshi-hiyouga
a28261a866 [model] add mistral small models (#6786)
Former-commit-id: e5e95c39bc4199fa89c67e34f9adaaa987058744
2025-02-01 04:31:38 +08:00
hoshi-hiyouga
800de98dc8 [model] add qwen2.5 vl models (#6779)
Former-commit-id: ed46fb4f6194c30060b908092464dded12e5787c
2025-01-31 03:00:29 +08:00
hoshi-hiyouga
222423bcef [breaking] support transformers 4.48 (#6628)
Former-commit-id: f154ab175c513a4d7bb866bf2cffc34b77b50508
2025-01-31 01:36:33 +08:00
hoshi-hiyouga
e71737351f [webui] improve webui & reasoning mode (#6778)
Former-commit-id: 3f17fc0d7163372e0446f1a38792ff761e99b739
2025-01-31 00:09:21 +08:00
qvlehao
4f298894da [model] add deepseek-R1 & show think process (#6767)
Former-commit-id: 4dccb724af51208a001c96fefbdbf226be09e50c
2025-01-29 12:16:26 +08:00
yinpu
a8fae3869d fix: avoid redundant normalization in DPO's SFT loss calculation (#6722)
Former-commit-id: 971a8ccbdacf130763d40c7ef82a711b2fc1292f
2025-01-21 13:38:02 +08:00
engchina
db9b977e4f [webui] support ja (#6698)
* add support for japanese language

* add support for japanese language

---------

Co-authored-by: engchina <atjapan2015@gmail.com>
Former-commit-id: 88692e403f9b5085dd0c7c2b2c68656c5da50dd4
2025-01-20 19:46:38 +08:00
hoshi-hiyouga
87d685b59f [model] support yarn (#6693)
Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
2025-01-18 13:56:09 +08:00
hoshi-hiyouga
e4046bdd1f [assets] update wechat (#6692)
Former-commit-id: 70dba5fab6f4c9225758cafb646113d8e80ac084
2025-01-18 12:35:03 +08:00
hoshi-hiyouga
5baa3add8c [misc] update mm plugin (#6691)
Former-commit-id: 00303338d6927b1fda58b23340a31a8fa009f706
2025-01-17 23:04:26 +08:00
hoshi-hiyouga
332f637592 disable valset by default (#6690)
Former-commit-id: a1a94f364e33d1d73852f74eda4fa581e6b16533
2025-01-17 21:09:30 +08:00
hoshi-hiyouga
31daa6570b [webui] upgrade to gradio 5 (#6688)
Former-commit-id: 9df7721264ddef0008d7648e6ed173adef99bd74
2025-01-17 20:15:42 +08:00
hoshi-hiyouga
33525a34b6 fix qwen2 moe (#6684)
Former-commit-id: ab624419fa0ab23ef7a331a0ec14e393328772b5
2025-01-17 13:46:09 +08:00
Zhangchi Feng
3607caa2ad [data] Fix minicpmv/o dpo training (#6657)
* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

* support dpo of minicpmv

Former-commit-id: 8d9f47b98047f370637d1c96c2f3440dcc738ef3
2025-01-15 17:30:37 +08:00
steveepreston
0fc2e19279 Update val_size english description (#6653)
* Update `val_size` Description in locales.py

* Update `val_size` Description in data_args.py

* Remove extra space in data_args.py

Former-commit-id: f1ba5158091446dce540dd796284037bdd724c38
2025-01-15 16:00:20 +08:00
hoshi-hiyouga
ef994600db update readme (#6648)
Former-commit-id: b47467276ab3174c50329b3c8b76823bc0a2249c
2025-01-15 11:06:19 +08:00
hoshi-hiyouga
7638f1070e [optim] clean apollo (#6645)
* clean apollo code

* update readme

Former-commit-id: 38b8ec4a99189483124b54df9d6bc6b0d318855a
2025-01-15 01:42:50 +08:00
zhuHQ
c2120432db [optim] add support to APOLLO (#6617)
Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
2025-01-15 00:24:56 +08:00
Zhangchi Feng
66184762e8 update readme of MiniCPM-o (#6642)
* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

Former-commit-id: 68604050ae2c98aeef5e9a6b4d2c11a4eb609bfa
2025-01-14 21:22:35 +08:00
hoshi-hiyouga
41a9e231cb lint (#6641)
Former-commit-id: 79731ae13ecd17eb8646fb53162c81dddfef3b00
2025-01-14 18:40:07 +08:00
Haian Huang(深度眸)
1bb06e06df Support InternLM3 Dense 8B Model (#6640)
* support internlm3

* update

* update

* update

* add hint

Former-commit-id: 24ab7ae0944c5f373e9cac60f0332e704824a057
2025-01-14 18:07:27 +08:00
Xiaosu Zhu
381f7120e6 Fix tokenizer max length (#6632)
Former-commit-id: 1807c7ba033985490aa7c8c39d880da6af983b92
2025-01-14 17:35:54 +08:00
Zhangchi Feng
f7857c83e1 Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)
* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
2025-01-14 17:34:58 +08:00
hoshi-hiyouga
d0da6f40b0 [model] fix mllama any image (#6637)
* fix mllama any image

* reorder classes

Former-commit-id: 1242a1c4b4a465c06363fdc59302e80e5c4c96e6
2025-01-14 16:47:58 +08:00
hoshi-hiyouga
28d145a066 pin vllm version to 0.6.5 (#6629)
Former-commit-id: 26097ca0adf25ebb7d9e8eec2d2cef673c6cfe88
2025-01-14 02:44:02 +08:00
Zhangchi Feng
ae32c148d1 Support new features of MiniCPM-V (#6626)
* fix template name

* tiny fix

* support minicpm-o-2.6

Former-commit-id: 53034a61c7654358f46916cbc370910fb2aeff3b
2025-01-14 00:26:19 +08:00
hoshi-hiyouga
2a05941b14 [inference] fix stop token for object detection (#6624)
* fix stop token

* update minicpm data pipeline

* fix npu qlora examples

Former-commit-id: 844919fadaa8a61dfae47020971ea80730b2346f
2025-01-13 21:34:20 +08:00
codingma
11c38b9173 add nf4 qlora support on Ascend NPU (#6601)
* add nf4 qlora support on Ascend NPU

* add transformers version check

* add python>=3.10 requirement description for npu

* tiny fix

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 7912d1acac5f10dab22145fe729a90c57aad8d85
2025-01-13 19:43:36 +08:00
Zhangchi Feng
73c1c15b62 Fix template name of MiniCPM-V (#6620)
* fix template name

* tiny fix

Former-commit-id: 94dea52cef709a7e6f1cdc0b78e83e0422bd65d3
2025-01-13 16:46:48 +08:00
hoshi-hiyouga
7f58bf984f Merge pull request #6598 from BUAADreamer/minicpmv
[model] Support MiniCPM-V

Former-commit-id: 251e82bec12eaea6cf13608de191c096c63d1214
2025-01-13 15:24:02 +08:00
fzc8578
ec552372ba remove tests
Former-commit-id: 51addcd7ab81548a9952064dd8c95a8542252003
2025-01-13 15:08:35 +08:00
fzc8578
17d32fb5c7 fix tests
Former-commit-id: 582a17a12010943c7ca1cc0e25ebc8d125d10b45
2025-01-13 15:01:39 +08:00
fzc8578
4b61610b12 fix style
Former-commit-id: 76a36d9acecbf36b6959a14caacfed1d32bcee41
2025-01-13 14:19:38 +08:00
fzc8578
07798e4aad fix system prompt and tests
Former-commit-id: 955efca677b299749f3d40d587ee310951537543
2025-01-13 14:18:06 +08:00
fzc8578
6d6acd0213 add some
Former-commit-id: 5ad8ef3ec434f53f6fc494474becb034a3aca0ca
2025-01-11 15:03:20 +08:00
fzc8578
a789e0f263 add cpm_o test
Former-commit-id: 53cade69caed82b470fdb249274f03ee34af3100
2025-01-11 11:55:30 +08:00
fzc8578
f9ee00b6b6 add cpm_o test
Former-commit-id: 81dc0f678a7609c834581d956387bde42652755d
2025-01-11 11:49:03 +08:00
fzc8578
31bfdb08cd fix format
Former-commit-id: 964e18be5a824950164bc7232d35822a8b116d1a
2025-01-11 01:27:40 +08:00
fzc8578
12c83e00fc add some
Former-commit-id: 6233764d18f31365e9ba450408306fad55567ffc
2025-01-11 01:10:24 +08:00
fzc8578
9dc7b6c7ac adapt to new mllm_param
Former-commit-id: 0775b71965863c2618c117726a1046a36d6d85b8
2025-01-11 00:16:34 +08:00
Zhangchi Feng
627548bf7f Merge branch 'main' into minicpmv
Former-commit-id: 8a9c90759feda975faadc5858bd44b7ea116e7fb
2025-01-11 00:01:36 +08:00
hiyouga
dc65ecdf09 refactor mllm param logic
Former-commit-id: b895c190945cf5d991cb4e4dea2ae73cc9c8d246
2025-01-10 15:45:48 +00:00
fzc8578
e577990eb2 add minicpmv2.6
Former-commit-id: 1ab0aea54b54066cad500b7969b86a0e952d396d
2025-01-10 23:45:44 +08:00
fzc8578
1f3b729a4b add some
Former-commit-id: 58f50b8729083e9ea0fdcf07042b06261670ad57
2025-01-10 23:29:06 +08:00
fzc8578
0aa7ac210f add some
Former-commit-id: 3acd151a0f8efdd230c0b0980550795d204a69f7
2025-01-10 21:25:32 +08:00
fzc8578
40382f1387 fix some
Former-commit-id: 1eb7118db3ad6054cfd59d5f16a5d882e40e9057
2025-01-10 20:55:52 +08:00
fzc8578
75b3819e43 fix version
Former-commit-id: 834903fbf7a0fc8ac110f62f4df7c13819dd3c68
2025-01-10 20:31:04 +08:00
fzc8578
e63c2df0b1 fix some
Former-commit-id: cd5a1a8b9c6eb59d6e95f79573f60ad8668f1942
2025-01-10 20:27:06 +08:00
fzc8578
25d4889789 tiny fix
Former-commit-id: f088e580d3bacd0eecd0c3bf17e928eb49832ba1
2025-01-10 20:15:39 +08:00
Zhangchi Feng
8c0a721c4c Merge branch 'main' into minicpmv
Former-commit-id: d8840ae416660e23f1d615ffd404f519360151d9
2025-01-10 20:12:07 +08:00
fzc8578
9e972bc9ec add some
Former-commit-id: fede563aeb716ba5d1e368fd3e1182e4e580d248
2025-01-10 20:01:22 +08:00
hoshi-hiyouga
1675712a4c Merge pull request #6588 from hiyouga/hiyouga/upd_issue_temp
[gh] update issue template

Former-commit-id: 0a2626f996ce61559e93bedf19083aac5c861666
2025-01-10 03:03:48 +08:00
hiyouga
e0c9012f7f update issue template
Former-commit-id: 2bfca993588d8087dfd118f6f02486bbe752b166
2025-01-09 18:58:53 +00:00
hoshi-hiyouga
a25024bd0c Merge pull request #6585 from hiyouga/hiyouga/add_phi4
[model] add phi4 model

Former-commit-id: 0ae6a9b7bf9f1d6d844b97406b4795363bf75e78
2025-01-10 02:39:17 +08:00
hiyouga
867980196e improve template, add phi4 model
Former-commit-id: a785b6796e445a3adba45c5b6947166a2ff99871
2025-01-09 18:27:54 +00:00
hoshi-hiyouga
4e25d037c8 Merge pull request #6564 from stephen-nju/fix_ray
Fix ray

Former-commit-id: d4566839369726023f1b6e8f4b2332bda0c715cc
2025-01-08 18:14:18 +08:00
hoshi-hiyouga
6ba6926221 Merge pull request #6565 from hiyouga/hiyouga/improve_log
[misc] imporve log

Former-commit-id: 538bf7b839c63d6a6758522fa08999d9b78e9db2
2025-01-08 18:08:21 +08:00
zhubin
b6b53b61f7 fix –get ray args when args not a dict
Former-commit-id: 5e5398cd5b117b2378107172d3f91cfb0321e842
2025-01-08 10:06:02 +00:00
hiyouga
647c51a772 imporve log
Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
2025-01-08 09:56:10 +00:00
hoshi-hiyouga
3b843ac9d4 Merge pull request #6542 from erictang000/et/ray-integration
Ray Train integration with LLaMA-Factory

Former-commit-id: 4e34ee0a8e0aa90b535e53608b51c5c0804db34e
2025-01-08 11:46:03 +08:00
hiyouga
0ef1f981da fix llamaboard with ray
Former-commit-id: bd8a432d6a980b1b24a551626304fe3d394b1baf
2025-01-07 09:59:24 +00:00
hiyouga
944a2aec4d refactor ray integration, support save ckpt
Former-commit-id: 2f50b27e608b2092bfceab6c6e84e6631e973ee2
2025-01-07 09:39:10 +00:00
Eric Tang
4f31ad997c run style check
Former-commit-id: 5ec33baf5f95df9fa2afe5523c825d3eda8a076b
2025-01-07 08:55:44 +00:00
Kourosh Hakhamaneshi
8683582300 drafting ray integration
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>

Former-commit-id: 19c12ddae9350f6e25a270fe3372f5b9094cf960
2025-01-07 08:55:44 +00:00
hoshi-hiyouga
5ccc607222 Merge pull request #6547 from hiyouga/hiyouga/fix_pixtral_dpo
[trainer] fix pixtral dpo

Former-commit-id: 920bb2a8922847fa544e2c260c67161e64cf5d50
2025-01-07 14:38:55 +08:00
hiyouga
d8bd46f1bf fix #6546
Former-commit-id: 6fcf2f10faf3b1614896b091591eeef96d717e64
2025-01-07 06:30:44 +00:00
fzc8578
8c2a712247 add some
Former-commit-id: b4790c66c126567bd193de52a564e3ce11c94769
2025-01-06 19:32:39 +08:00
hoshi-hiyouga
53e41bf2c7 Merge pull request #6528 from hiyouga/hiyouga/upd_wechat
[assets] update wechat

Former-commit-id: 3ceedf44896b5ebc406d6398b3f15e74e4710fbe
2025-01-04 16:01:21 +08:00
hiyouga
0eeae9061c update wechat
Former-commit-id: 11a9d96a042e8afd972e0bf2fa3e51f95e4799ec
2025-01-04 07:59:57 +00:00
Zhangchi Feng
08729dbefc Merge branch 'hiyouga:main' into minicpmv
Former-commit-id: 873b2d5888038e2328a12a6eb7c84099ba7ca1f3
2025-01-04 11:20:33 +08:00
fzc8578
2c120aa0df add some
Former-commit-id: 81176fe226da89eace89cb202bad68e73b7c2a02
2025-01-04 11:11:15 +08:00
hoshi-hiyouga
cca6286b6f Merge pull request #6524 from hiyouga/hiyouga/upd_scripts
[misc] update scripts

Former-commit-id: 6ba3ec45fc369c095ab9a1fbd9847dc66cf24ca4
2025-01-03 23:52:26 +08:00
hiyouga
8516054e4d update scripts
Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
2025-01-03 10:50:32 +00:00
hoshi-hiyouga
d1a8cd67d2 Merge pull request #6515 from hiyouga/hiyouga/misc
[misc] update model name

Former-commit-id: f92eea4090351dcd3c364e10a9eec0d17d480e12
2025-01-02 20:20:02 +08:00
hiyouga
8a5b4bdfd4 update model name
Former-commit-id: bf627d9f1ac117f040adbfd7630b5283f0db556a
2025-01-02 12:19:21 +00:00
hoshi-hiyouga
3bceef02ee Merge pull request #6514 from hiyouga/hiyouga/add_project
[readme] add project

Former-commit-id: 0bd0c373183731302f1af9f33a1f8ff70ba743e2
2025-01-02 20:16:15 +08:00
hoshi-hiyouga
166a830938 Merge pull request #6513 from hiyouga/hiyouga/add_gpt2
[model] add gpt2 model

Former-commit-id: 859c37f43c8a49eea4f118d0d00ee2a554f6bd4f
2025-01-02 20:15:55 +08:00
hiyouga
18767fe026 add project
Former-commit-id: 3b7e745d271e36b4cfe8826820b23254e1debfe9
2025-01-02 12:15:41 +00:00
hiyouga
18a1a4b9da add gpt2 model
Former-commit-id: 37d5e3639fcf5ae6e58cc435e0fa9dee0d6e4ead
2025-01-02 12:07:38 +00:00
hoshi-hiyouga
6015fe700e Merge pull request #6512 from hiyouga/hiyouga/fix_gen_logic
[trainer] fix generate logic

Former-commit-id: b97759421c535560ade631a7fa0a57b7c0da50f1
2025-01-02 19:36:54 +08:00
hoshi-hiyouga
369dae8dd3 Merge pull request #6462 from shibingli/main
Add ARG HTTP_PROXY in Dockerfile to support HTTP proxy during image building

Former-commit-id: 1e72bb24253bb07da874f3a37ccfa4fddaaf6978
2025-01-02 19:34:17 +08:00
hiyouga
2aaf3697d7 fix #6499
Former-commit-id: dffc607220ff6dac15cf501ac9a3cdbe80c25211
2025-01-02 11:28:54 +00:00
hoshi-hiyouga
5504b5254c Merge pull request #6492 from hiyouga/hiyouga/add_deepseek3
[model] add deepseek3 model

Former-commit-id: 0a6d1244a51f3cc8fe141b32f39bffce4c924a8c
2024-12-30 21:50:13 +08:00
hiyouga
b2e4f11602 add deepseek3 model
Former-commit-id: 611779d412f31e25b1ed38049050eee2da61dde5
2024-12-30 13:39:20 +00:00
hoshi-hiyouga
e3f95abca7 Merge pull request #5507 from piamo/main
Add deepseek-v2.5 template

Former-commit-id: 8a4911d201e219465fe0835a3ceb967f8b80dc0e
2024-12-30 21:08:25 +08:00
hoshi-hiyouga
2f44f70c2c Merge pull request #6483 from hiyouga/hiyouga/fix_paligemma_infer
[model] update vllm & fix paligemma dtype

Former-commit-id: 03ad6d44805a965764aaa51376964972b9b7da3d
2024-12-30 16:34:32 +08:00
hiyouga
f8f05a883b fix #6482
Former-commit-id: 8577f52b4152efe6cc7a8b5f6d37b4f9ba6684e7
2024-12-30 06:03:07 +00:00
hoshi-hiyouga
5f473e2696 Merge pull request #6465 from hiyouga/hiyouga/fix_eval_loss
[trainer] fix eval loss

Former-commit-id: fa8110b2052a74b4bd0dcf391a54207e1e31056d
2024-12-28 01:02:56 +08:00
hiyouga
88b1874c04 fix #6448
Former-commit-id: 04f78e85af5af14b4c195936623e426a6a128af2
2024-12-27 16:54:39 +00:00
shibingli@yeah.net
58bc6943dc Add ARG HTTP_PROXY in Dockerfile to support HTTP proxy during image building.
Former-commit-id: c46af4c45f96f1942dfaf77bdbdbe5d0fe85a387
2024-12-27 18:31:14 +08:00
shibingli@yeah.net
2dedf7b401 Add ARG HTTP_PROXY in Dockerfile to support HTTP proxy during image building.This commit introduces an ARG parameter named HTTP_PROXY in the Dockerfile. This addition allows for the configuration of an HTTP proxy, facilitating image building in environments with network restrictions.
Former-commit-id: d59fe30bca636bc2ca132d50172dba0032cecb6b
2024-12-27 18:17:17 +08:00
hoshi-hiyouga
5769a553d2 Merge pull request #6457 from youkaichao/module-run
[misc] enable module run

Former-commit-id: 813881a5d13dd1d5a526a85d41032196e0d46f04
2024-12-26 23:41:37 +08:00
youkaichao
552816e04b Update cli.py
Former-commit-id: 18e65bbd3ae07af3b9eed7f293c345815776c325
2024-12-26 23:22:09 +08:00
hoshi-hiyouga
b5fa1044b8 Merge pull request #6443 from hiyouga/hiyouga/add_qvq
[modle] add qvq

Former-commit-id: 2010e80b1a939d21efa13d54df5f5d648ea640de
2024-12-25 15:53:19 +08:00
hiyouga
3c55976a0e add qvq #6439
Former-commit-id: 4dbfa142d899dd6e4d1a9d4db125765af5580a4f
2024-12-25 07:52:41 +00:00
hoshi-hiyouga
4611f67fae Merge pull request #6426 from hiyouga/hiyouga/update_readme
[assets] update readme

Former-commit-id: 2309c431090d1f3b573d113bbedeabee2b01fdf2
2024-12-23 22:17:19 +08:00
hiyouga
a5346041bb update readme
Former-commit-id: 1deda4750e0df6c46aeb33cf3f8b35baa537cc1d
2024-12-23 14:08:59 +00:00
hoshi-hiyouga
df42e438c1 Merge pull request #5922 from Tuyohai/main
support granite3 models

Former-commit-id: a9087bc0549f7f16e5b4c39e324043755b1618c8
2024-12-23 16:46:02 +08:00
hoshi-hiyouga
7dbfd7dff6 Merge pull request #6418 from hiyouga/hiyouga/add_report
[trainer] add custom args to experimental logger

Former-commit-id: 5e5a7ba73c1a386f025d75c10b102306bcb98674
2024-12-22 05:47:55 +08:00
hiyouga
a897d46049 support report custom args
Former-commit-id: d41254c40a1c5cacf9377096adb27efa9bdb79ea
2024-12-21 21:42:45 +00:00
hiyouga
adff887659 fix paligemma infer
Former-commit-id: d272455d6118c1d670c70cfe3458d8dab111da6c
2024-12-21 20:24:32 +00:00
hoshi-hiyouga
eba78f2159 Merge pull request #6416 from Zeyi-Lin/main
docs: use swanlab
Former-commit-id: 0759b576a36cde120ccb8cadd96fca4d871be130
2024-12-22 04:08:26 +08:00
ZeYi Lin
ec05c8cdb4 docs: use swanlab
Former-commit-id: 33509ea7bcd5f698a8393379bb3941c3c32f7fd6
2024-12-21 20:59:25 +08:00
hoshi-hiyouga
0a869c4ed4 Merge pull request #6401 from Zeyi-Lin/hiyouga/swanlab
feat: add swanlab for experiment tracking and visualization.
Former-commit-id: e65fe507f7643bf40b0fc462805c7b7f8ef6b738
2024-12-21 14:09:33 +08:00
ZeYi Lin
f792eaf8d4 fix: project blank
Former-commit-id: 3a0939572b0bfc7da0ee1a7244b6b3fbf567aba0
2024-12-20 18:26:02 +08:00
ZeYi Lin
8a41c96761 fix: by hiyouga suggestion
Former-commit-id: 41195f1bc69e4b5da7a265369d368b06754362cf
2024-12-20 16:43:03 +08:00
ZeYi Lin
e5d9d8c55d feat: ui improve
Former-commit-id: 6a1effb1741a13ae5238b0e9b429b4cbe3b6534f
2024-12-20 11:03:02 +08:00
ZeYi Lin
3e44c8fe3a fix: text
Former-commit-id: 52fe8d61eba7b7d8f66df09a03d40f25cc9c5b44
2024-12-19 21:26:02 +08:00
ZeYi Lin
925e421bde fix: bugs
Former-commit-id: a2297f97f7587c77d55fbce9ffa81dc60d0b04a1
2024-12-19 21:08:16 +08:00
hoshi-hiyouga
bbb636bdba Merge pull request #6395 from hiyouga/hiyouga/fix_genkwargs
[generate] fix generate kwargs

Former-commit-id: 1193594f2d06df38ec0aef7f591c74651cf1353c
2024-12-19 20:24:17 +08:00
ZeYi Lin
a30bdbb1c0 docs: config framework
Former-commit-id: 9cad21df82754170900e3ea74476f674754159b3
2024-12-19 20:22:36 +08:00
ZeYi Lin
95b7e10a06 fix: string
Former-commit-id: 73e1da5ab07c96a6faa9738e83c4dd9297f34b14
2024-12-19 20:18:59 +08:00
hiyouga
0385c60177 fix #6391
Former-commit-id: 067ba6e6cb4d8a1d95bba0a108f73008416a2865
2024-12-19 12:16:38 +00:00
ZeYi Lin
44895ebe36 feat: optimize frontend
Former-commit-id: 4a78603c141d9bd78bcaf81261b443cf082bf51f
2024-12-19 19:04:19 +08:00
ZeYi Lin
44dfbf9dbd feat: swanlab params
Former-commit-id: 761b3bdb03e27826fde2ca86d4e37b53c2bbc777
2024-12-19 18:47:27 +08:00
hoshi-hiyouga
0a465fc3ca Merge pull request #6388 from hiyouga/hiyouga/shuffle_control
[trainer] support disable shuffling

Former-commit-id: 3243e74a2ed3b1f7fa818842955f91386b591a9c
2024-12-19 17:00:12 +08:00
hiyouga
01eeae50b5 support disable shuffling
Former-commit-id: 9d8c35fd6b838ede0bd6827c6c6121f2cba2b11b
2024-12-19 08:53:21 +00:00
hiyouga
7eeeffdb8a add swanlab
Former-commit-id: c85a77c8a8824a56a67d56b97b4877fcd6edeb3d
2024-12-19 07:12:31 +00:00
hoshi-hiyouga
eca06531c3 Merge pull request #6384 from hiyouga/hiyouga/fix_webui
[webui] fix webui args

Former-commit-id: 94294c4e356b3ac5546f897d6e3255ee8c2a260f
2024-12-19 14:57:52 +08:00
hiyouga
d90b40b60f fix webui
Former-commit-id: 7152fde4a026e67f15885814c1900f3911d04ee8
2024-12-19 06:48:03 +00:00
hoshi-hiyouga
1898c1e9a6 Merge pull request #6379 from hiyouga/hiyouga/add_paligemma2
[model] add paligemma2

Former-commit-id: abe3ff3fe0b113e949bf6d2bd10e4c125fb8fe75
2024-12-18 17:03:11 +08:00
hiyouga
8d2f8b0dd8 add paligemma2
Former-commit-id: dafbc31684cb2566ef23c79e171cdfd02d6d396b
2024-12-18 08:57:26 +00:00
hoshi-hiyouga
df42281256 Merge pull request #6313 from ge-xing/main
support telechat2 model

Former-commit-id: 282d0619b1047ba48f9bc3ac837d2ed40b7df307
2024-12-18 16:16:17 +08:00
hoshi-hiyouga
896cf476d5 Merge pull request #6369 from hiyouga/hiyouga/template
[template] support qwen2 tool template

Former-commit-id: e1e133635f05f5b83869bc02340d6ea46976f318
2024-12-18 04:23:49 +08:00
hiyouga
37961d5f06 support qwen tool format
Former-commit-id: cbef4cb501fa1b50fa611e7054a856ce2c5ed10e
2024-12-17 20:12:06 +00:00
hiyouga
bb047bc844 change default replace jinja to false
Former-commit-id: bfe6625f6f6aa294933fa9056a4bfedee4fbe5e2
2024-12-17 19:27:10 +00:00
hoshi-hiyouga
448adedf6a Merge pull request #5473 from AlongWY/mistral
Support Mistral format tools

Former-commit-id: 4838427310d49e5942138e4578d2483baa005471
2024-12-18 03:23:24 +08:00
ylfeng
469c7cd462 Support Mistral format tools
Former-commit-id: e42d0e54b7a64a3f017a09e99846d174db7b438f
2024-12-17 19:13:26 +00:00
hoshi-hiyouga
ebf6a07681 Merge pull request #6368 from hiyouga/hiyouga/fix_llama_template
[template] fix llama3 tool template

Former-commit-id: 7c6763c4f3287f758077191361d5b0354741f84a
2024-12-18 01:10:48 +08:00
hiyouga
53f0fff513 fix llama3 tool template
Former-commit-id: 63f28a594a44c011f2e6d418f22ddbfc445db163
2024-12-17 17:05:10 +00:00
hoshi-hiyouga
ab7567693d Merge pull request #6367 from hiyouga/hiyouga/add_model
[model&template] add llama3.3 & support llama3 tool prompt

Former-commit-id: c32012c5e4943a30c3061716ed780d6124b6c90d
2024-12-18 00:13:28 +08:00
hiyouga
1b8aab0723 support llama3 tool prompt
Former-commit-id: dc45d2f56669fd99935a68cda1ec0e8f36229f7f
2024-12-17 15:52:37 +00:00
hoshi-hiyouga
30ebe61914 Merge pull request #5819 from yafshar/remote_code
Add trust_remote_code Parameter and Set Default to False

Former-commit-id: e82099350a2fb6d8ddf9c80ba0b18173057d4dcf
2024-12-17 21:10:24 +08:00
Yaser Afshar
6f1c8dacea Add missing key to init_kwargs
Former-commit-id: 03fc4621dad132164596a58d3e8693787b7e1aca
2024-12-17 12:34:05 +00:00
Yaser Afshar
8881237475 Add trust_remote_code parameter and remove True
- Introduced a new model parameter `trust_remote_code`
- Set the default value of `trust_remote_code` to `False`
  to enhance security


Former-commit-id: 4bf23f406cf5235c16f9f8139850c53354901814
2024-12-17 12:25:12 +00:00
zhaohu xing
584755be4b support telechat2 model
Former-commit-id: 15a069d85c07842cd28d65845af93c3cf70ef1f4
2024-12-17 12:15:33 +00:00
hoshi-hiyouga
3d3324be5c Merge pull request #6364 from hiyouga/hiyouga/control_reenterent_gc
[model] support non-reenterent-gc

Former-commit-id: a8a13cb360980bb4acd493e33ed405e07460fe73
2024-12-17 19:58:36 +08:00
hiyouga
4196d5b4d6 support non-reenterent-gc & fix #6358
Former-commit-id: 20446141e408885eb36d512bfb2dfb62bbc0c20d
2024-12-17 11:41:59 +00:00
hoshi-hiyouga
101c95ce65 Merge pull request #6363 from hiyouga/hiyouga/control_skip_eos
[infer] support control eos

Former-commit-id: 963640cff370be9f2fab649c88a120a645e6992e
2024-12-17 19:35:40 +08:00
hiyouga
19ebc0e7a2 support control eos, fix #6345
Former-commit-id: cb0f8399356bf372f3b7963f2565c3d504be0923
2024-12-17 10:42:05 +00:00
hoshi-hiyouga
1ce15b5d9e Merge pull request #6362 from hiyouga/hiyouga/mllm_packing
[model] generalized packing

Former-commit-id: b85f77a2687f7e0d11f7d2e49de54c544e39e3d5
2024-12-17 18:41:48 +08:00
hiyouga
d670d62a66 generalized packing & fix #6343
Former-commit-id: 3b1e4194616cacd5c24f08b328e31a008bddcf29
2024-12-17 10:26:19 +00:00
hoshi-hiyouga
6522467ddb Merge pull request #6359 from hiyouga/hiyouga/fix_qwen2vl_infer
[model] fix qwen2vl infern

Former-commit-id: 419cba5fae31a3c88305fe424b8aae9d59e3941a
2024-12-17 18:15:23 +08:00
hiyouga
aacd9642f5 fix #6348
Former-commit-id: 83e552320909f4775377889f1512994b7e638a7e
2024-12-17 10:06:46 +00:00
hoshi-hiyouga
4446c92517 Merge pull request #6334 from hiyouga/hiyouga/add_examples
[assets] update wechat and examples

Former-commit-id: 7725e7ac7d21ad844e8424a920e8bece6f38af19
2024-12-15 01:37:01 +08:00
hiyouga
8c65548b10 update assets
Former-commit-id: 7b9bd552b2bf97b72976511094eb51dfde5d1017
2024-12-14 17:36:03 +00:00
hiyouga
fb22651faf fix mrope
Former-commit-id: 55bee1d333549ca19858b3f5c1b7b86926e5fb09
2024-12-12 15:08:17 +00:00
hoshi-hiyouga
cfff136b2a Merge pull request #6253 from hiyouga/hiyouga/qwen2vl_mm_proj
[model] support qwen2vl train proj only

Former-commit-id: 0b0012142ab683da1e0558e6240310bf90f39150
2024-12-05 20:25:33 +08:00
hiyouga
bac2c64f87 support qwen2vl train proj only
Former-commit-id: 0e949ef03455726e907c6f1039e93ebe480c897a
2024-12-05 10:37:42 +00:00
hoshi-hiyouga
be1ec97c8e Merge pull request #6251 from hiyouga/hiyouga/vllm_qwen2vl_infer
[infer] support qwen2vl vllm infer

Former-commit-id: df76f7d6e124131ce7628c31cce01de4f8e6014c
2024-12-05 18:26:19 +08:00
hiyouga
bbd432415d support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
2024-12-05 10:17:26 +00:00
hoshi-hiyouga
1fef702382 Merge pull request #6246 from hiyouga/hiyouga/update_examples
[examples] update examples

Former-commit-id: ecb688bdb3e940651d64bc1edc85ce4568f3eabe
2024-12-05 16:49:30 +08:00
hiyouga
39865d8a1f update examples
Former-commit-id: bcb010be7732ae137f156932100ee4d02a93725c
2024-12-05 08:48:25 +00:00
hoshi-hiyouga
c7b27bd70b Merge pull request #6242 from hiyouga/hiyouga/fix_script
[script] fix scripts

Former-commit-id: cf254ea0891ea2e6522fdbefcccf409ff7aafd99
2024-12-05 11:54:46 +08:00
hiyouga
86e4fab0d5 fix scripts
Former-commit-id: f94f55d20283298cb7d90d0573992a62df414a8f
2024-12-05 03:47:32 +00:00
hoshi-hiyouga
ff3e40e4a5 Merge pull request #6160 from village-way/pr_dataloader
fix:tokenized_path not None and load_from_disk return Dataset Trigger…
Former-commit-id: 63de20970c8062aeebed5f366f1675beb12e05bf
2024-12-04 22:18:19 +08:00
hoshi-hiyouga
ea830cad0c lint
Former-commit-id: 191ccc585399ad4c6c2c4f280b144b2c0a4869f3
2024-12-04 22:08:27 +08:00
hoshi-hiyouga
225e270fd5 Merge pull request #6238 from hiyouga/hiyouga/vllm_batchinfer
[infer] feat: support batch infer in vllm

Former-commit-id: 886752801ba8a5bf6fc4853ed618817185950c11
2024-12-04 21:59:13 +08:00
hiyouga
c1768cfb14 support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
2024-12-04 13:50:00 +00:00
hoshi-hiyouga
53edd62f8b Merge pull request #6190 from JieShenAI/main
add vllm_infer script

Former-commit-id: 09c7ea700c83dcf8d75796a1e28a36197f62cab4
2024-12-04 21:19:23 +08:00
hoshi-hiyouga
41a7e128b6 Merge pull request #6170 from hykilpikonna/main
[+] Show the hostname in webui title

Former-commit-id: 1cb2f9da317a8db8f45e887ab57cdfdc0e8b9412
2024-12-04 18:07:29 +08:00
hoshi-hiyouga
6b8c41c3ac Merge pull request #6233 from hiyouga/hiyouga/vlm_zero3
[data] fix vlm zero3 training

Former-commit-id: b0cbd5e3464a8a1a0f1cf709fb107b23a61f34ff
2024-12-04 17:51:10 +08:00
hiyouga
2f09c34980 fix vlm zero3 training
Former-commit-id: 86fe7fe71b51077310357b7b1895522258f9bc7a
2024-12-04 09:40:39 +00:00
JieShen
76dc69ce36 add async call api
Former-commit-id: 0f728386d88cf8253250c6650555d41578114a0c
2024-12-01 22:18:05 +08:00
JieShen
6c9d05539a add vllm_infer script
Former-commit-id: 4daab843a3aa096b35e5d3832c01fac4271e4604
2024-11-29 14:22:20 +08:00
Azalea
b6bc17f730 [U] Compute hostname differently
Former-commit-id: fbc735972af6facdaba169603a4c77e613b2e8d7
2024-11-28 22:23:41 -05:00
hoshi-hiyouga
c07ba8ccc0 Merge pull request #6175 from hiyouga/hiyouga/add_qwq
[model] add QwQ

Former-commit-id: da8f565c359004d811481b8b85f2a36f30e95e23
2024-11-28 17:01:53 +08:00
hiyouga
ed86f621a0 add qwq
Former-commit-id: acad977356a7f2e729eb6f2cb919a416b18f8add
2024-11-28 08:50:57 +00:00
Azalea
c6a3175bbf [+] Show the hostname
Former-commit-id: 410847656a760fe4c2c310b0d770072392d7aefb
2024-11-28 12:25:02 +08:00
wangdepeng
452291417d fix:tokenized_path not None and load_from_disk return Dataset Trigger stuck
Former-commit-id: cbf9da35728daaf98d92e699e891e334c74af1e5
2024-11-27 16:44:42 +08:00
hoshi-hiyouga
ab9db8b7c7 Merge pull request #6156 from hiyouga/hiyouga/add_o1
[data&model] add marco-o1, skywork-o1 and openo1

Former-commit-id: fa8aa1a3bcb49357799ec30fbb3f143a015e5d58
2024-11-27 14:36:01 +08:00
hiyouga
877e2ea791 fix dataset
Former-commit-id: d4a2d299414984a4043d30034c5c95e2d717a49e
2024-11-27 06:27:44 +00:00
hiyouga
6ea42d5b63 add skywork o1
Former-commit-id: 272a6fe972de926e5841c1570995f4e6fed9f28d
2024-11-27 05:51:59 +00:00
hiyouga
31c117e696 Merge remote-tracking branch 'origin/main' into hiyouga/add_o1
Former-commit-id: 5da8c00b233f96e51cf3bac7f25e3e61659d0cb7
2024-11-27 05:36:41 +00:00
hoshi-hiyouga
04f057334f Merge pull request #6157 from hiyouga/hiyouga/fix_ci
[ci] pin tokenizers version

Former-commit-id: 0357d7530d16699e728bc648abd08ea309e84865
2024-11-27 13:33:04 +08:00
hiyouga
99a54d06ca pin tokenizers version
Former-commit-id: 2b747737f0be2caeb737fe87dad6bf5902b4a588
2024-11-27 05:24:58 +00:00
hiyouga
8332c85f37 add marco-o1 and openo1 dataset
Former-commit-id: 51d49e075470951f109bcdde136203f972450c2e
2024-11-27 04:20:23 +00:00
hoshi-hiyouga
fcf1a3df62 Merge pull request #6152 from hiyouga/hiyouga/add_num_proc_in_data_load
[data] add num_proc in load_dataset

Former-commit-id: d8258ba7e792d5f17ae80d5e8b303e8fa820f162
2024-11-27 00:16:15 +08:00
hoshi-hiyouga
f4f52ae67d Merge pull request #6151 from hiyouga/hiyouga/fix_mllama
[model] fix mllama cross mask

Former-commit-id: 7e64661c1fc53c4d3d9fd915162b762e403b1991
2024-11-27 00:07:54 +08:00
hiyouga
0b08d5882a fix #6149
Former-commit-id: b581b272793314a9602f4dc2fb646a988a6249df
2024-11-26 16:03:02 +00:00
hiyouga
62eeafaba6 fix mllama cross_mask
Former-commit-id: c33967308bebd99489d28bd5a879525cf304c1f9
2024-11-26 15:56:58 +00:00
hoshi-hiyouga
5a52e41399 Merge pull request #6141 from hiyouga/hiyouga-patch-1
[misc] chore: lint

Former-commit-id: ba2b94c68eb08798792be76f95b94b358ce69f44
2024-11-25 23:02:11 +08:00
hoshi-hiyouga
e8083f8f3f lint
Former-commit-id: 57c3cf1f498d5ffafdc8c06e0f8713f8ff77de81
2024-11-25 22:55:56 +08:00
hoshi-hiyouga
338b3a03f0 Merge pull request #6140 from hiyouga/hiyouga/fix_mllama
[data] fix mllama plugin

Former-commit-id: b7e220a7d82db26cbe7ced9ed30332418cc4fa20
2024-11-25 22:32:07 +08:00
hoshi-hiyouga
c8b01b41ac fix #6139
Former-commit-id: a4e9552b9ade6ebb22d782f0412003279ddca23c
2024-11-25 22:22:06 +08:00
hoshi-hiyouga
6d08a418ed Merge pull request #6137 from hiyouga/hiyouga/fix_mllama
[model] fix mllama hidden_size

Former-commit-id: 54f1d3f4064b9d37261883e8399c8e7909178857
2024-11-25 20:17:33 +08:00
hoshi-hiyouga
e3066d1489 fix visual patch
Former-commit-id: ac51fa37cc23518b30a6123e188964dce39be82f
2024-11-25 20:06:06 +08:00
hoshi-hiyouga
487e3f2507 fix #6136
Former-commit-id: b84e5d91a070c473ea820c379bf9b5abbca6df2c
2024-11-25 19:43:42 +08:00
hoshi-hiyouga
b82a53cad8 Merge pull request #6127 from hiyouga/hiyouga/dev_version
[misc] set dev version

Former-commit-id: cb0a51031324c9fdf0c1fedf237692a40c2091d9
2024-11-25 01:42:29 +08:00
hiyouga
5bec82ca9d set dev version
Former-commit-id: a0aea74100a9505664023f6a46fc290e332dfa40
2024-11-25 01:36:49 +08:00
steven
6ef0d13e42 support granite3 models
Former-commit-id: 8cff612e55eb7df116e51c4dd21e7a42543e7a1f
2024-11-04 10:35:03 +08:00
huangpan.foo
ed5c641e8b Add deepseek-v2.5 template
Former-commit-id: e80c1fe798fb2e076c0891a64300f1b6710176b6
2024-09-21 19:33:30 +08:00
221 changed files with 9049 additions and 4012 deletions

View File

@@ -4,14 +4,17 @@ API_HOST=
API_PORT=
API_KEY=
API_MODEL_NAME=
API_VERBOSE=
FASTAPI_ROOT_PATH=
MAX_CONCURRENT=
# general
DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
ALLOW_EXTRA_ARGS=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
USE_RAY=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=
@@ -31,7 +34,7 @@ GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
GRADIO_IPV6=
# setup
ENABLE_SHORT_CONSOLE=1
ENABLE_SHORT_CONSOLE=
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=

63
.github/ISSUE_TEMPLATE/1-bug-report.yml vendored Normal file
View File

@@ -0,0 +1,63 @@
name: "\U0001F41B Bug / help"
description: Create a report to help us improve the LLaMA Factory
labels: ["bug", "pending"]
body:
- type: markdown
attributes:
value: |
Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response.
已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。
- type: markdown
attributes:
value: |
Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
请勿在此分类下创建和框架 bug 无关的 issues请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
请确保您已经认真阅读了上述规则并且搜索过现有的 issues包括常见问题
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide entry arguments, error messages and stack traces that reproduces the problem.
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
Remember to wrap your log messages with \`\`\`.
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
value: |
```text
Put your message here.
```
- type: textarea
id: others
validations:
required: false
attributes:
label: Others

View File

@@ -0,0 +1,41 @@
name: "\U0001F680 Feature request"
description: Submit a request for a new feature
labels: ["enhancement", "pending"]
body:
- type: markdown
attributes:
value: |
Please do not create issues that are not related to new features under this category.
请勿在此分类下创建和新特性无关的 issues。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the above rules carefully and searched the existing issues.
请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
options:
- label: I have read the above rules and searched the existing issues.
required: true
- type: textarea
id: description
validations:
required: true
attributes:
label: Description
description: |
A clear and concise description of the feature proposal.
请详细描述您希望加入的新功能特性。
- type: textarea
id: contribution
validations:
required: false
attributes:
label: Pull Request
description: |
Have you already created the relevant PR and submitted the code?
您是否已经创建了相关 PR 并提交了代码?

View File

@@ -1,66 +0,0 @@
name: "\U0001F41B Bug / Help"
description: Create a report to help us improve the LLaMA Factory
body:
- type: markdown
attributes:
value: |
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
- type: checkboxes
id: reminder
attributes:
label: Reminder
description: |
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
请确保您已经认真阅读了 README 并且搜索过现有的 issues包括常见问题
options:
- label: I have read the README and searched the existing issues.
required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide code snippets, error messages and stack traces that reproduces the problem.
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
Remember to use Markdown tags to correctly format your code.
请合理使用 Markdown 标签来格式化您的文本。
placeholder: |
```bash
llamafactory-cli train ...
```
- type: textarea
id: expected-behavior
validations:
required: false
attributes:
label: Expected behavior
description: |
Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: others
validations:
required: false
attributes:
label: Others

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1 @@
blank_issues_enabled: false

View File

@@ -18,13 +18,15 @@ jobs:
ISSUE_URL: ${{ github.event.issue.html_url }}
ISSUE_TITLE: ${{ github.event.issue.title }}
run: |
LABEL=pending
LABEL=""
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
for KEYWORD in ${NPU_KEYWORDS[@]}; do
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
LABEL=pending,npu
LABEL="npu"
break
fi
done
gh issue edit $ISSUE_URL --add-label $LABEL
if [ -n "$LABEL" ]; then
gh issue edit $ISSUE_URL --add-label $LABEL
fi

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
python-version: "3.9"
- name: Install dependencies
run: |

View File

@@ -22,10 +22,10 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.8" # TODO: remove py38 in next transformers release
- "3.9"
- "3.10"
- "3.11"
- "3.12"
os:
- "ubuntu-latest"
- "windows-latest"
@@ -33,9 +33,6 @@ jobs:
runs-on: ${{ matrix.os }}
environment:
name: tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}

5
.gitignore vendored
View File

@@ -162,6 +162,9 @@ cython_debug/
# vscode
.vscode/
# uv
uv.lock
# custom .gitignore
ms_cache/
hf_cache/
@@ -171,3 +174,5 @@ config/
saves/
output/
wandb/
swanlog/
generated_predictions.jsonl

311
README.md
View File

@@ -1,20 +1,31 @@
![# LLaMA Factory](assets/logo.png)
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#projects-using-llama-factory)
[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
<h3 align="center">
Easily fine-tune 100+ large language models with zero-code <a href="#quickstart">CLI</a> and <a href="#fine-tuning-with-llama-board-gui-powered-by-gradio">Web UI</a>
</h3>
<p align="center">
<picture>
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
</picture>
</p>
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
@@ -26,16 +37,12 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **Local machine**: Please refer to [usage](#getting-started)
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
Recent activities:
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -49,6 +56,16 @@ Recent activities:
- [Provided Datasets](#provided-datasets)
- [Requirement](#requirement)
- [Getting Started](#getting-started)
- [Installation](#installation)
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
- [Download from Modelers Hub](#download-from-modelers-hub)
- [Use W&B Logger](#use-wb-logger)
- [Use SwanLab Logger](#use-swanlab-logger)
- [Projects using LLaMA Factory](#projects-using-llama-factory)
- [License](#license)
- [Citation](#citation)
@@ -56,14 +73,22 @@ Recent activities:
## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
### Day-N Support for Fine-Tuning Cutting-Edge Models
| Support Date | Model Name |
| ------------ | ---------------------------------------------------------- |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## Benchmark
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
@@ -81,21 +106,41 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
<details><summary>Full Changelog</summary>
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
[24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details.
[24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset.
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We supported fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
[24/08/27] We supported **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We supported **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
[24/07/04] We supported [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[24/06/16] We supported **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
@@ -173,40 +218,51 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
| Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
| Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | ------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -290,9 +346,13 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
@@ -332,35 +392,34 @@ huggingface-cli login
| Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- |
| python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.32.0 |
| python | 3.9 | 3.10 |
| torch | 1.13.1 | 2.5.1 |
| transformers | 4.41.2 | 4.49.0 |
| datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 |
| peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 |
| deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.6.3 |
| vllm | 0.4.3 | 0.7.3 |
| flash-attn | 2.3.0 | 2.7.2 |
### Hardware Requirement
\* *estimated*
| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
| Method | Bits | 7B | 14B | 30B | 70B | `x`B |
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
## Getting Started
@@ -375,47 +434,67 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality
> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
<details><summary>Setting up a virtual environment with <b>uv</b></summary>
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
```bash
uv sync --extra torch --extra metrics --prerelease=allow
```
Run LLaMA-Factory in the isolated environment:
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
</details>
<details><summary>For Windows users</summary>
#### Install BitsAndBytes
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
```
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
#### Install Flash Attention-2
To enable FlashAttention-2 on the Windows platform, please use the script from [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) to compile and install it by yourself.
</details>
<details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
# install CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| Requirement | Minimum | Recommend |
| ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
@@ -423,6 +502,40 @@ If you cannot infer model on NPU devices, try setting `do_sample: false` in the
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
```bash
# Install bitsandbytes from source
# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
cd bitsandbytes/
# Install dependencies
pip install -r requirements-dev.txt
# Install the dependencies for the compilation tools. Note that the commands for this step may vary depending on the operating system. The following are provided for reference
apt-get install -y build-essential cmake
# Compile & install
cmake -DCOMPUTE_BACKEND=npu -S .
make
pip install .
```
2. Install transformers from the main branch.
```bash
git clone -b main https://github.com/huggingface/transformers.git
cd transformers
pip install .
```
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
</details>
### Data Preparation
@@ -446,6 +559,8 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
> [!TIP]
> Use `llamafactory-cli help` to show help information.
>
> Read [FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614) first if you encounter any problems.
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
@@ -590,7 +705,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
### Download from ModelScope Hub
@@ -623,6 +738,21 @@ run_name: test_run # optional
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
### Use SwanLab Logger
To use [SwanLab](https://github.com/SwanHubX/SwanLab) for logging experimental results, you need to add the following arguments to yaml files.
```yaml
use_swanlab: true
swanlab_run_name: test_run # optional
```
When launching training tasks, you can log in to SwanLab in three ways:
1. Add `swanlab_api_key=<your_api_key>` to the yaml file, and set it to your [API key](https://swanlab.cn/settings).
2. Set the environment variable `SWANLAB_API_KEY` to your [API key](https://swanlab.cn/settings).
3. Use the `swanlab login` command to complete the login.
## Projects using LLaMA Factory
If you have a project that should be incorporated, please contact via email or create a pull request.
@@ -722,7 +852,8 @@ If you have a project that should be incorporated, please contact via email or c
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
</details>
@@ -730,7 +861,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation

View File

@@ -1,20 +1,32 @@
![# LLaMA Factory](assets/logo.png)
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-93-green)](#使用了-llama-factory-的项目)
[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![GitCode](https://gitcode.com/zhengyaowei/LLaMA-Factory/star/badge.svg)](https://gitcode.com/zhengyaowei/LLaMA-Factory)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
<h3 align="center">
使用零代码<a href="#快速开始">命令行</a><a href="#llama-board-可视化微调由-gradio-驱动">Web UI</a> 轻松微调百余种大模型
</h3>
<p align="center">
<picture>
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
</picture>
</p>
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
@@ -28,15 +40,11 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **Colab(免费)**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用)
- **PAI-DSW**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **PAI-DSW(免费试用)**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
- **Amazon SageMaker**[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
近期活动:
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -50,6 +58,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- [数据集](#数据集)
- [软硬件依赖](#软硬件依赖)
- [如何使用](#如何使用)
- [安装 LLaMA Factory](#安装-llama-factory)
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
- [从魔乐社区下载](#从魔乐社区下载)
- [使用 W&B 面板](#使用-wb-面板)
- [使用 SwanLab 面板](#使用-swanlab-面板)
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
- [协议](#协议)
- [引用](#引用)
@@ -57,14 +75,22 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQPiSSA 和 Agent 微调
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQPiSSA。
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
### 最新模型的 Day-N 微调适配
| 适配时间 | 模型名称 |
| ------------ | ---------------------------------------------------------- |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## 性能指标
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
@@ -82,6 +108,28 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
<details><summary>展开日志</summary>
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
[24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。
[24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
@@ -92,8 +140,6 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
<details><summary>展开日志</summary>
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -174,39 +220,51 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 模型
| 模型名 | 模型大小 | Template |
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
| 模型名 | 参数量 | Template |
| ----------------------------------------------------------------- | -------------------------------- | ------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@@ -220,7 +278,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 训练方法
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| --------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
@@ -290,9 +348,13 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [OpenO1-SFT (en&zh)](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)
- [Open-Thoughts (en)](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
- [Open-R1-Math (en)](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k)
- [Chinese-DeepSeek-R1-Distill (zh)](https://huggingface.co/datasets/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
@@ -332,35 +394,34 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.32.0 |
| python | 3.9 | 3.10 |
| torch | 1.13.1 | 2.5.1 |
| transformers | 4.41.2 | 4.49.0 |
| datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 |
| peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 |
| deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.6.3 |
| vllm | 0.4.3 | 0.7.3 |
| flash-attn | 2.3.0 | 2.7.2 |
### 硬件依赖
\* *估算值*
| 方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
| 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B |
| ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
| Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
| Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
| Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
## 如何使用
@@ -375,26 +436,47 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]"
```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality
> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
<details><summary>使用 <b>uv</b> 构建虚拟环境</summary>
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
```bash
uv sync --extra torch --extra metrics --prerelease=allow
```
在环境中运行 LLaMA-Factory
```bash
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
</details>
<details><summary>Windows 用户指南</summary>
#### 安装 BitsAndBytes
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
```
如果要在 Windows 平台上开启 FlashAttention-2需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
#### 安装 Flash Attention-2
如果要在 Windows 平台上开启 FlashAttention-2请使用 [flash-attention-windows-wheel](https://huggingface.co/lldacing/flash-attention-windows-wheel) 中的脚本自行编译与安装。
</details>
<details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -410,12 +492,12 @@ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
@@ -423,6 +505,40 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
1. 手动编译 bitsandbytes请参考[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成 NPU 版的 bitsandbytes 安装,编译要求环境 cmake 版本不低于 3.22.1g++ 版本不低于 12.x。
```bash
# 从源码安装 bitsandbytes
# 克隆 bitsandbytes 仓库, Ascend NPU 目前在 multi-backend-refactor 中支持
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
cd bitsandbytes/
# 安装依赖
pip install -r requirements-dev.txt
# 安装编译工具依赖,该步骤在不同系统上命令有所不同,供参考
apt-get install -y build-essential cmake
# 编译 & 安装
cmake -DCOMPUTE_BACKEND=npu -S .
make
pip install .
```
2. 安装 transformers 的 main 分支版本。
```bash
git clone -b main https://github.com/huggingface/transformers.git
cd transformers
pip install .
```
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
</details>
### 数据准备
@@ -446,6 +562,8 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
> [!TIP]
> 使用 `llamafactory-cli help` 显示帮助信息。
>
> 遇到报错请先看[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)。
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
@@ -590,7 +708,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP]
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py)
### 从魔搭社区下载
@@ -623,6 +741,21 @@ run_name: test_run # 可选
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
### 使用 SwanLab 面板
若要使用 [SwanLab](https://github.com/SwanHubX/SwanLab) 记录实验数据,请在 yaml 文件中添加下面的参数。
```yaml
use_swanlab: true
swanlab_run_name: test_run # 可选
```
在启动训练任务时登录SwanLab账户有以下三种方式
方式一:在 yaml 文件中添加 `swanlab_api_key=<your_api_key>` ,并设置为你的 [API 密钥](https://swanlab.cn/settings)。
方式二:将环境变量 `SWANLAB_API_KEY` 设置为你的 [API 密钥](https://swanlab.cn/settings)。
方式三:启动前使用 `swanlab login` 命令完成登录。
## 使用了 LLaMA Factory 的项目
如果您有项目希望添加至下述列表,请通过邮件联系或者创建一个 PR。
@@ -722,6 +855,8 @@ run_name: test_run # 可选
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
</details>
@@ -729,7 +864,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用

View File

@@ -24,6 +24,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"audios": "the column name in the dataset containing the audios inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
@@ -150,6 +151,10 @@ An additional column `images` is required. Please refer to the [sharegpt](#share
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
### Multimodal Audio Dataset
An additional column `audios` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
## Sharegpt Format
### Supervised Fine-Tuning Dataset
@@ -296,7 +301,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
- [Example dataset](mllm_demo.json)
Multimodal image datasets require a `images` column containing the paths to the input images.
Multimodal image datasets require an `images` column containing the paths to the input images.
The number of images should be identical to the `<image>` tokens in the conversations.
@@ -374,6 +379,47 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
```
### Multimodal Audio Dataset
- [Example dataset](mllm_audio_demo.json)
Multimodal audio datasets require an `audios` column containing the paths to the input audios.
The number of audios should be identical to the `<audio>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<audio>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"audios": [
"audio path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"audios": "audios"
}
}
```
### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.

View File

@@ -24,6 +24,7 @@
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None",
"videos": "数据集代表视频输入的表头名称默认None",
"audios": "数据集代表音频输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
@@ -150,6 +151,10 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
### 多模态音频数据集
多模态音频数据集需要提供额外的 `audios` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
## Sharegpt 格式
### 指令监督微调数据集
@@ -374,6 +379,48 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
```
### 多模态音频数据集
- [样例数据集](mllm_audio_demo.json)
多模态音频数据集需要额外添加一个 `audios` 列,包含输入音频的路径。
注意音频的数量必须与文本中所有 `<audio>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<audio>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"audios": [
"音频路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"audios": "audios"
}
}
```
### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。

BIN
data/mllm_demo_data/1.mp3 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/2.wav Normal file

Binary file not shown.

BIN
data/mllm_demo_data/3.flac Normal file

Binary file not shown.

View File

@@ -17,16 +17,28 @@ ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@@ -51,13 +63,30 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY ninja && \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes

View File

@@ -4,13 +4,13 @@ services:
dockerfile: ./docker/docker-cuda/Dockerfile
context: ../..
args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
INSTALL_EETQ: false
INSTALL_BNB: "false"
INSTALL_VLLM: "false"
INSTALL_DEEPSPEED: "false"
INSTALL_FLASHATTN: "false"
INSTALL_LIGER_KERNEL: "false"
INSTALL_HQQ: "false"
INSTALL_EETQ: "false"
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
@@ -24,7 +24,7 @@ services:
- "8000:8000"
ipc: host
tty: true
shm_size: '16gb'
shm_size: "16gb"
stdin_open: true
command: bash
deploy:

View File

@@ -1,7 +1,7 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
FROM ascendai/cann:8.0.0-910b-ubuntu22.04-py3.10
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
@@ -12,16 +12,28 @@ ENV DEBIAN_FRONTEND=noninteractive
ARG INSTALL_DEEPSPEED=false
ARG PIP_INDEX=https://pypi.org/simple
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$TORCH_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@@ -31,7 +43,17 @@ RUN EXTRA_PACKAGES="torch-npu,metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]

View File

@@ -4,7 +4,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile
context: ../..
args:
INSTALL_DEEPSPEED: false
INSTALL_DEEPSPEED: "false"
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
@@ -22,7 +22,7 @@ services:
- "8000:8000"
ipc: host
tty: true
shm_size: '16gb'
shm_size: "16gb"
stdin_open: true
command: bash
devices:

View File

@@ -13,16 +13,28 @@ ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory
WORKDIR /app
# Set http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
echo "Configuring proxy..."; \
export http_proxy=$HTTP_PROXY; \
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image
COPY . /app
@@ -44,13 +56,29 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY ninja && \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes

View File

@@ -4,12 +4,12 @@ services:
dockerfile: ./docker/docker-rocm/Dockerfile
context: ../..
args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
INSTALL_BNB: "false"
INSTALL_VLLM: "false"
INSTALL_DEEPSPEED: "false"
INSTALL_FLASHATTN: "false"
INSTALL_LIGER_KERNEL: "false"
INSTALL_HQQ: "false"
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
@@ -24,7 +24,7 @@ services:
- "8000:8000"
ipc: host
tty: true
shm_size: '16gb'
shm_size: "16gb"
stdin_open: true
command: bash
devices:

View File

@@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
By default, LLaMA-Factory uses all visible computing devices.
## Examples
### LoRA Fine-Tuning
@@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
@@ -99,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
#### Supervised Fine-Tuning with Ray on 4 GPUs
```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
```
### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
@@ -107,6 +109,12 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
```
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
```bash
@@ -130,14 +138,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### Supervised Fine-Tuning on Single Node
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### Multimodal Supervised Fine-Tuning
@@ -146,12 +154,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### Merging LoRA Adapters and Quantization
#### Merge LoRA Adapters
@@ -168,15 +170,27 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### Save Ollama modelfile
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
```
### Inferring LoRA Fine-Tuned Models
#### Use CLI
#### Batch Generation using vLLM Tensor Parallel
```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```
#### Use CLI ChatBox
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### Use Web UI
#### Use Web UI ChatBox
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
@@ -196,6 +210,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using APOLLO
```bash
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using BAdam
```bash
@@ -238,3 +258,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash
bash examples/extras/fsdp_qlora/train.sh
```
#### Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@@ -13,6 +13,8 @@
使用 `CUDA_VISIBLE_DEVICES`GPU`ASCEND_RT_VISIBLE_DEVICES`NPU选择计算设备。
LLaMA-Factory 默认使用所有可见的计算设备。
## 示例
### LoRA 微调
@@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### 多机指令监督微调
```bash
@@ -99,6 +95,12 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
#### 使用 Ray 在 4 张 GPU 上微调
```bash
USE_RAY=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ray.yaml
```
### QLoRA 微调
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
@@ -107,6 +109,12 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
```
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
```bash
@@ -130,14 +138,14 @@ llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
#### 在单机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### 在多机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### 多模态指令监督微调
@@ -146,12 +154,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### 合并 LoRA 适配器与模型量化
#### 合并 LoRA 适配器
@@ -168,15 +170,27 @@ llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### 保存 Ollama 配置文件
```bash
llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
```
### 推理 LoRA 模型
#### 使用命令行接口
#### 使用 vLLM+TP 批量推理
```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```
#### 使用命令行对话框
```bash
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### 使用浏览器界面
#### 使用浏览器对话框
```bash
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
@@ -196,6 +210,12 @@ llamafactory-cli api examples/inference/llama3_lora_sft.yaml
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### 使用 APOLLO 进行全参数训练
```bash
llamafactory-cli train examples/extras/apollo/llama3_full_sft.yaml
```
#### 使用 BAdam 进行全参数训练
```bash
@@ -238,3 +258,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash
bash examples/extras/fsdp_qlora/train.sh
```
#### 计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@@ -14,7 +14,7 @@ fsdp_config:
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: fp16 # or bf16
mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static

View File

@@ -1,5 +1,6 @@
### model
model_name_or_path: Qwen/Qwen2-1.5B-Instruct
trust_remote_code: true
### method
stage: sft
@@ -33,7 +34,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
use_apollo: true
apollo_layerwise: true # choices: [true, false], use false for DDP training
apollo_target: all
apollo_rank: 128
apollo_scale: 32.0
apollo_scale_type: channel
### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1 # use 1 for layerwise apollo
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
pure_bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,6 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
@@ -36,7 +37,7 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,11 +1,13 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -34,7 +36,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,13 +1,14 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
use_galore: true
galore_layerwise: true
galore_target: mlp,self_attn
galore_layerwise: true # choices: [true, false], use false for DDP training
galore_target: all
galore_rank: 128
galore_scale: 2.0
@@ -28,7 +29,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
gradient_accumulation_steps: 1 # use 1 for layerwise galore
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
@@ -37,7 +38,7 @@ pure_bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,6 @@
### model
model_name_or_path: models/llama3-8b-pro
trust_remote_code: true
### method
stage: sft
@@ -35,7 +36,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
loraplus_lr_ratio: 16.0
@@ -34,7 +36,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,5 +1,6 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
@@ -34,7 +35,7 @@ pure_bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,6 +1,10 @@
# The batch generation can be SLOW using this config.
# For faster inference, we recommend to use `scripts/vllm_infer.py`.
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
trust_remote_code: true
### method
stage: sft

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
pissa_init: true
pissa_iter: 16
@@ -36,7 +38,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,2 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -0,0 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -1,4 +1,5 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
finetuning_type: lora
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: vllm
vllm_enforce_eager: true
trust_remote_code: true

View File

@@ -1,2 +1,4 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: llava
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -1,2 +1,4 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -0,0 +1,10 @@
### model
model_name_or_path: saves/llama3-8b/full/sft
template: llama3
trust_remote_code: true
### export
export_dir: output/llama3_full_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@@ -1,11 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
trust_remote_code: true
### export
export_dir: models/llama3_gptq
export_dir: output/llama3_gptq
export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json
export_size: 2
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@@ -4,10 +4,10 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3
finetuning_type: lora
trust_remote_code: true
### export
export_dir: models/llama3_lora_sft
export_size: 2
export_dir: output/llama3_lora_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@@ -4,10 +4,10 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl
finetuning_type: lora
trust_remote_code: true
### export
export_dir: models/qwen2_vl_lora_sft
export_size: 2
export_dir: output/qwen2_vl_lora_sft
export_size: 5
export_device: cpu
export_legacy_format: false

View File

@@ -1,23 +0,0 @@
### model
model_name_or_path: saves/llama3-8b/full/sft
### method
stage: sft
do_predict: true
finetuning_type: full
### dataset
eval_dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 2048
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/full/predict
overwrite_output_dir: true
### eval
per_device_eval_batch_size: 1
predict_with_generate: true

View File

@@ -1,11 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset
dataset: identity,alpaca_en_demo
@@ -14,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/full/sft
@@ -21,6 +23,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +34,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,19 +1,26 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
freeze_vision_tower: true # choices: [true, false]
freeze_multi_modal_projector: true # choices: [true, false]
freeze_language_model: false # choices: [true, false]
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset
dataset: mllm_demo,identity
dataset: mllm_demo,identity,alpaca_en_demo
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_vl-7b/full/sft
@@ -21,6 +28,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +39,10 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
@@ -16,6 +18,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/dpo
@@ -23,6 +26,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -33,9 +37,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: dpo_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,6 +1,7 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft
trust_remote_code: true
### method
finetuning_type: lora

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: kto
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
pref_beta: 0.1
@@ -34,7 +36,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,11 +1,13 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model: saves/llama3-8b/lora/reward
trust_remote_code: true
### method
stage: ppo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: pt
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -13,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/pretrain
@@ -20,6 +23,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -30,9 +34,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: c4_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: rm
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -14,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/reward
@@ -21,6 +24,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +35,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: dpo_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -14,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -21,6 +24,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +35,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,12 +1,14 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
deepspeed: examples/deepspeed/ds_z3_config.json
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset
dataset: identity,alpaca_en_demo
@@ -15,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -22,6 +25,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -32,9 +36,11 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,54 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # or use local absolute path
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: identity,alpaca_en_demo
dataset_dir: REMOTE:llamafactory/demo_data # or use local absolute path
template: llama3
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: tmp_dir
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### ray
ray_run_name: llama3_8b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # number of GPUs to use
resources_per_worker:
GPU: 1
placement_strategy: PACK
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: llava-hf/llava-1.5-7b-hf
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -14,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llava1_5-7b/lora/sft
@@ -21,6 +24,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +35,10 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,14 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
@@ -16,6 +20,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_vl-7b/lora/dpo
@@ -23,6 +28,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -33,9 +39,10 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,19 +1,24 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
dataset: mllm_demo,identity # video: mllm_video_demo
dataset: mllm_demo,identity,alpaca_en_demo # video: mllm_video_demo
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_vl-7b/lora/sft
@@ -21,6 +26,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
### train
per_device_train_batch_size: 1
@@ -31,9 +37,10 @@ lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -33,7 +35,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -33,7 +35,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,12 +1,16 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4
quantization_method: bitsandbytes
double_quantization: false
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
deepspeed: examples/deepspeed/ds_z0_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -25,7 +29,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
@@ -34,7 +38,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -1,10 +1,12 @@
### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -33,7 +35,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -2,11 +2,13 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
### dataset
@@ -35,7 +37,7 @@ bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -2,6 +2,22 @@
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "llamafactory"
dynamic = [
"version",
"dependencies",
"optional-dependencies",
"requires-python",
"scripts",
"authors",
"description",
"readme",
"license",
"keywords",
"classifiers"
]
[tool.ruff]
target-version = "py38"
line-length = 119
@@ -31,3 +47,19 @@ indent-style = "space"
docstring-code-format = true
skip-magic-trailing-comma = false
line-ending = "auto"
[tool.uv]
conflicts = [
[
{ extra = "torch-npu" },
{ extra = "aqlm" },
],
[
{ extra = "torch-npu" },
{ extra = "liger-kernel" },
],
[
{ extra = "torch-npu" },
{ extra = "vllm" },
]
]

View File

@@ -1,9 +1,11 @@
transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
gradio>=4.0.0,<5.0.0
tokenizers>=0.19.0,<=0.21.0
gradio>=4.38.0,<=5.21.0
pandas>=2.0.0
scipy
einops
@@ -20,4 +22,5 @@ packaging
pyyaml
numpy<2.0.0
av
librosa
tyro<0.9.0

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,15 +19,10 @@ from typing import Any, Dict
import fire
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CONFIG_NAME = "config.json"
@@ -40,34 +35,42 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
baichuan2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
if "W_pack" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
llama_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
llama_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
llama_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
elif "lm_head" in key:
llama2_state_dict[key] = torch.nn.functional.normalize(value)
llama_state_dict[key] = torch.nn.functional.normalize(value)
else:
llama2_state_dict[key] = value
llama_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
llama_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
)
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
shard = {tensor: llama_state_dict[tensor].contiguous() for tensor in tensors}
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
if not state_dict_split.is_sharded:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
else:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print(f"Model weights saved in {output_dir}")
print(f"Model weights saved in {output_dir}.")
def save_config(input_dir: str, output_dir: str):
@@ -81,6 +84,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,16 +19,11 @@ from typing import Any, Dict
import fire
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
from transformers.utils import check_min_version
@@ -49,60 +44,68 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None:
torch_dtype = value.dtype
if "wte" in key:
llama2_state_dict["model.embed_tokens.weight"] = value
llama_state_dict["model.embed_tokens.weight"] = value
elif "ln_f" in key:
llama2_state_dict["model.norm.weight"] = value
llama_state_dict["model.norm.weight"] = value
else:
key = key.replace("transformer.h", "model.layers")
if "attn.c_attn" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
llama_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
llama_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
proj_size : 2 * proj_size, ...
]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
llama_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
elif "attn.c_proj" in key:
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
llama_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
llama_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
value[:, 0]
).squeeze()
elif "ln_1" in key:
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
llama_state_dict[key.replace("ln_1", "input_layernorm")] = value
elif "ln_2" in key:
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
llama_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
elif "mlp.w1" in key:
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
llama_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
elif "mlp.w2" in key:
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
llama_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
elif "mlp.c_proj" in key:
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
llama_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
elif "lm_head" in key:
llama2_state_dict[key] = value
llama_state_dict[key] = value
else:
raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
llama_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
)
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
shard = {tensor: llama_state_dict[tensor].contiguous() for tensor in tensors}
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
if not state_dict_split.is_sharded:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
else:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print(f"Model weights saved in {output_dir}")
print(f"Model weights saved in {output_dir}.")
return str(torch_dtype).replace("torch.", "")
@@ -134,6 +137,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")

View File

@@ -1,4 +1,4 @@
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
# Copyright 2025 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
@@ -18,24 +18,19 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict
import fire
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig
def change_name(name: str, old_index: int, new_index: int) -> str:
@@ -46,46 +41,42 @@ def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: str = "2GB",
shard_size: str = "5GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand)
config.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(output_dir)
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
if save_safetensors:
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
torch_dtype="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
state_dict = model.state_dict()
if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
setattr(config, "num_hidden_layers", num_layers + num_expand)
config.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer.save_pretrained(output_dir)
print(f"Expanding model of {num_layers} layers to {num_layers + num_expand} layers.")
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True
)
assert isinstance(model, PreTrainedModel) # type hint
if save_safetensors and getattr(model.config, "tie_word_embeddings", False):
del model.lm_head # safetensors does not allow shared weights
split = num_layers // num_expand
layer_cnt = 0
output_state_dict = OrderedDict()
state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print(f"Add layer {layer_cnt} copied from layer {i}")
print(f"Add layer {layer_cnt} copied from layer {i}.")
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
@@ -95,7 +86,7 @@ def block_expansion(
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print(f"Add layer {layer_cnt} expanded from layer {i}")
print(f"Add layer {layer_cnt} expanded from layer {i}.")
layer_cnt += 1
for key, value in state_dict.items():
@@ -103,21 +94,29 @@ def block_expansion(
output_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
output_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
)
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
shard = {tensor: output_state_dict[tensor].contiguous() for tensor in tensors}
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
if not state_dict_split.is_sharded:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
else:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print(f"Model weights saved in {output_dir}")
print(f"Model weights saved in {output_dir}.")
print("- Fine-tune this model with:")
print(f"model_name_or_path: {output_dir}")

View File

@@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py

View File

@@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py

View File

@@ -1,4 +1,4 @@
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# Copyright 2025 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/

View File

@@ -1,4 +1,4 @@
# Copyright 2024 imoneoi and the LlamaFactory team.
# Copyright 2025 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
@@ -22,9 +22,9 @@ import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from transformers import DataCollatorForLanguageModeling
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
@@ -41,7 +41,7 @@ def calculate_lr(
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
cutoff_len: int = 2048, # i.e. maximum input length during training
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
@@ -59,6 +59,7 @@ def calculate_lr(
template=template,
cutoff_len=cutoff_len,
packing=packing,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -71,24 +72,25 @@ def calculate_lr(
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
token_batch_size = cutoff_len * batch_size * valid_ratio
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
lr, valid_ratio * 100, token_batch_size
)
)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -142,21 +142,23 @@ def calculate_mfu(
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
dist.barrier()
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if int(os.getenv("LOCAL_RANK", "0")) == 0:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,16 +20,16 @@ import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from transformers import DataCollatorForLanguageModeling
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
@@ -39,36 +39,39 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
chosen_features = []
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
attention_mask = [1] * (prompt_len + answer_len)
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
chosen_features.append(
{
"input_ids": feature["chosen_input_ids"],
"attention_mask": feature["chosen_attention_mask"],
"labels": feature["chosen_input_ids"] if self.train_on_prompt else feature["chosen_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
)
return super().__call__(chosen_features)
def calculate_ppl(
model_name_or_path: str,
save_name: str,
save_name: str = "ppl.json",
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
@@ -80,6 +83,7 @@ def calculate_ppl(
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -93,10 +97,12 @@ def calculate_ppl(
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
)
elif stage == "rm":
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")
@@ -107,7 +113,7 @@ def calculate_ppl(
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,8 @@ def length_cdf(
):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(
@@ -41,6 +42,7 @@ def length_cdf(
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
@@ -51,7 +53,7 @@ def length_cdf(
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):
for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())

151
scripts/vllm_infer.py Normal file
View File

@@ -0,0 +1,151 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional
import fire
from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import check_version, get_device_count
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
def vllm_infer(
model_name_or_path: str,
adapter_name_or_path: str = None,
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
top_p: float = 0.7,
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,
):
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version("vllm>=0.4.3,<=0.7.3")
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
model_args, data_args, _, generating_args = get_infer_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
preprocessing_num_workers=16,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
)
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
inputs, prompts, labels = [], [], []
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)
}
else:
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_request = None
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.")
print("*" * 70)
if __name__ == "__main__":
fire.Fire(vllm_infer)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ def get_requires() -> List[str]:
def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
@@ -44,9 +44,9 @@ def get_console_scripts() -> List[str]:
extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
@@ -54,13 +54,25 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<0.6.4"],
"vllm": ["vllm>=0.4.3,<=0.7.3"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest"],
}
@@ -70,7 +82,7 @@ def main():
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
author_email="hiyouga AT buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@@ -79,7 +91,7 @@ def main():
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.8.0",
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
@@ -91,10 +103,10 @@ def main():
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,17 +20,17 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.46.1
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.41.2,<=4.46.1
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ from typing import Optional
from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
@@ -60,7 +61,7 @@ async def sweeper() -> None:
@asynccontextmanager
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
if chat_model.engine.name == EngineName.HF:
asyncio.create_task(sweeper())
yield
@@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
else:
return await create_chat_completion_response(request, chat_model)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
@@ -70,7 +72,8 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@@ -99,10 +102,12 @@ def _process_request(
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
text_content = ""
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
text_content += input_item.text
else:
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
@@ -112,6 +117,8 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB"))
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@@ -168,7 +175,7 @@ async def create_chat_completion_response(
if isinstance(result, list):
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
function = Function(name=tool.name, arguments=tool.arguments)
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,8 @@ if TYPE_CHECKING:
from vllm import AsyncLLMEngine
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..extras.constants import EngineName
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -41,6 +42,7 @@ class BaseEngine(ABC):
Must implements async methods: chat(), stream_chat() and get_scores().
"""
name: "EngineName"
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
@@ -68,6 +70,7 @@ class BaseEngine(ABC):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
@@ -83,6 +86,7 @@ class BaseEngine(ABC):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""

View File

@@ -20,6 +20,7 @@ import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
@@ -27,7 +28,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .base_engine import BaseEngine, Response
@@ -47,10 +48,9 @@ class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface":
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
@@ -66,13 +66,14 @@ class ChatModel:
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
)
return task.result()
@@ -83,12 +84,13 @@ class ChatModel:
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
def stream_chat(
self,
@@ -97,12 +99,13 @@ class ChatModel:
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -117,12 +120,15 @@ class ChatModel:
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
async for new_token in self.engine.stream_chat(
messages, system, tools, images, videos, audios, **input_kwargs
):
yield new_token
def get_scores(

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from trl import PreTrainedModelWrapper
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.HF
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
@@ -63,7 +64,7 @@ class HuggingfaceEngine(BaseEngine):
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.")
logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -81,9 +82,10 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
@@ -94,14 +96,25 @@ class HuggingfaceEngine(BaseEngine):
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
prompt_ids,
None,
mm_input_dict["images"],
mm_input_dict["videos"],
mm_input_dict["audios"],
tokenizer,
processor,
)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
@@ -114,6 +127,7 @@ class HuggingfaceEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
@@ -133,7 +147,10 @@ class HuggingfaceEngine(BaseEngine):
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
skip_special_tokens=skip_special_tokens
if skip_special_tokens is not None
else generating_args["skip_special_tokens"],
eos_token_id=template.get_stop_token_ids(tokenizer),
pad_token_id=tokenizer.pad_token_id,
)
)
@@ -166,12 +183,30 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif (
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
): # for minicpmv inputs
value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype)
if key == "second_per_grid_ts": # qwen2.5vl special case
gen_kwargs[key] = value.tolist()
else:
gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
gen_kwargs["tokenizer"] = tokenizer
if "audio_feature_lens" in mm_inputs:
gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
gen_kwargs.pop("image_sizes", None)
return gen_kwargs, prompt_length
@@ -188,6 +223,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@@ -201,11 +237,19 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # post-process the minicpm_o output
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response = tokenizer.batch_decode(
response_ids,
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
clean_up_tokenization_spaces=True,
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
@@ -234,6 +278,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
@@ -247,9 +292,14 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
@@ -292,6 +342,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -309,6 +360,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
async with self.semaphore:
@@ -323,6 +375,7 @@ class HuggingfaceEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
@@ -340,6 +393,7 @@ class HuggingfaceEngine(BaseEngine):
tools,
images,
videos,
audios,
input_kwargs,
)
async with self.semaphore:

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,27 +19,22 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -54,6 +49,8 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.VLLM
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
@@ -67,11 +64,12 @@ class VllmEngine(BaseEngine):
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"download_dir": model_args.cache_dir,
"dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
@@ -83,6 +81,9 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
@@ -105,22 +106,30 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token or ""
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
if audios is not None:
mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
@@ -131,6 +140,7 @@ class VllmEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
@@ -160,25 +170,23 @@ class VllmEngine(BaseEngine):
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens,
skip_special_tokens=True,
skip_special_tokens=skip_special_tokens
if skip_special_tokens is not None
else self.generating_args["skip_special_tokens"],
)
if images is not None: # add image features
image_data = []
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image_data.append(image)
multi_modal_data = {"image": image_data}
multi_modal_data = {
"image": self.template.mm_plugin._regularize_images(
images,
image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels,
)
}
else:
multi_modal_data = None
@@ -198,10 +206,11 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for request_output in generator:
final_output = request_output
@@ -226,10 +235,11 @@ class VllmEngine(BaseEngine):
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count
from .extras.misc import get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@@ -86,20 +86,26 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=os.getenv("NNODES", "1"),
node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
@@ -119,4 +125,8 @@ def main():
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError(f"Unknown command: {command}.")
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,264 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras import logging
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
def _convert_images(
images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None
else:
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
return images
def _convert_videos(
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
def convert_alpaca(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts alpaca format dataset to the standard format.
"""
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts sharegpt format dataset to the standard format.
"""
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=False,
remove_columns=column_names,
**kwargs,
)

View File

@@ -18,9 +18,18 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import numpy as np
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from transformers import ProcessorMixin
@@ -72,25 +81,85 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __post_init__(self):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
audios = feature.pop("audios", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_audios.extend(audios)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"])
fake_input_ids = []
if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(
fake_messages, fake_images, [], [], self.processor
)
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
_fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
)
fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images
batch_imglens[0] = 1
if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
fake_audios = [np.zeros(1600)]
fake_messages = self.template.mm_plugin.process_messages(
fake_messages, [], [], fake_audios, self.processor
)
_fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
_fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
_fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
)
fake_input_ids.extend(_fake_input_ids)
batch_audios = fake_audios
batch_audlens[0] = 1
if len(fake_input_ids) != 0:
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
batch_images,
batch_videos,
batch_audios,
batch_imglens,
batch_vidlens,
batch_audlens,
batch_input_ids,
self.processor,
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
@@ -98,9 +167,31 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
}
if "second_per_grid_ts" in mm_inputs:
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs
bsz, seq_length = features["input_ids"].shape
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
return features
@@ -120,6 +211,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)
return features
@@ -145,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
concatenated_features.append(target_feature)
@@ -168,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
@@ -175,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
target_features.append(target_feature)
kl_features.append(kl_feature)
@@ -185,6 +283,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "cross_attention_mask" in kl_batch: # for mllama inputs.
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]

View File

@@ -0,0 +1,271 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
from ..extras import logging
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
@dataclass
class DatasetConverter:
dataset_attr: "DatasetAttr"
data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
r"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(medias, list):
medias = [medias] if medias is not None else []
elif len(medias) == 0:
return None
else:
medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
for i in range(len(medias)):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
else:
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
return medias
@abstractmethod
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
r"""
Converts a single example in the dataset to the standard format.
"""
...
@dataclass
class AlpacaDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
prompt = []
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
for old_prompt, old_response in example[self.dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if self.dataset_attr.prompt and example[self.dataset_attr.prompt]:
query.append(example[self.dataset_attr.prompt])
if self.dataset_attr.query and example[self.dataset_attr.query]:
query.append(example[self.dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], str)
and isinstance(example[self.dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.rejected]},
]
elif self.dataset_attr.response and isinstance(example[self.dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[self.dataset_attr.response]}]
else: # unsupervised
response = []
output = {
"_prompt": prompt,
"_response": response,
"_system": example[self.dataset_attr.system] if self.dataset_attr.system else "",
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
@dataclass
class SharegptDatasetConverter(DatasetConverter):
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
tag_mapping = {
self.dataset_attr.user_tag: Role.USER.value,
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
self.dataset_attr.observation_tag: Role.OBSERVATION.value,
self.dataset_attr.function_tag: Role.FUNCTION.value,
self.dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (self.dataset_attr.user_tag, self.dataset_attr.observation_tag)
even_tags = (self.dataset_attr.assistant_tag, self.dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[self.dataset_attr.messages]
if (
self.dataset_attr.system_tag
and len(messages) != 0
and messages[0][self.dataset_attr.role_tag] == self.dataset_attr.system_tag
):
system = messages[0][self.dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[self.dataset_attr.system] if self.dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[self.dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
break
aligned_messages.append(
{
"role": tag_mapping[message[self.dataset_attr.role_tag]],
"content": message[self.dataset_attr.content_tag],
}
)
if (not self.dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
self.dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
elif self.dataset_attr.kto_tag and isinstance(example[self.dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[self.dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
self.dataset_attr.ranking
and isinstance(example[self.dataset_attr.chosen], dict)
and isinstance(example[self.dataset_attr.rejected], dict)
): # pairwise example
chosen = example[self.dataset_attr.chosen]
rejected = example[self.dataset_attr.rejected]
if (
chosen[self.dataset_attr.role_tag] not in accept_tags[-1]
or rejected[self.dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{
"role": tag_mapping[chosen[self.dataset_attr.role_tag]],
"content": chosen[self.dataset_attr.content_tag],
},
{
"role": tag_mapping[rejected[self.dataset_attr.role_tag]],
"content": rejected[self.dataset_attr.content_tag],
},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[self.dataset_attr.tools] if self.dataset_attr.tools else "",
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
}
return output
DATASET_CONVERTERS = {
"alpaca": AlpacaDatasetConverter,
"sharegpt": SharegptDatasetConverter,
}
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
r"""
Register a new dataset converter.
"""
if name in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} already exists.")
DATASET_CONVERTERS[name] = dataset_converter
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
r"""
Gets a dataset converter.
"""
if name not in DATASET_CONVERTERS:
raise ValueError(f"Dataset converter {name} not found.")
return DATASET_CONVERTERS[name](dataset_attr, data_args)
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
dataset_converter = get_dataset_converter(dataset_attr.formatting, dataset_attr, data_args)
return dataset.map(
dataset_converter,
batched=False,
remove_columns=column_names,
**kwargs,
)

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
def merge_dataset(
@@ -54,14 +54,16 @@ def merge_dataset(
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
@@ -69,24 +71,75 @@ def merge_dataset(
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
return DatasetDict({"train": train_set, "validation": val_set})
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6:
if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
else:
dataset_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict)
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]
if "validation" in dataset:
dataset_module["eval_dataset"] = dataset["validation"]
else:
eval_dataset = {}
for key in dataset.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
else: # single dataset
dataset_module["train_dataset"] = dataset
return dataset_module

View File

@@ -1,4 +1,4 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,16 +16,12 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
from .tool_utils import FunctionCall, get_tool_utils
@dataclass
@@ -90,43 +86,44 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements
@dataclass
class FunctionFormatter(Formatter):
class FunctionFormatter(StringFormatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = []
for name, arguments in functions:
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(0) + function_str
return elements
return super().apply(content=function_str)
@dataclass
@@ -141,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

Some files were not shown because too many files have changed in this diff Show More