179 Commits

Author SHA1 Message Date
Yaowei Zheng
ca75f1edf3 [model] fix vlm utils (#8388) 2025-06-17 01:08:49 +08:00
Yaowei Zheng
3a3bae1cfe [data] fix qwen2vl pos ids (#8387) 2025-06-17 00:48:54 +08:00
Yaowei Zheng
31874e4f62 [version] release v0.9.3 (#8386) 2025-06-16 19:21:32 +08:00
Yaowei Zheng
9a2d1dec62 [assets] update wechat (#8385) 2025-06-16 18:23:22 +08:00
Aman Gupta
8e4ac78607 [trainer] Add LD-DPO objective (#8362) 2025-06-12 16:10:38 +08:00
Yaowei Zheng
44f1b9b5ad [misc] tiny fixes (#8348) 2025-06-10 15:30:58 +08:00
阿丹(adan)
b41697c9b6 [model] support MiniCPM4 (#8314) 2025-06-10 14:38:39 +08:00
Kingsley
31bca4d172 [model] support Mistral3.1 small 2503 (#8335) 2025-06-09 10:37:42 +08:00
Chenhao Zhang
fa4360dca7 [assets] Add awesome works used LLaMA-Factory (#8333) 2025-06-09 10:21:17 +08:00
Yaowei Zheng
9acab4949d [model] fix model generate (#8327) 2025-06-07 08:47:50 +08:00
Vivek Iyer
32b4574094 [model] pushing FFT with unsloth (#8325)
Co-authored-by: viyer <vivek_iyer2@apple.com>
2025-06-07 08:20:58 +08:00
Yaowei Zheng
03a93ec513 [data] fix empty template (#8312) 2025-06-06 13:50:50 +08:00
Yaowei Zheng
bcb6b94658 [setup] fix uv (#8311) 2025-06-06 11:54:15 +08:00
Yaowei Zheng
c0710be6d7 [assets] update readme (#8303) 2025-06-05 23:23:15 +08:00
Kingsley
212a8006dc [tests] add visual model save test (#8248)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2025-06-05 20:38:01 +08:00
Yaowei Zheng
ed70f8d5a2 [assets] fix npu docker (#8298) 2025-06-05 19:09:20 +08:00
Butui Hu
1a33d65a56 [launcher] Add elastic and fault-tolerant training support (#8286)
Signed-off-by: Butui Hu <hot123tea123@gmail.com>
2025-06-05 16:40:03 +08:00
Kingsley
69c9e379d5 [script] add Script description for qwen_omni_merge (#8293) 2025-06-05 13:22:01 +08:00
Yaowei Zheng
e9fe9cee29 [assets] update docker files (#8291) 2025-06-04 23:30:46 +08:00
Yaowei Zheng
cb7ab69783 [assets] update readme (#8288) 2025-06-04 17:46:12 +08:00
Yaowei Zheng
c1ed76e109 [assets] add icon (#8276) 2025-06-03 20:36:21 +08:00
Kingsley
c224d17cb2 [data] support nested images input for videos (#8264) 2025-06-03 20:26:29 +08:00
Ze-Yi LIN
c4e51d40e0 [tracking] swanlab add llamafactory tag (#8258) 2025-06-03 18:42:29 +08:00
Kingsley
554e89ff02 [model] add MIMO_VL (#8249) 2025-06-01 03:54:54 +08:00
Yaowei Zheng
fee2122f09 [deps] upgrade transformers to 4.52.4 (#8245) 2025-05-31 16:51:40 +08:00
Akshat Sehgal
c7e63bead7 [model] add smollm2 support (#8220) 2025-05-31 16:29:01 +08:00
hoshi-hiyouga
3e1a7fcb9c [assets] update readme (#8235) 2025-05-30 16:52:12 +08:00
Kingsley
2aaede8ef4 [scripts] specify model class for qwen_omni merge (#8227) 2025-05-30 14:20:12 +08:00
hoshi-hiyouga
42bebc341d [model] add deepseek 0528 models (#8215) 2025-05-29 21:37:07 +08:00
hoshi-hiyouga
83a9ff5853 [assets] fix docker images (#8203) 2025-05-28 22:26:05 +08:00
yzoaim
519bab86e6 [workflow] auto push docker images (#8181)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-05-28 20:21:15 +08:00
hoshi-hiyouga
dbc9f5a5d9 [assets] update Dockerfile (#8201) 2025-05-28 20:20:59 +08:00
hoshi-hiyouga
9b152d9cb5 [webui] fix skip args (#8195) 2025-05-28 18:11:07 +08:00
Youngwoo Kim
6c3cd400b5 [data] Reading files from cloud is broken (#8182) (#8183) 2025-05-28 15:50:44 +08:00
hoshi-hiyouga
4d3ffa2ec4 [assets] fix docker image (#8180) 2025-05-27 19:01:31 +08:00
hoshi-hiyouga
2bf8e993ab [data] fix shared file system (#8179) 2025-05-27 18:36:03 +08:00
hoshi-hiyouga
d4a413eb37 [webui] add extra args to export (#8178) 2025-05-27 18:25:31 +08:00
hoshi-hiyouga
00974a3169 [assets] update docker files (#8176) 2025-05-27 18:15:23 +08:00
hoshi-hiyouga
46ccf84aaa [webui] add infer extra args (#8167) 2025-05-27 12:04:00 +08:00
hoshi-hiyouga
07343ca83d [webui] fix input args (#8162) 2025-05-27 02:05:54 +08:00
hoshi-hiyouga
3c7dc66a92 [model] add smollm2 and medgemma (#8161) 2025-05-26 23:19:58 +08:00
hoshi-hiyouga
ba032828e2 [deps] upgrade transformers (#8159) 2025-05-26 22:03:58 +08:00
Akshat Sehgal
501e7d8a8f feat: add smollm support (#8050) 2025-05-26 19:47:54 +08:00
wangzhan
12292e4283 [api] support repetition_penalty and align presence_penalty with OpenAI Client (#7958) 2025-05-26 18:45:11 +08:00
Kingsley
f08b748199 [data] fix internvl plugin when using PIL images (#8129) 2025-05-22 01:32:59 +08:00
hoshi-hiyouga
d2a3036a23 [misc] update data readme (#8128) 2025-05-21 22:41:18 +08:00
hoshi-hiyouga
9ae17cd173 [deps] update to transformers 4.52 (#8125) 2025-05-21 05:16:18 +08:00
hoshi-hiyouga
56926d76f9 [data] llama3 multi tool support (#8124) 2025-05-21 02:01:12 +08:00
hoshi-hiyouga
c2f6f2fa77 [assets] update readme (#8110) 2025-05-20 02:44:18 +08:00
hoshi-hiyouga
9b5baa97f0 [data] qwen3 fixes (#8109) 2025-05-20 02:00:30 +08:00
hoshi-hiyouga
45030ff803 [model] switch to gptqmodel (#8108) 2025-05-19 22:25:40 +08:00
piamo
bc7f00f2c7 [model] update rope kwargs for yarn (#8101) 2025-05-19 20:07:54 +08:00
hoshi-hiyouga
beae231af6 [doc] add no build isolation (#8103) 2025-05-19 19:25:13 +08:00
Ma, Xiaochen
a0b4b91577 [trainer] fix KeyError at end of pretrain (#8099) 2025-05-19 18:01:26 +08:00
Biao Wang
90492f3582 [misc] fix cli (#8095)
Co-authored-by: wangbiao11 <wangbiao11@baidu.com>
2025-05-19 17:59:39 +08:00
Saiya
ab41f7956c [infer] support lora adapter for SGLang backend (#8067) 2025-05-16 23:33:47 +08:00
Kingsley
52b23f9e56 [data] add forward compatibility for video_utils in Transformers 4.52.0 (#8077) 2025-05-16 17:41:04 +08:00
Eric Tang
a9aa392ba4 [data] support loading folder from remote (#8078) 2025-05-16 15:35:38 +08:00
Shawn Tao
0b773234e5 [infer] Modify vllm_infer.py to batch preprocess to avoid too much files opened error (#8051)
Co-authored-by: Kingsley <82590017+Kuangdd01@users.noreply.github.com>
2025-05-15 10:54:35 +08:00
hoshi-hiyouga
712c57f3b4 [assets] update windows installation (#8042) 2025-05-13 17:01:56 +08:00
hoshi-hiyouga
dc080399c6 [model] add seed coder and qwen3 quant models (#8039) 2025-05-13 15:59:55 +08:00
hoshi-hiyouga
68fc068cab [data] fix kimi vl template (#8015) 2025-05-11 20:45:19 +08:00
Kingsley
9620825892 [scripts] add video params for vllm infer (#7992) 2025-05-09 21:16:52 +08:00
yunhao-tech
26cbb03a5f [data] Avoid repetitive tool description warp (#8000)
Co-authored-by: chenyunhao <chenyunhao@wps.cn>
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-05-09 21:16:37 +08:00
tpoisonooo
5f4b793e04 [docs] add GraphGen (#7974) 2025-05-07 12:23:11 +02:00
hoshi-hiyouga
994ab6424a [misc] update liger kernel patch (#7966) 2025-05-06 20:32:16 +02:00
hoshi-hiyouga
aa9ed4db59 [example] update examples (#7964) 2025-05-06 17:24:25 +02:00
Kingsley
ef86a53063 [model] add mimo7b (#7946) 2025-05-06 17:10:30 +02:00
hoshi-hiyouga
bf0286e1e3 [misc] fix qwen2 omni (#7962) 2025-05-06 15:39:13 +02:00
hoshi-hiyouga
ce7032e1b3 [model] add qwen2 omni 3b (#7945) 2025-05-03 16:36:51 +08:00
Eric Chen
5763017cea [assets] Warp Support README Update (#7887) 2025-05-02 00:08:48 +08:00
hoshi-hiyouga
13b05e74f1 [hparam] add enable think argument (#7928) 2025-04-30 17:21:30 +08:00
hoshi-hiyouga
c566e39b7d [data] fix base plugin (#7924) 2025-04-30 16:28:05 +08:00
hoshi-hiyouga
052ca871bd [data] optimize qwen3 loss computation (#7923) 2025-04-30 16:18:00 +08:00
hoshi-hiyouga
73198a6645 [misc] fix uv (#7913) 2025-04-30 07:45:03 +08:00
hoshi-hiyouga
d4ee44bdef [data] add eval_on_each_dataset arg (#7912) 2025-04-30 06:56:43 +08:00
hoshi-hiyouga
6d2cde43e7 [data] replace eos token for base models (#7911) 2025-04-30 06:52:28 +08:00
hoshi-hiyouga
11295cdea0 [data] improve mm plugin (#7910) 2025-04-30 06:34:28 +08:00
hoshi-hiyouga
98f23c6584 [model] add qwen3 (#7885) 2025-04-29 09:34:05 +08:00
Kingsley
db9559456c [data] fix qwen2.5 omni template (#7883) 2025-04-29 00:58:23 +08:00
hoshi-hiyouga
3ae5da2a04 [model] fix dsv3 leaf node (#7879) 2025-04-28 18:11:09 +08:00
hoshi-hiyouga
d173cb50f5 [data] fix qwen2 omni plugin (#7875) 2025-04-28 14:22:41 +08:00
zhaop-l
df27d7e48a [trainer] make projector trainable in freeze training (#7872)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-28 13:19:37 +08:00
hoshi-hiyouga
bb5b83352b [data] fix minicpmo vllm infer (#7870) 2025-04-28 01:59:53 +08:00
Kingsley
1157f4e246 fix attn patch for kimivl (#7867) 2025-04-27 23:12:28 +08:00
Eric Tang
ef03832cd4 [ray] add storage filesystem to ray config (#7854) 2025-04-27 22:12:40 +08:00
hoshi-hiyouga
2233b739fa [model] fix vit gradient checkpointing (#7830) 2025-04-23 22:48:48 +08:00
hoshi-hiyouga
091d2539e8 Merge commit from fork 2025-04-23 16:38:27 +08:00
hoshi-hiyouga
c1a7f2ebb2 [model] fix moe zero3 (#7826) 2025-04-23 15:30:49 +08:00
Kingsley
fa0eb91f1f [data] fix internvl plugin (#7817) 2025-04-23 00:58:22 +08:00
hoshi-hiyouga
49f9ed0232 [assets] update model readme (#7804) 2025-04-22 16:43:56 +08:00
Kingsley
2a564c25d1 [model] add arch check for InternVL (#7803) 2025-04-22 16:38:05 +08:00
Kingsley
7500e761d3 [misc] update internvl constants (#7801) 2025-04-22 15:53:08 +08:00
hoshi-hiyouga
fddcd43c88 [trainer] support early stop (#7797) 2025-04-22 01:59:33 +08:00
hoshi-hiyouga
0e4ce039ee [data] improve mmplugin (#7795) 2025-04-22 01:25:33 +08:00
hoshi-hiyouga
b07628dea5 [example] add bash usage (#7794) 2025-04-22 00:25:51 +08:00
Juanxi Tian
12ada72ed4 [trainer] Add Muon Optimizer (#7749)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-21 23:38:37 +08:00
hoshi-hiyouga
416853dd25 [parser] support omegaconf (#7793) 2025-04-21 23:30:30 +08:00
Changrui Chen
bd7bc31c79 [data] Fix wrong position ids with packed attention masks (#7754)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-21 23:19:36 +08:00
flashJd
0ac641326b [misc] fix new tokens adding (#7253)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-21 23:19:02 +08:00
ddddng
c5ba9106ec [model] fix gemma3 export (#7786)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-21 23:07:11 +08:00
Sachin Beldona
3b2d3794a5 [misc] fix bug in constant (#7765)
Co-authored-by: Sachin Beldona <sbeldona@cs.cmu.edu>
2025-04-21 23:06:31 +08:00
hoshi-hiyouga
b605c20768 [assets] update wechat (#7792) 2025-04-21 21:29:42 +08:00
hoshi-hiyouga
39169986ef [trainer] fix pt loss (#7748)
* fix pt loss

* robust

* fix

* test
2025-04-17 03:15:35 +08:00
hoshi-hiyouga
86ebb219d6 [breaking] bump transformers to 4.45.0 & improve ci (#7746)
* update ci

* fix

* fix

* fix

* fix

* fix
2025-04-17 02:36:48 +08:00
hoshi-hiyouga
d222f63cb7 [infer] set env for vllm ascend (#7745) 2025-04-17 01:08:55 +08:00
Kingsley
2e518f255f [model] support intern-VL 2.5-3 series (#7258)
* add internvl and rebase

* fix for internvl2&3

* remove lines

* fix video_inputs & lint

* nit

* add constants

* remove lines

* fix

* fix error

* pass ci

* pass ci

* skip internvl & nit
2025-04-17 00:31:30 +08:00
ENg-122
8f88a4e6a4 [misc] improve entrypoint (#7345)
* 纯粹优化下入口代码,因为看到if else太多了

* Update cli.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-16 21:48:23 +08:00
leo-pony
b9263ff5ac [infer] support vllm-ascend (#7739) 2025-04-16 20:06:47 +08:00
hoshi-hiyouga
ee2ab093a7 [api] fix chat messages (#7732) 2025-04-15 16:39:08 +08:00
hoshi-hiyouga
3df021d4d7 [deps] upgrade vllm (#7728) 2025-04-15 14:57:40 +08:00
Joe Schoonover
e252abf051 [docker] patch docker-rocm (#7725)
* Update Dockerfile

* Fix typo

* Fix syntax for /bin/sh conditional

* Add build args to docker-compose

* Change shell to /bin/bash

This is required for "==" syntax in conditional string comparison
2025-04-15 13:36:39 +08:00
hoshi-hiyouga
1134baeedd [assets] update model readme (#7724) 2025-04-15 00:41:09 +08:00
Kingsley
2101399c94 [model] Support Kimi_VL thinking/instruct (#7719)
* add kimi_vl

* patch config

* check version

* Update mm_plugin.py

* Update mm_plugin.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-15 00:21:58 +08:00
hoshi-hiyouga
3f91a95250 [misc] fix env vars (#7715) 2025-04-14 16:04:04 +08:00
hoshi-hiyouga
7c61b35106 [misc] upgrade cli (#7714) 2025-04-14 15:41:22 +08:00
hoshi-hiyouga
f518bfba5b [deps] upgrade transformers (#7704) 2025-04-13 18:11:34 +08:00
Yuxuan Zhang
8162f94db5 [model] add GLM-4-0414 (#7695)
* Update README_zh.md

* update
2025-04-13 17:10:45 +08:00
hoshi-hiyouga
1f0c52b73c [deps] fix uv conflicts (#7686)
* fix #7678

* Update setup.py

* Update tests.yml

* Update publish.yml

* Update Makefile
2025-04-11 18:02:24 +08:00
Eric Tang
a8caf09c7f [data] support for specifying a dataset in cloud storage (#7567)
* add support for loading datasets from s3/gcs

* add comments to readme

* run linter and address comments

* add option to pass in kwargs to ray init (i.e. runtime env)

* address comment

* revert mixed up changes
2025-04-10 11:31:35 +08:00
Eric Tang
bb8d79bae2 [ray] allow for specifying ray.init kwargs (i.e. runtime_env) (#7647)
* ray init kwargs

* Update trainer_utils.py

* fix ray args

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-10 11:31:05 +08:00
Dain Kim
1c436c9f25 [bugfix] enable_gemma_liger_kernel (#7660)
- The `enable_liger_kernel` function for the Gemma model series was not executed due to the existing `if` statement in the code.
- Changed the line to an `elif` statement so that the `apply_liger_kernel` function is executed properly.

resolved: #7628
2025-04-10 11:27:30 +08:00
jilongW
1b0934bccb [misc] fix cuda warn on intel GPU (#7655) 2025-04-09 21:37:54 +08:00
hoshi-hiyouga
4eec541857 [data] add coig-p dataset (#7657) 2025-04-09 21:18:25 +08:00
hoshi-hiyouga
89a4f9ec7f [assets] update readme (#7654) 2025-04-09 18:27:38 +08:00
hoshi-hiyouga
1abd71b551 [assets] update readme (#7644) 2025-04-09 01:06:06 +08:00
Kingsley
349c56c51c [data] Fix bugs of use_audio_in_video in Qwen2.5 Omni (#7638)
* cache _mm_inputs

* nit

* support for use_audio_in_video

* remove cache

* fix data

* Update mllm_video_audio_demo.json
2025-04-08 18:40:10 +08:00
Shawn Tao
acb09fa3a3 [trainer] fix key error (#7635) 2025-04-08 18:39:50 +08:00
Adarsh Shirawalmath
f75b91077b [sglang] support transformers 4.51.0 (#7639) 2025-04-08 18:39:23 +08:00
hoshi-hiyouga
c3c0efbaa0 [misc] fix packing and eval plot (#7623) 2025-04-07 18:20:57 +08:00
hoshi-hiyouga
5115dc8c7f [assets] update readme (#7612) 2025-04-06 13:58:49 +08:00
hoshi-hiyouga
831e7f1cfd [model] add llama4 (#7611) 2025-04-06 13:42:31 +08:00
Kingsley
d4cfa9507e [data] fix qwen2.5 omni plugin (#7578)
* specific entry

* Update mm_plugin.py

* fix fps cal

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-04-02 23:58:39 +08:00
Kingsley
d32c6c014d [data] fix qwen2.5 omni plugin (#7573)
* align key with qwen2vl

* nit && change scripts
2025-04-02 21:28:52 +08:00
gechengze
7b9deb9410 [trainer] fix batch processing in PPO trainer (#7576) 2025-04-02 21:17:48 +08:00
hoshi-hiyouga
5e22597ff1 [infer] vllm video/audio inference (#7566) 2025-04-02 02:27:04 +08:00
hoshi-hiyouga
2bfcad2394 [model] fix kv cache (#7564) 2025-04-01 23:07:46 +08:00
Yu Shi Jie
a13b1bb49a [model] fix use_cache patching for gemma3 multimodal (#7500) 2025-04-01 16:06:48 +08:00
Ritesh Goru
d10467d178 [data] specify position_ids in PackedSupervisedDatasetProcessor for neat_packing (#7318)
* use position_ids for neat_packing with fa2

* revert fa2 changes
2025-04-01 16:03:13 +08:00
taoharry
aac70663fd [webui] fix launch with proxy (#7332) 2025-04-01 15:52:56 +08:00
Billy Cao
00409ff28a [data] shard the dataset to allow multiprocessing when streaming is enabled (#7530)
* Shard the dataset when streaming to allow multiprocessing

* Allow user to not set dataset_shards to ensure backward compatibility
2025-04-01 15:36:23 +08:00
Hao
d70b3b4bc5 [trainer] new kto mismatch pair creation strategy (#7509) 2025-04-01 15:21:53 +08:00
hoshi-hiyouga
e76eba051d [data] fix qwen2.5 omni collator (#7553) 2025-04-01 00:15:12 +08:00
Kingsley
7eed496336 [model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes

* preserve image_sizes

* init plugin

* support audio-text2text lora

* nit

* support image/video-text2text, audio-text2text

* remove args

* remove lines

* add docs && nit

* remove some comments

* fix && add merge part script

* add license
2025-03-31 20:39:35 +08:00
hoshi-hiyouga
0f8296626a [deps] pin pydantic to 2.10.6 (#7546) 2025-03-31 14:42:28 +08:00
Kingsley
8da1d2fa71 [data] fix pixtral plugin (#7505)
* preserve `image_sizes`

* add comments
2025-03-27 17:06:40 +08:00
Xu-pixel
b578a7d5b6 [3rdparty] support swanlab lark notification (#7481) 2025-03-27 01:52:01 +08:00
Kdump
24afceddb7 [trainer] fix wsd scheduler (#7304)
* [trainer] Warmup_stable_decay supports setting the number of stable and decay steps according to the warmup_ratio ratio

* Update trainer_utils.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-03-26 15:25:02 +08:00
hoshi-hiyouga
0583d06676 [model] add qwen2vl 32b & upgrade peft (#7469)
* add qwen2vl 32b

* fix ci

* upgrade peft to 0.15

* fix ci

* fix ci
2025-03-25 12:15:58 +08:00
GuoCoder
ec6a261568 [model] fix lora on quant models (#7456)
Co-authored-by: root <root@ai>
2025-03-25 11:59:46 +08:00
Xiaosu Zhu
6b3b97c738 [misc] update liger-kernel's monkey patch (#7453)
* Update liger_kernel.py

* Update setup.py
2025-03-25 11:58:52 +08:00
AbdelKarim ELJANDOUBI
6d3748f727 [misc] enable liger kernel for gemma3 text and paligemma (#7466)
* add gemma3 text

* add paligemma (1,2 and 2 mix)
2025-03-25 09:27:43 +08:00
Kenny Lam
7c890170e3 [misc] enable liger kernel for gemma3 (#7462) 2025-03-24 19:09:59 +08:00
hoshi-hiyouga
ca42c0c406 [assets] fix gemma3 readme (#7449) 2025-03-24 10:31:25 +08:00
hoshi-hiyouga
7203365b80 [trainer] fix vlm loss for transformers 4.49 (#7448) 2025-03-24 10:24:05 +08:00
rumichi
3612946dd9 [docker] upgrade to torch 2.6 (#7442) 2025-03-23 21:18:08 +08:00
hoshi-hiyouga
3aa4f32e9c [misc] fix ci (#7441)
* fix ci

* improve ci
2025-03-23 21:09:35 +08:00
hoshi-hiyouga
304796b803 [misc] fix license (#7440) 2025-03-23 19:31:56 +08:00
SnowFox4004
7cfd6e4bb0 [scripts] support compute score on vllm's predictions (#7419)
* enable manual bleu&rouge eval by adding `scripts/eval_bleu_rouge.py`

* added libraries check

* update: 使用datasets库的多进程加速处理

* update:
- 使用 fire.Fire
- 修改代码格式

* Update eval_bleu_rouge.py: correctly uses fire

Deleted the code of using sys.argv

* Update eval_bleu_rouge.py

---------

Co-authored-by: SnowFox4004 <manba@out>
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-03-23 19:21:01 +08:00
hoshi-hiyouga
05b19d6952 [deps] upgrade transformers to 4.50.0 (#7437)
* upgrade transformers

* fix hf cache

* fix dpo trainer
2025-03-23 17:44:27 +08:00
hoshi-hiyouga
919415dba9 [deps] upgrade vllm to 0.8 (#7436) 2025-03-23 14:32:22 +08:00
Guo, Quan
a959c2a509 [misc] fix sglang deps (#7432)
* feat: Add transformer version requirement for sglang

* feat: add srt to sglang which is required for running sglang

Other options are srt_hip, srt_xpu, srt_npu, srt_hpu, srt_cpu, for different computation architectures.
2025-03-23 14:07:10 +08:00
Eric Tang
db0a08db6f [3rdparty] fix redundant process group destroy for ray (#7395)
* fix redundant process group destroy for ray

* Update tuner.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
2025-03-21 10:56:47 +08:00
hoshi-hiyouga
a306f0f5a2 [version] fix minicpmo (#7378) 2025-03-20 16:59:31 +08:00
hoshi-hiyouga
63752fccf7 [assets] update wechat (#7361) 2025-03-18 21:31:09 +08:00
hoshi-hiyouga
1f9773395b [misc] set dev version (#7351) 2025-03-18 00:10:53 +08:00
hoshi-hiyouga
128b5b12b3 [data] fix template (#7349) 2025-03-17 23:45:20 +08:00
hoshi-hiyouga
d5915a7dd7 [assets] update videos (#7340)
* Update README.md

* Update README_zh.md
2025-03-17 15:48:02 +08:00
Hertz
ec1154662b [model] support hunyuan 7b (#7317)
* [Model]supported tencent-hunyuan model

* [Model]supported tencent-hunyuan model(fix)

* [Model]supported tencent-hunyuan model(fix)
2025-03-15 20:55:24 +08:00
Qiaolin Yu
a44a53ebec [inference] support sglang backend (#7278)
* Mimic SGLang offline Engine

* Add more tests and args

* Pass all current tests

* Clean Code

* fix sample_params

* clean code

* Fix Stream Chat

* change sglang from engine mode to server mode

* fix

* Fix Review Issues

* Use SGLang Built-In Utilities

* Fix test SGLang

* Some Doc Issue

* fix sglang engine

* add readme

---------

Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
2025-03-15 04:37:58 +08:00
hoshi-hiyouga
93e6184cbe [data] gemma3 plugin pan and scan (#7294)
* gemma3 pan and scan

* add test case

* fix test
2025-03-13 23:29:23 +08:00
hoshi-hiyouga
0be0d7796a [assets] update video (#7287) 2025-03-13 18:45:47 +08:00
Ritesh Goru
480369a9f2 [data] efficient 4d_attention_mask creation in neat_packing (#7272) 2025-03-13 03:31:12 +08:00
hoshi-hiyouga
650a9a9057 [misc] update format (#7277) 2025-03-13 02:53:08 +08:00
hoshi-hiyouga
4b9d8da5a4 [model] support gemma3 (#7273) 2025-03-13 01:35:23 +08:00
hoshi-hiyouga
e6159ad730 [misc] upgrade deps (#7257) 2025-03-12 00:33:47 +08:00
hoshi-hiyouga
264538cb26 [misc] upgrade format to py39 (#7256) 2025-03-12 00:08:41 +08:00
hoshi-hiyouga
5995800bce [ci] update workflow (#7255) 2025-03-11 22:57:49 +08:00
hoshi-hiyouga
bf8b483186 [core] release v0.9.2 (#7254) 2025-03-11 22:42:23 +08:00
218 changed files with 7334 additions and 4564 deletions

View File

@@ -3,12 +3,12 @@
.github .github
.venv .venv
cache cache
data
docker docker
saves saves
hf_cache hf_cache
ms_cache ms_cache
om_cache om_cache
shared_data
output output
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@@ -16,6 +16,8 @@ USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB= USE_OPENMIND_HUB=
USE_RAY= USE_RAY=
RECORD_VRAM= RECORD_VRAM=
OPTIM_TORCH=
NPU_JIT_COMPILE=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=
MASTER_ADDR= MASTER_ADDR=

View File

@@ -12,7 +12,7 @@ body:
attributes: attributes:
value: | value: |
Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead. Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
请勿在此分类下创建和框架 bug 无关的 issues请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。 请勿在此分类下创建和框架 bug 无关的 issues训练问题求助请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
- type: checkboxes - type: checkboxes
id: reminder id: reminder
@@ -47,8 +47,6 @@ body:
description: | description: |
Please provide entry arguments, error messages and stack traces that reproduces the problem. Please provide entry arguments, error messages and stack traces that reproduces the problem.
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
Remember to wrap your log messages with \`\`\`.
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
value: | value: |
```text ```text

66
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
name: docker
on:
workflow_dispatch:
push:
branches:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "docker/**"
- ".github/workflows/*.yml"
pull_request:
branches:
- "main"
paths:
- "**/*.py"
- "requirements.txt"
- "docker/**"
- ".github/workflows/*.yml"
jobs:
build:
runs-on: ubuntu-latest
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
environment:
name: docker
url: https://hub.docker.com/r/hiyouga/llamafactory
steps:
- name: Free up disk space
run: |
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache
df -h
- name: Checkout
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ vars.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
file: ./docker/docker-cuda/Dockerfile
build-args: |
EXTRAS=metrics,deepspeed,liger-kernel
push: ${{ github.event_name != 'pull_request' }}
tags: docker.io/hiyouga/llamafactory:latest
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -1,6 +1,7 @@
name: publish name: publish
on: on:
workflow_dispatch:
release: release:
types: types:
- published - published
@@ -27,14 +28,9 @@ jobs:
with: with:
python-version: "3.9" python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
- name: Build package - name: Build package
run: | run: |
python -m build make build
- name: Publish package - name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,6 +1,7 @@
name: tests name: tests
on: on:
workflow_dispatch:
push: push:
branches: branches:
- "main" - "main"
@@ -21,7 +22,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: python:
- "3.9" - "3.9"
- "3.10" - "3.10"
- "3.11" - "3.11"
@@ -30,9 +31,25 @@ jobs:
- "ubuntu-latest" - "ubuntu-latest"
- "windows-latest" - "windows-latest"
- "macos-13" - "macos-13"
transformers:
- null
include: # test backward compatibility
- python: "3.9"
os: "ubuntu-latest"
transformers: "4.45.0"
- python: "3.9"
os: "ubuntu-latest"
transformers: "4.49.0"
- python: "3.9"
os: "ubuntu-latest"
transformers: "4.51.0"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }} OS_NAME: ${{ matrix.os }}
@@ -44,19 +61,42 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python }}
cache: "pip" cache: "pip"
cache-dependency-path: "setup.py" cache-dependency-path: "**/requirements*.txt"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[torch,dev]" python -m pip install ".[torch,dev]"
- name: Install transformers
if: ${{ matrix.transformers }}
run: |
python -m pip install "transformers==${{ matrix.transformers }}"
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
- name: Check license
run: |
make license
- name: Check build
run: |
make build
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env:
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

3
.gitignore vendored
View File

@@ -166,8 +166,8 @@ cython_debug/
uv.lock uv.lock
# custom .gitignore # custom .gitignore
ms_cache/
hf_cache/ hf_cache/
ms_cache/
om_cache/ om_cache/
cache/ cache/
config/ config/
@@ -176,3 +176,4 @@ output/
wandb/ wandb/
swanlog/ swanlog/
generated_predictions.jsonl generated_predictions.jsonl
predictions_score.json

View File

@@ -1,14 +1,17 @@
.PHONY: build commit quality style test .PHONY: build commit license quality style test
check_dirs := scripts src tests setup.py check_dirs := scripts src tests setup.py
build: build:
pip install build && python -m build pip3 install build && python3 -m build
commit: commit:
pre-commit install pre-commit install
pre-commit run --all-files pre-commit run --all-files
license:
python3 tests/check_license.py $(check_dirs)
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
ruff format --check $(check_dirs) ruff format --check $(check_dirs)

244
README.md
View File

@@ -5,8 +5,8 @@
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![Citation](https://img.shields.io/badge/citation-614-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
@@ -14,34 +14,48 @@
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Open in Alaya](assets/alaya_new.svg)](https://docs.alayanew.com/docs/documents/newActivities/llamafactory/?utm_source=LLaMA-Factory)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) [![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
<h3 align="center"> ### Used by [Amazon](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/), [NVIDIA](https://developer.nvidia.com/rtx/ai-toolkit), [Aliyun](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory), etc.
Easily fine-tune 100+ large language models with zero-code <a href="#quickstart">CLI</a> and <a href="#fine-tuning-with-llama-board-gui-powered-by-gradio">Web UI</a>
</h3>
<p align="center">
<picture>
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
</picture>
</p>
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg). <div align="center" markdown="1">
### Supporters ❤️
<a href="https://warp.dev/llama-factory">
<img alt="Warp sponsorship" width="400" src="https://github.com/user-attachments/assets/ab8dd143-b0fd-4904-bdc5-dd7ecac94eae">
</a>
#### [Warp, the agentic terminal for developers](https://warp.dev/llama-factory)
[Available for MacOS, Linux, & Windows](https://warp.dev/llama-factory)
----
### Easily fine-tune 100+ large language models with zero-code [CLI](#quickstart) and [Web UI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
![GitHub Trend](https://trendshift.io/api/badge/repositories/4535)
</div>
👋 Join our [WeChat group](assets/wechat.jpg), [NPU user group](assets/wechat_npu.jpg) or [Alaya NeW user group](assets/wechat_alaya.png).
\[ English | [中文](README_zh.md) \] \[ English | [中文](README_zh.md) \]
**Fine-tuning a large language model can be easy as...** **Fine-tuning a large language model can be easy as...**
https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3 https://github.com/user-attachments/assets/3991a3a8-4276-4d30-9cab-4cb0c4b9b99e
Choose your path: Choose your path:
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/ - **Documentation**: https://llamafactory.readthedocs.io/en/latest/
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) - **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) - **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
> [!NOTE] > [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them. > Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -49,7 +63,7 @@ Choose your path:
## Table of Contents ## Table of Contents
- [Features](#features) - [Features](#features)
- [Benchmark](#benchmark) - [Blogs](#blogs)
- [Changelog](#changelog) - [Changelog](#changelog)
- [Supported Models](#supported-models) - [Supported Models](#supported-models)
- [Supported Training Approaches](#supported-training-approaches) - [Supported Training Approaches](#supported-training-approaches)
@@ -76,51 +90,67 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang).
### Day-N Support for Fine-Tuning Cutting-Edge Models ### Day-N Support for Fine-Tuning Cutting-Edge Models
| Support Date | Model Name | | Support Date | Model Name |
| ------------ | ---------------------------------------------------------- | | ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | | Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
## Benchmark ## Blogs
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory. - [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
- [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
![benchmark](assets/benchmark.svg) <details><summary>All Blogs</summary>
<details><summary>Definitions</summary> - [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024) - [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024) - [LLaMA Factory: Fine-tuning the LLaMA3 Model for Role-Playing](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) (Chinese)
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
</details> </details>
## Changelog ## Changelog
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
[25/04/14] We supported fine-tuning the **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** and **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** models.
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
<details><summary>Full Changelog</summary>
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. [25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model. [25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** models.
<details><summary>Full Changelog</summary>
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
@@ -216,6 +246,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
</details> </details>
> [!TIP]
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
@@ -226,22 +259,28 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab) | 1B/2B/8B/14B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
| [MiniCPM](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
@@ -253,9 +292,12 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Seed Coder](https://huggingface.co/ByteDance-Seed) | 8B | seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
@@ -268,6 +310,10 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
> >
> Remember to use the **SAME** template in training and inference. > Remember to use the **SAME** template in training and inference.
>
> \*: You should install the `transformers` from main branch and use `DISABLE_VERSION_CHECK=1` to skip version check.
>
> \*\*: You need to install a specific version of `transformers` to use the corresponding model.
Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported. Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
@@ -371,8 +417,10 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) - [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback) - [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -393,11 +441,12 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.9 | 3.10 |
| torch | 1.13.1 | 2.5.1 | | torch | 2.0.0 | 2.6.0 |
| transformers | 4.41.2 | 4.49.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.45.0 | 4.50.0 |
| datasets | 2.16.0 | 3.2.0 | | datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 | | accelerate | 0.34.0 | 1.2.1 |
| peft | 0.11.1 | 0.12.0 | | peft | 0.14.0 | 0.15.1 |
| trl | 0.8.6 | 0.9.6 | | trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
@@ -405,8 +454,8 @@ huggingface-cli login
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.16.4 | | deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.7.3 | | vllm | 0.4.3 | 0.8.2 |
| flash-attn | 2.3.0 | 2.7.2 | | flash-attn | 2.5.6 | 2.7.2 |
### Hardware Requirement ### Hardware Requirement
@@ -428,16 +477,27 @@ huggingface-cli login
> [!IMPORTANT] > [!IMPORTANT]
> Installation is mandatory. > Installation is mandatory.
#### Install from Source
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]" --no-build-isolation
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, dev
> [!TIP] #### Install from Docker Image
> Use `pip install --no-deps -e .` to resolve package conflicts.
```bash
docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
```
This image is built on Ubuntu 22.04 (x86\_64), CUDA 12.4, Python 3.11, PyTorch 2.6.0, and Flash-attn 2.7.4.
Find the pre-built images: https://hub.docker.com/r/hiyouga/llamafactory/tags
Please refer to [build docker](#build-docker) to build the image yourself.
<details><summary>Setting up a virtual environment with <b>uv</b></summary> <details><summary>Setting up a virtual environment with <b>uv</b></summary>
@@ -457,6 +517,20 @@ uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora
<details><summary>For Windows users</summary> <details><summary>For Windows users</summary>
#### Install PyTorch
You need to manually install the GPU version of PyTorch on the Windows platform. Please refer to the [official website](https://pytorch.org/get-started/locally/) and the following command to install PyTorch with CUDA support:
```bash
pip uninstall torch torchvision torchaudio
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
python -c "import torch; print(torch.cuda.is_available())"
```
If you see `True` then you have successfully installed PyTorch with CUDA support.
Try `dataloader_num_workers: 0` if you encounter `Can't pickle local object` error.
#### Install BitsAndBytes #### Install BitsAndBytes
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
@@ -495,6 +569,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
@@ -540,11 +615,13 @@ pip install .
### Data Preparation ### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace / ModelScope / Modelers hub, load the dataset in local disk, or specify a path to s3/gcs cloud storage.
> [!NOTE] > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. > Please update `data/dataset_info.json` to use your custom dataset.
You can also use **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** or **[GraphGen](https://github.com/open-sciencelab/GraphGen)** to create synthetic data for fine-tuning.
### Quickstart ### Quickstart
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
@@ -600,22 +677,13 @@ For CUDA users:
```bash ```bash
docker build -f ./docker/docker-cuda/Dockerfile \ docker build -f ./docker/docker-cuda/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --gpus=all \ docker run -dit --ipc=host --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \ -p 8000:8000 \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -625,19 +693,12 @@ docker exec -it llamafactory bash
For Ascend NPU users: For Ascend NPU users:
```bash ```bash
# Choose docker image upon your environment
docker build -f ./docker/docker-npu/Dockerfile \ docker build -f ./docker/docker-npu/Dockerfile \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=torch-npu,metrics \
-t llamafactory:latest . -t llamafactory:latest .
# Change `device` upon your resources docker run -dit --ipc=host \
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
@@ -648,7 +709,6 @@ docker run -dit \
--device /dev/davinci_manager \ --device /dev/davinci_manager \
--device /dev/devmm_svm \ --device /dev/devmm_svm \
--device /dev/hisi_hdc \ --device /dev/hisi_hdc \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -659,25 +719,15 @@ For AMD ROCm users:
```bash ```bash
docker build -f ./docker/docker-rocm/Dockerfile \ docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit \ docker run -dit --ipc=host \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \ -p 8000:8000 \
--device /dev/kfd \ --device /dev/kfd \
--device /dev/dri \ --device /dev/dri \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -686,12 +736,14 @@ docker exec -it llamafactory bash
</details> </details>
<details><summary>Details about volume</summary> <details><summary>Use Docker volumes</summary>
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. You can uncomment `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` in the Dockerfile to use data volumes.
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- `om_cache`: Similar to Hugging Face cache but for Modelers users. When building the Docker image, use `-v ./hf_cache:/root/.cache/huggingface` argument to mount the local directory to the container. The following data volumes are available.
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `hf_cache`: Utilize Hugging Face cache on the host machine.
- `shared_data`: The directionary to store datasets on the host machine.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine. - `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details> </details>
@@ -699,7 +751,7 @@ docker exec -it llamafactory bash
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]
@@ -841,6 +893,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/) 1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072) 1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611) 1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
1. Zhang et al. CPsyCoun: A Report-based Multi-turn Dialogue Reconstruction and Evaluation Framework for Chinese Psychological Counseling. ACL 2024. [[paper]](https://aclanthology.org/2024.findings-acl.830.pdf)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -854,14 +907,15 @@ If you have a project that should be incorporated, please contact via email or c
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357) 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention. 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost. 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
1. **[EmoLLM](https://github.com/SmartFlowAI/EmoLLM)**: A project about large language models (LLMs) and mental health.
</details> </details>
## License ## License
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -5,8 +5,8 @@
[![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-349-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![Citation](https://img.shields.io/badge/citation-614-green)](https://scholar.google.com/scholar?cites=12620864006390196564)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Docker Pulls](https://img.shields.io/docker/pulls/hiyouga/llamafactory)](https://hub.docker.com/r/hiyouga/llamafactory/tags)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
@@ -14,36 +14,50 @@
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Open in Alaya](assets/alaya_new.svg)](https://docs.alayanew.com/docs/documents/newActivities/llamafactory/?utm_source=LLaMA-Factory)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) [![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
<h3 align="center"> ### 获得[亚马逊](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)、[英伟达](https://developer.nvidia.cn/rtx/ai-toolkit)、[阿里云](https://help.aliyun.com/zh/pai/use-cases/fine-tune-a-llama-3-model-with-llama-factory)等的应用。
使用零代码<a href="#快速开始">命令行</a><a href="#llama-board-可视化微调由-gradio-驱动">Web UI</a> 轻松微调百余种大模型
</h3>
<p align="center">
<picture>
<img alt="Github trend" src="https://trendshift.io/api/badge/repositories/4535">
</picture>
</p>
<div align="center" markdown="1">
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。 ### 赞助商 ❤️
<a href="https://warp.dev/llama-factory">
<img alt="Warp sponsorship" width="400" src="https://github.com/user-attachments/assets/ab8dd143-b0fd-4904-bdc5-dd7ecac94eae">
</a>
#### [Warp面向开发者的智能终端](https://warp.dev/llama-factory)
[适用于 MacOS、Linux 和 Windows](https://warp.dev/llama-factory)
----
### 使用零代码[命令行](#快速开始)与 [Web UI](#llama-board-可视化微调由-gradio-驱动) 轻松微调百余种大模型
![GitHub Trend](https://trendshift.io/api/badge/repositories/4535)
</div>
👋 加入我们的[微信群](assets/wechat.jpg)、[NPU 用户群](assets/wechat_npu.jpg)或 [九章智算云算力优惠群](assets/wechat_alaya.png)。
\[ [English](README.md) | 中文 \] \[ [English](README.md) | 中文 \]
**微调大模型可以像这样轻松…** **微调大模型可以像这样轻松…**
https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
选择你的打开方式: 选择你的打开方式:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607 - **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/ - **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **PAI-DSW免费试用**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) - **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Amazon SageMaker**[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) - **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
> [!NOTE] > [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。 > 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -51,7 +65,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 目录 ## 目录
- [项目特色](#项目特色) - [项目特色](#项目特色)
- [性能指标](#性能指标) - [官方博客](#官方博客)
- [更新日志](#更新日志) - [更新日志](#更新日志)
- [模型](#模型) - [模型](#模型)
- [训练方法](#训练方法) - [训练方法](#训练方法)
@@ -78,36 +92,54 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
- **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 - **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。
### 最新模型的 Day-N 微调适配 ### 最新模型的 Day-N 微调适配
| 适配时间 | 模型名称 | | 适配时间 | 模型名称 |
| ------------ | ---------------------------------------------------------- | | ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | | Day 0 | Qwen3 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 / Llama 4 |
## 性能指标 ## 官方博客
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。 - [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
- [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
![benchmark](assets/benchmark.svg) <details><summary>全部博客</summary>
<details><summary>变量定义</summary> - [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4截断长度=1024 - [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4截断长度=1024 - [LLaMA Factory微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1截断长度=1024
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA Factory 的 LoRA 微调中采用 `lora_rank=32`
</details> </details>
## 更新日志 ## 更新日志
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
[25/04/14] 我们支持了 **[GLM-Z1](https://huggingface.co/THUDM/GLM-Z1-9B-0414)** 和 **[Kimi-VL](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct)** 模型的微调。
[25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。
<details><summary>展开日志</summary>
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
[25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 [25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
@@ -116,13 +148,11 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。 [25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。
<details><summary>展开日志</summary>
[25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 [25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/14] 我们支持了 **[InternLM 3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
@@ -218,6 +248,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
</details> </details>
> [!TIP]
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
## 模型 ## 模型
| 模型名 | 参数量 | Template | | 模型名 | 参数量 | Template |
@@ -228,22 +261,28 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4/glmz1 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab) | 1B/2B/8B/14B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 4](https://huggingface.co/meta-llama) | 109B/402B | llama4 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo |
| [MiniCPM](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
@@ -255,9 +294,12 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
| [Seed Coder](https://huggingface.co/ByteDance-Seed) | 8B | seed_coder |
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 |
@@ -270,6 +312,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
> >
> 请务必在训练和推理时采用**完全一致**的模板。 > 请务必在训练和推理时采用**完全一致**的模板。
>
> \*:您需要从 main 分支安装 `transformers` 并使用 `DISABLE_VERSION_CHECK=1` 来跳过版本检查。
>
> \*\*:您需要安装特定版本的 `transformers` 以使用该模型。
项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
@@ -373,8 +419,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [COIG-P (zh)](https://huggingface.co/datasets/m-a-p/COIG-P)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) - [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback) - [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [RLAIF-V (en)](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -392,23 +440,24 @@ huggingface-cli login
## 软硬件依赖 ## 软硬件依赖
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.9 | 3.10 |
| torch | 1.13.1 | 2.5.1 | | torch | 2.0.0 | 2.6.0 |
| transformers | 4.41.2 | 4.49.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.45.0 | 4.50.0 |
| datasets | 2.16.0 | 3.2.0 | | datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 | | accelerate | 0.34.0 | 1.2.1 |
| peft | 0.11.1 | 0.12.0 | | peft | 0.14.0 | 0.15.1 |
| trl | 0.8.6 | 0.9.6 | | trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.16.4 | | deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.7.3 | | vllm | 0.4.3 | 0.8.2 |
| flash-attn | 2.3.0 | 2.7.2 | | flash-attn | 2.5.6 | 2.7.2 |
### 硬件依赖 ### 硬件依赖
@@ -430,16 +479,27 @@ huggingface-cli login
> [!IMPORTANT] > [!IMPORTANT]
> 此步骤为必需。 > 此步骤为必需。
#### 从源码安装
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]" --no-build-isolation
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、dev
> [!TIP] #### 从镜像安装
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
```bash
docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
```
该镜像基于 Ubuntu 22.04x86\_64、CUDA 12.4、Python 3.11、PyTorch 2.6.0 和 Flash-attn 2.7.4 构建。
查看全部镜像https://hub.docker.com/r/hiyouga/llamafactory/tags
请参阅[构建 Docker](#构建-docker) 来重新构建镜像。
<details><summary>使用 <b>uv</b> 构建虚拟环境</summary> <details><summary>使用 <b>uv</b> 构建虚拟环境</summary>
@@ -457,9 +517,22 @@ uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora
</details> </details>
<details><summary>Windows 用户指南</summary> <details><summary>Windows 用户指南</summary>
#### 安装 PyTorch
Windows 平台需要额外手动安装 GPU 版本的 PyTorch 依赖包,您可以参考[官方网站](https://pytorch.org/get-started/locally/)和以下命令安装并测试 PyTorch 是否正确安装。
```bash
pip uninstall torch torchvision torchaudio
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
python -c "import torch; print(torch.cuda.is_available())"
```
如果看到 `True` 则说明安装成功。
若遇到类似 `Can't pickle local object` 的报错,请设置 `dataloader_num_workers: 0`
#### 安装 BitsAndBytes #### 安装 BitsAndBytes
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。 如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
@@ -498,6 +571,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
@@ -548,6 +622,8 @@ pip install .
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)** 或 **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
### 快速开始 ### 快速开始
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
@@ -603,22 +679,13 @@ CUDA 用户:
```bash ```bash
docker build -f ./docker/docker-cuda/Dockerfile \ docker build -f ./docker/docker-cuda/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit --gpus=all \ docker run -dit --ipc=host --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \ -p 8000:8000 \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -628,19 +695,12 @@ docker exec -it llamafactory bash
昇腾 NPU 用户: 昇腾 NPU 用户:
```bash ```bash
# 根据您的环境选择镜像
docker build -f ./docker/docker-npu/Dockerfile \ docker build -f ./docker/docker-npu/Dockerfile \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=torch-npu,metrics \
-t llamafactory:latest . -t llamafactory:latest .
# 根据您的资源更改 `device` docker run -dit --ipc=host \
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
@@ -651,7 +711,6 @@ docker run -dit \
--device /dev/davinci_manager \ --device /dev/davinci_manager \
--device /dev/devmm_svm \ --device /dev/devmm_svm \
--device /dev/hisi_hdc \ --device /dev/hisi_hdc \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -662,25 +721,15 @@ AMD ROCm 用户:
```bash ```bash
docker build -f ./docker/docker-rocm/Dockerfile \ docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \ --build-arg PIP_INDEX=https://pypi.org/simple \
--build-arg EXTRAS=metrics \
-t llamafactory:latest . -t llamafactory:latest .
docker run -dit \ docker run -dit --ipc=host \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \ -p 8000:8000 \
--device /dev/kfd \ --device /dev/kfd \
--device /dev/dri \ --device /dev/dri \
--shm-size 16G \
--name llamafactory \ --name llamafactory \
llamafactory:latest llamafactory:latest
@@ -689,12 +738,14 @@ docker exec -it llamafactory bash
</details> </details>
<details><summary>数据卷详情</summary> <details><summary>使用数据卷</summary>
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录 您可以通过移除 Dockerfile 中 `VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]` 的注释来使用数据卷
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供 在构建 Docker 时使用参数 `-v ./hf_cache:/root/.cache/huggingface` 来挂载数据卷。各个数据卷的含义表示如下
- `data`:宿主机中存放数据集的文件夹路径。
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹。
- `shared_data`:宿主机中存放数据集的文件夹路径。
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。 - `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
</details> </details>
@@ -702,7 +753,7 @@ docker exec -it llamafactory bash
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3.yaml infer_backend=vllm vllm_enforce_eager=true
``` ```
> [!TIP] > [!TIP]
@@ -857,6 +908,7 @@ swanlab_run_name: test_run # 可选
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357) 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
1. **[WeClone](https://github.com/xming521/WeClone)**:从聊天记录创造数字分身的一站式解决方案。
</details> </details>
@@ -864,7 +916,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

38
assets/alaya_new.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 47 KiB

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -1,12 +1,15 @@
The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it. The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it.
Currently we support datasets in **alpaca** and **sharegpt** format. The `dataset_info.json` file should be put in the `dataset_dir` directory. You can change `dataset_dir` to use another directory. The default value is `./data`.
Currently we support datasets in **alpaca** and **sharegpt** format. Allowed file types include json, jsonl, csv, parquet, arrow.
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url, file_name and cloud_file_name)",
"ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url and file_name)", "ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url, file_name and cloud_file_name)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name and cloud_file_name)",
"cloud_file_name": "the name of the dataset file in s3/gcs cloud storage. (if specified, ignore file_name)",
"file_name": "the name of the dataset folder or dataset file in this directory. (required if above are not specified)", "file_name": "the name of the dataset folder or dataset file in this directory. (required if above are not specified)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"ranking": "whether the dataset is a preference dataset or not. (default: False)", "ranking": "whether the dataset is a preference dataset or not. (default: False)",
@@ -47,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
* [Example dataset](alpaca_en_demo.json) * [Example dataset](alpaca_en_demo.json)
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response. In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response.
For reasoning models, if the dataset contains chain-of-thought (CoT), the CoT needs to be placed in the model responses, such as `<think>cot</think>output`.
The `system` column will be used as the system prompt if specified. The `system` column will be used as the system prompt if specified.
@@ -56,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)", "system": "system prompt (optional)",
"history": [ "history": [
["human instruction in the first round (optional)", "model response in the first round (optional)"], ["user instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"] ["user instruction in the second round (optional)", "model response in the second round (optional)"]
] ]
} }
] ]
@@ -83,9 +88,14 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
> [!TIP]
> If the model has reasoning capabilities (e.g. Qwen3) but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True` (slow thinking, by default), the empty CoT will be added to the model responses and loss computation will be considered; otherwise (fast thinking), it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
>
> If you want to train data containing CoT with slow thinking and data without CoT with fast thinking, you can set `enable_thinking` to `None`. However, this feature is relatively complicated and should be used with caution.
### Pre-training Dataset ### Pre-training Dataset
- [Example dataset](c4_demo.json) - [Example dataset](c4_demo.jsonl)
In pre-training, only the `text` column will be used for model learning. In pre-training, only the `text` column will be used for model learning.
@@ -116,8 +126,8 @@ It requires a better response in `chosen` column and a worse response in `reject
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"chosen": "chosen answer (required)", "chosen": "chosen answer (required)",
"rejected": "rejected answer (required)" "rejected": "rejected answer (required)"
} }
@@ -171,7 +181,7 @@ Note that the human and observation should appear in odd positions, while gpt an
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "function_call", "from": "function_call",
@@ -222,7 +232,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -230,7 +240,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
}, },
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
} }
], ],
"chosen": { "chosen": {
@@ -272,7 +282,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -311,7 +321,7 @@ The number of images should be identical to the `<image>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>human instruction" "value": "<image>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -352,7 +362,7 @@ The number of videos should be identical to the `<video>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>human instruction" "value": "<video>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -393,7 +403,7 @@ The number of audios should be identical to the `<audio>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>human instruction" "value": "<audio>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -434,7 +444,7 @@ The openai format is simply a special case of the sharegpt format, where the fir
}, },
{ {
"role": "user", "role": "user",
"content": "human instruction" "content": "user instruction"
}, },
{ {
"role": "assistant", "role": "assistant",

View File

@@ -1,6 +1,8 @@
[dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**在 `dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集。 [dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**在 `dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集。
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集 其中 `dataset_info.json` 文件应放置在 `dataset_dir` 目录下。您可以通过修改 `dataset_dir` 参数来使用其他目录。默认值为 `./data`
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。允许的文件类型包括 json、jsonl、csv、parquet 和 arrow。
```json ```json
"数据集名称": { "数据集名称": {
@@ -47,7 +49,9 @@
- [样例数据集](alpaca_zh_demo.json) - [样例数据集](alpaca_zh_demo.json)
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令`instruction\ninput`。而 `output` 列对应的内容为模型回答。 在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为提示词,即提示词`instruction\ninput`。而 `output` 列对应的内容为模型回答。
对于推理类模型的微调,如果数据集包含思维链,则需要把思维链放在模型回答中,例如 `<think>cot</think>output`
如果指定,`system` 列对应的内容将被作为系统提示词。 如果指定,`system` 列对应的内容将被作为系统提示词。
@@ -56,8 +60,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)", "system": "系统提示词(选填)",
"history": [ "history": [
@@ -83,9 +87,14 @@
} }
``` ```
> [!TIP]
> 如果模型本身具备推理能力(如 Qwen3而数据集不包含思维链LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时(慢思考,默认),空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失(快思考)。请在训练和推理时保持 `enable_thinking` 参数一致。
>
> 如果您希望训练包含思维链的数据时使用慢思考,训练不包含思维链的数据时使用快思考,可以设置 `enable_thinking` 为 `None`。但该功能较为复杂,请谨慎使用。
### 预训练数据集 ### 预训练数据集
- [样例数据集](c4_demo.json) - [样例数据集](c4_demo.jsonl)
在预训练时,只有 `text` 列中的内容会用于模型学习。 在预训练时,只有 `text` 列中的内容会用于模型学习。
@@ -116,8 +125,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"chosen": "优质回答(必填)", "chosen": "优质回答(必填)",
"rejected": "劣质回答(必填)" "rejected": "劣质回答(必填)"
} }
@@ -171,7 +180,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "function_call", "from": "function_call",
@@ -222,7 +231,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -230,7 +239,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
}, },
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
} }
], ],
"chosen": { "chosen": {
@@ -272,7 +281,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -311,7 +320,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>人类指令" "value": "<image><image>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -319,6 +328,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"images": [ "images": [
"图像路径(必填)",
"图像路径(必填)" "图像路径(必填)"
] ]
} }
@@ -352,7 +362,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>人类指令" "value": "<video><video>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -360,6 +370,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"videos": [ "videos": [
"视频路径(必填)",
"视频路径(必填)" "视频路径(必填)"
] ]
} }
@@ -393,7 +404,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>人类指令" "value": "<audio><audio>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -401,6 +412,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"audios": [ "audios": [
"音频路径(必填)",
"音频路径(必填)" "音频路径(必填)"
] ]
} }
@@ -435,7 +447,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
}, },
{ {
"role": "user", "role": "user",
"content": "人类指令" "content": "用户指令"
}, },
{ {
"role": "assistant", "role": "assistant",

View File

@@ -1,3 +1,18 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -10,7 +25,7 @@ _DESCRIPTION = "BELLE multiturn chat dataset."
_CITATION = """\ _CITATION = """\
@article{belle2023exploring, @article{belle2023exploring,
title={Exploring the Impact of Instruction Data Scaling on Large Language Models: An Empirical Study on Real-World Use Cases}, title={Exploring the Impact of Instruction Data Scaling on Large Language Models},
author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li}, author={Yunjie Ji, Yong Deng, Yan Gong, Yiping Peng, Qiang Niu, Lei Zhang, Baochang Ma, Xiangang Li},
journal={arXiv preprint arXiv:2303.14742}, journal={arXiv preprint arXiv:2303.14742},
year={2023} year={2023}

300
data/c4_demo.jsonl Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,20 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from typing import List
import datasets import datasets
@@ -50,7 +64,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}), datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepaths": file_path["test"]}),
] ]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: list[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:

BIN
data/mllm_demo_data/4.mp3 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/4.mp4 Normal file

Binary file not shown.

View File

@@ -1,6 +1,20 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from typing import List
import datasets import datasets
@@ -11,7 +25,7 @@ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dia
_CITATION = """\ _CITATION = """\
@misc{UltraChat, @misc{UltraChat,
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen}, author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and others},
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
year = {2023}, year = {2023},
publisher = {GitHub}, publisher = {GitHub},
@@ -40,7 +54,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": file_paths})]
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: list[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
@@ -49,7 +63,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
except Exception: except Exception:
continue continue
key: int = data["id"] key: int = data["id"]
content: List[str] = data["data"] content: list[str] = data["data"]
if len(content) % 2 == 1: if len(content) % 2 == 1:
content.pop(-1) content.pop(-1)
if len(content) < 2: if len(content) < 2:

View File

@@ -1,101 +1,66 @@
# Default use the NVIDIA official image with PyTorch 2.3.0 # https://hub.docker.com/r/hiyouga/pytorch/tags
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html ARG BASE_IMAGE=hiyouga/pytorch:th2.6.0-cu124-flashattn2.7.4-cxx11abi0-devel
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY=""
# Define environments # Define environments
ENV MAX_JOBS=4 ENV MAX_JOBS=16
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV PIP_ROOT_USER_ACTION=ignore
ENV http_proxy="${HTTP_PROXY}"
ENV https_proxy="${HTTP_PROXY}"
# Define installation arguments # Use Bash instead of default /bin/sh
ARG INSTALL_BNB=false SHELL ["/bin/bash", "-c"]
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory # Set the working directory
WORKDIR /app WORKDIR /app
# Set http proxy # Change pip source
RUN if [ -n "$HTTP_PROXY" ]; then \ RUN pip config set global.index-url "${PIP_INDEX}" && \
echo "Configuring proxy..."; \ pip config set global.extra-index-url "${PIP_INDEX}" && \
export http_proxy=$HTTP_PROXY; \ pip install --no-cache-dir --upgrade pip packaging wheel setuptools
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements # Install the requirements
COPY requirements.txt /app COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \ RUN pip install --no-cache-dir -r requirements.txt
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image # Copy the rest of the application into the image
COPY . /app COPY . /app
# Install the LLaMA Factory # Install LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \ RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
if [ "$INSTALL_BNB" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention # Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && \ pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \ pip install --no-cache-dir ninja && \
pip install --proxy=$HTTP_PROXY ninja && \ pip install --no-cache-dir flash-attn --no-build-isolation; \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi fi
# Set up volumes # Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
# Expose port 7860 for the LLaMA Board # Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT 7860 ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860 EXPOSE 7860
# Expose port 8000 for the API service # Expose port 8000 for API service
ENV API_PORT 8000 ENV API_PORT=8000
EXPOSE 8000 EXPOSE 8000
# unset proxy
ENV http_proxy=
ENV https_proxy=
# Reset pip config
RUN pip config unset global.index-url && \
pip config unset global.extra-index-url

View File

@@ -0,0 +1,55 @@
# Start from the pytorch official image (ubuntu-22.04 + cuda-12.4.1 + python-3.11)
# https://hub.docker.com/r/pytorch/pytorch/tags
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
# Define environments
ENV MAX_JOBS=16
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV PIP_ROOT_USER_ACTION=ignore
# Define installation arguments
ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
# Set apt source
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
{ \
echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
} > /etc/apt/sources.list
# Install systemctl and wget
RUN apt-get update && \
apt-get install -y -o Dpkg::Options::="--force-confdef" systemd wget && \
apt-get clean
# Install git and vim
RUN apt-get update && \
apt-get install -y git vim && \
apt-get clean
# Install gcc and g++
RUN apt-get update && \
apt-get install -y gcc g++ && \
apt-get clean
# Change pip source
RUN pip config set global.index-url "${PIP_INDEX}" && \
pip config set global.extra-index-url "${PIP_INDEX}" && \
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
# Install flash-attn-2.7.4.post1 (cxx11abi=False)
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl && \
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)
RUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \
pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl
# Reset pip config
RUN pip config unset global.index-url && \
pip config unset global.extra-index-url

View File

@@ -4,27 +4,15 @@ services:
dockerfile: ./docker/docker-cuda/Dockerfile dockerfile: ./docker/docker-cuda/Dockerfile
context: ../.. context: ../..
args: args:
INSTALL_BNB: "false"
INSTALL_VLLM: "false"
INSTALL_DEEPSPEED: "false"
INSTALL_FLASHATTN: "false"
INSTALL_LIGER_KERNEL: "false"
INSTALL_HQQ: "false"
INSTALL_EETQ: "false"
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data
- ../../output:/app/output
ports: ports:
- "7860:7860" - "7860:7860"
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true tty: true
shm_size: "16gb" # shm_size: "16gb" # ipc: host is set
stdin_open: true stdin_open: true
command: bash command: bash
deploy: deploy:
@@ -33,5 +21,5 @@ services:
devices: devices:
- driver: nvidia - driver: nvidia
count: "all" count: "all"
capabilities: [gpu] capabilities: [ gpu ]
restart: unless-stopped restart: unless-stopped

View File

@@ -1,67 +1,58 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1 # https://hub.docker.com/r/ascendai/cann/tags
# More versions can be found at https://hub.docker.com/r/ascendai/cann/tags ARG BASE_IMAGE=ascendai/cann:8.0.0-910b-ubuntu22.04-py3.11
# FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8 FROM ${BASE_IMAGE}
FROM ascendai/cann:8.0.0-910b-ubuntu22.04-py3.10
# FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8 # Installation arguments
# FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8 ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=torch-npu,metrics
ARG HTTP_PROXY=""
# Define environments # Define environments
ENV MAX_JOBS=16
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV PIP_ROOT_USER_ACTION=ignore
ENV http_proxy="${HTTP_PROXY}"
ENV https_proxy="${HTTP_PROXY}"
# Define installation arguments # Use Bash instead of default /bin/sh
ARG INSTALL_DEEPSPEED=false SHELL ["/bin/bash", "-c"]
ARG PIP_INDEX=https://pypi.org/simple
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
ARG HTTP_PROXY=
# Set the working directory # Set the working directory
WORKDIR /app WORKDIR /app
# Set http proxy # Change pip source
RUN if [ -n "$HTTP_PROXY" ]; then \ RUN pip config set global.index-url "${PIP_INDEX}" && \
echo "Configuring proxy..."; \ pip config set global.extra-index-url "${PIP_INDEX}" && \
export http_proxy=$HTTP_PROXY; \ pip install --no-cache-dir --upgrade pip packaging wheel setuptools
export https_proxy=$HTTP_PROXY; \
fi
# Install the requirements # Install the requirements
COPY requirements.txt /app COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \ RUN pip install --no-cache-dir -r requirements.txt
pip config set global.extra-index-url "$TORCH_INDEX" && \
python -m pip install --upgrade pip && \
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image # Copy the rest of the application into the image
COPY . /app COPY . /app
# Install the LLaMA Factory # Install LLaMA Factory
RUN EXTRA_PACKAGES="torch-npu,metrics"; \ RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi
# Set up volumes # Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
# Expose port 7860 for the LLaMA Board # Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT 7860 ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860 EXPOSE 7860
# Expose port 8000 for the API service # Expose port 8000 for API service
ENV API_PORT 8000 ENV API_PORT=8000
EXPOSE 8000 EXPOSE 8000
# unset proxy
ENV http_proxy=
ENV https_proxy=
# Reset pip config
RUN pip config unset global.index-url && \
pip config unset global.extra-index-url

View File

@@ -4,15 +4,10 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile dockerfile: ./docker/docker-npu/Dockerfile
context: ../.. context: ../..
args: args:
INSTALL_DEEPSPEED: "false"
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: torch-npu,metrics
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data
- ../../output:/app/output
- /usr/local/dcmi:/usr/local/dcmi - /usr/local/dcmi:/usr/local/dcmi
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
- /usr/local/Ascend/driver:/usr/local/Ascend/driver - /usr/local/Ascend/driver:/usr/local/Ascend/driver
@@ -22,7 +17,7 @@ services:
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true tty: true
shm_size: "16gb" # shm_size: "16gb" # ipc: host is set
stdin_open: true stdin_open: true
command: bash command: bash
devices: devices:

View File

@@ -1,93 +1,71 @@
FROM hardandheavy/transformers-rocm:2.2.0 # https://hub.docker.com/r/rocm/pytorch/tags
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
FROM ${BASE_IMAGE}
# Installation arguments
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRAS=metrics
ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
# Define environments # Define environments
ENV MAX_JOBS=4 ENV MAX_JOBS=16
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
ENV DEBIAN_FRONTEND=noninteractive
ENV NODE_OPTIONS=""
ENV PIP_ROOT_USER_ACTION=ignore
ENV http_proxy="${HTTP_PROXY}"
ENV https_proxy="${HTTP_PROXY}"
# Define installation arguments # Use Bash instead of default /bin/sh
ARG INSTALL_BNB=false SHELL ["/bin/bash", "-c"]
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple
ARG HTTP_PROXY=
# Set the working directory # Set the working directory
WORKDIR /app WORKDIR /app
# Set http proxy # Change pip source
RUN if [ -n "$HTTP_PROXY" ]; then \ RUN pip config set global.index-url "${PIP_INDEX}" && \
echo "Configuring proxy..."; \ pip config set global.extra-index-url "${PIP_INDEX}" && \
export http_proxy=$HTTP_PROXY; \ pip install --no-cache-dir --upgrade pip packaging wheel setuptools
export https_proxy=$HTTP_PROXY; \
fi # Reinstall pytorch rocm
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
# Install the requirements # Install the requirements
COPY requirements.txt /app COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \ RUN pip install --no-cache-dir -r requirements.txt
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
if [ -n "$HTTP_PROXY" ]; then \
python -m pip install --proxy=$HTTP_PROXY -r requirements.txt; \
else \
python -m pip install -r requirements.txt; \
fi
# Copy the rest of the application into the image # Copy the rest of the application into the image
COPY . /app COPY . /app
# Install the LLaMA Factory # Install LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \ RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
if [ "$INSTALL_BNB" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
if [ -n "$HTTP_PROXY" ]; then \
pip install --proxy=$HTTP_PROXY -e ".[$EXTRA_PACKAGES]"; \
else \
pip install -e ".[$EXTRA_PACKAGES]"; \
fi
# Rebuild flash attention # Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && \ pip uninstall -y ninja && \
if [ -n "$HTTP_PROXY" ]; then \ pip install --no-cache-dir ninja && \
pip install --proxy=$HTTP_PROXY ninja && \ pip install --no-cache-dir flash-attn --no-build-isolation; \
pip install --proxy=$HTTP_PROXY --no-cache-dir flash-attn --no-build-isolation; \
else \
pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi; \
fi
# Unset http proxy
RUN if [ -n "$HTTP_PROXY" ]; then \
unset http_proxy; \
unset https_proxy; \
fi fi
# Set up volumes # Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
# Expose port 7860 for the LLaMA Board # Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT 7860 ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860 EXPOSE 7860
# Expose port 8000 for the API service # Expose port 8000 for API service
ENV API_PORT 8000 ENV API_PORT=8000
EXPOSE 8000 EXPOSE 8000
# unset proxy
ENV http_proxy=
ENV https_proxy=
# Reset pip config
RUN pip config unset global.index-url && \
pip config unset global.extra-index-url

View File

@@ -4,27 +4,15 @@ services:
dockerfile: ./docker/docker-rocm/Dockerfile dockerfile: ./docker/docker-rocm/Dockerfile
context: ../.. context: ../..
args: args:
INSTALL_BNB: "false"
INSTALL_VLLM: "false"
INSTALL_DEEPSPEED: "false"
INSTALL_FLASHATTN: "false"
INSTALL_LIGER_KERNEL: "false"
INSTALL_HQQ: "false"
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
EXTRAS: metrics
container_name: llamafactory container_name: llamafactory
volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data
- ../../output:/app/output
- ../../saves:/app/saves
ports: ports:
- "7860:7860" - "7860:7860"
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true tty: true
shm_size: "16gb" # shm_size: "16gb" # ipc: host is set
stdin_open: true stdin_open: true
command: bash command: bash
devices: devices:

View File

@@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,14 +22,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{huang2023ceval, @article{huang2023ceval,
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models}, title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian}, author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and others},
journal={arXiv preprint arXiv:2305.08322}, journal={arXiv preprint arXiv:2305.08322},
year={2023} year={2023}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. C-Eval is a comprehensive Chinese evaluation suite for foundation models.
It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
""" """
_HOMEPAGE = "https://cevalbenchmark.com" _HOMEPAGE = "https://cevalbenchmark.com"

View File

@@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,14 +22,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{li2023cmmlu, @article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese}, title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and others,
journal={arXiv preprint arXiv:2306.09212}, journal={arXiv preprint arXiv:2306.09212},
year={2023} year={2023}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context. CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge
and reasoning abilities of LLMs within the Chinese language and cultural context.
""" """
_HOMEPAGE = "https://github.com/haonan-li/CMMLU" _HOMEPAGE = "https://github.com/haonan-li/CMMLU"

View File

@@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,14 +22,15 @@ import pandas as pd
_CITATION = """\ _CITATION = """\
@article{hendryckstest2021, @article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding}, title={Measuring Massive Multitask Language Understanding},
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, author={Dan Hendrycks and Collin Burns and others},
journal={Proceedings of the International Conference on Learning Representations (ICLR)}, journal={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2021} year={2021}
} }
""" """
_DESCRIPTION = """\ _DESCRIPTION = """\
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021). Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart,
Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
""" """
_HOMEPAGE = "https://github.com/hendrycks/test" _HOMEPAGE = "https://github.com/hendrycks/test"

View File

@@ -15,6 +15,24 @@ Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose
By default, LLaMA-Factory uses all visible computing devices. By default, LLaMA-Factory uses all visible computing devices.
Basic usage:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
Advanced usage:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
learning_rate=1e-5 \
logging_steps=1
```
```bash
bash examples/train_lora/llama3_lora_sft.sh
```
## Examples ## Examples
### LoRA Fine-Tuning ### LoRA Fine-Tuning
@@ -34,8 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO Training #### DPO/ORPO/SimPO Training
@@ -47,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
#### Multimodal DPO/ORPO/SimPO Training #### Multimodal DPO/ORPO/SimPO Training
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
``` ```
#### Reward Modeling #### Reward Modeling
@@ -148,10 +165,18 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
``` ```
### Elastic and Fault-Tolerant Supervised Fine-Tuning on Multiple Nodes
To launch an elastic job with `MAX_RESTARTS` failures retries, run the following on at least `MIN_NNODES` nodes and at most `MAX_NNODES` nodes. `RDZV_ID` should be set as a unique job id (shared by all nodes participating in the job). See also [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html).
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@@ -178,10 +203,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
#### Batch Generation using vLLM Tensor Parallel #### Evaluation using vLLM's Multi-GPU Inference
``` ```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### Use CLI ChatBox #### Use CLI ChatBox
@@ -228,6 +254,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using Muon
```bash
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash
@@ -258,9 +290,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### Computing BLEU and ROUGE Scores
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@@ -15,6 +15,24 @@
LLaMA-Factory 默认使用所有可见的计算设备。 LLaMA-Factory 默认使用所有可见的计算设备。
基础用法:
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
高级用法:
```bash
CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml \
learning_rate=1e-5 \
logging_steps=1
```
```bash
bash examples/train_lora/llama3_lora_sft.sh
```
## 示例 ## 示例
### LoRA 微调 ### LoRA 微调
@@ -34,8 +52,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO 训练 #### DPO/ORPO/SimPO 训练
@@ -47,7 +64,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
#### 多模态 DPO/ORPO/SimPO 训练 #### 多模态 DPO/ORPO/SimPO 训练
```bash ```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml llamafactory-cli train examples/train_lora/qwen2_5vl_lora_dpo.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
@@ -89,6 +106,14 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
### 支持弹性和容错的多机指令监督微调
要启动一个支持弹性节点和容错的多机指令微调,在每个节点上执行以下命令。弹性节点数量范围为 `MIN_NNODES:MAX_NNODES`,每个节点最多允许因为错误重启 `MAX_RESTARTS` 次。`RDZV_ID` 应设置为一个唯一的作业 ID由参与该作业的所有节点共享。更多新可以参考官方文档 [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html)。
```bash
FORCE_TORCHRUN=1 MIN_NNODES=1 MAX_NNODES=3 MAX_RESTARTS=3 RDZV_ID=llamafactory MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
@@ -151,7 +176,7 @@ FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2_5vl_full_sft.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@@ -178,10 +203,11 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
### 推理 LoRA 模型 ### 推理 LoRA 模型
#### 使用 vLLM+TP 批量推理 #### 使用 vLLM 多卡推理评估
``` ```
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct --template llama3 --dataset alpaca_en_demo
python scripts/eval_bleu_rouge.py generated_predictions.jsonl
``` ```
#### 使用命令行对话框 #### 使用命令行对话框
@@ -228,6 +254,12 @@ llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
``` ```
#### 使用 Muon 进行全参数训练
```bash
llamafactory-cli train examples/extras/muon/qwen2_full_sft.yaml
```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash
@@ -258,9 +290,3 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```bash ```bash
bash examples/extras/fsdp_qlora/train.sh bash examples/extras/fsdp_qlora/train.sh
``` ```
#### 计算 BLEU 和 ROUGE 分数
```bash
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
```

View File

@@ -7,16 +7,16 @@ fsdp_config:
fsdp_backward_prefetch: BACKWARD_PRE fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true # offload may affect training speed fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_use_orig_params: true fsdp_use_orig_params: true
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: bf16 # or fp16 mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static
same_network: true same_network: true
tpu_env: [] tpu_env: []

View File

@@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true # offload may affect training speed
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -15,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2-1_5b/full/sft output_dir: saves/qwen2-1_5b/full/sft
@@ -22,6 +23,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -20,6 +20,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
@@ -27,6 +28,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -20,6 +20,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
@@ -27,6 +28,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b-pro/freeze/sft output_dir: saves/llama3-8b-pro/freeze/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -15,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b-mod/full/sft output_dir: saves/llama3-8b-mod/full/sft
@@ -22,6 +23,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,17 +1,16 @@
### model ### model
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: Qwen/Qwen2-1.5B-Instruct
trust_remote_code: true trust_remote_code: true
### method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: full
lora_rank: 8 use_muon: true
lora_target: all
### dataset ### dataset
dataset: mllm_demo dataset: identity,alpaca_en_demo
template: llava template: qwen
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
@@ -19,23 +18,23 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/llava1_5-7b/lora/sft output_dir: saves/qwen2-1_5b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
bf16: true bf16: true
ddp_timeout: 180000000 ddp_timeout: 180000000
resume_from_checkpoint: null
### eval ### eval
# val_size: 0.1 # val_size: 0.1

View File

@@ -18,10 +18,12 @@ cutoff_len: 2048
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/predict output_dir: saves/llama3-8b/lora/predict
overwrite_output_dir: true overwrite_output_dir: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1

View File

@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,4 +1,4 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,4 +1,4 @@
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/llama3-8b/full/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,5 +1,5 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3 template: llama3
infer_backend: huggingface # choices: [huggingface, vllm] infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true trust_remote_code: true

View File

@@ -1,5 +0,0 @@
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3
infer_backend: vllm
vllm_enforce_eager: true
trust_remote_code: true

View File

@@ -1,4 +0,0 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: llava
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -0,0 +1,4 @@
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm, sglang]
trust_remote_code: true

View File

@@ -1,4 +0,0 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl
infer_backend: huggingface # choices: [huggingface, vllm]
trust_remote_code: true

View File

@@ -6,5 +6,5 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_full_sft export_dir: output/llama3_full_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -6,7 +6,7 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_gptq export_dir: output/llama3_gptq
export_quantization_bit: 4 export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json export_quantization_dataset: data/c4_demo.jsonl
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -9,5 +9,5 @@ trust_remote_code: true
### export ### export
export_dir: output/llama3_lora_sft export_dir: output/llama3_lora_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -1,13 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft adapter_name_or_path: saves/qwen2_5vl-7b/lora/sft
template: qwen2_vl template: qwen2_vl
trust_remote_code: true trust_remote_code: true
### export ### export
export_dir: output/qwen2_vl_lora_sft export_dir: output/qwen2_5vl_lora_sft
export_size: 5 export_size: 5
export_device: cpu export_device: cpu # choices: [cpu, auto]
export_legacy_format: false export_legacy_format: false

View File

@@ -24,6 +24,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -0,0 +1,49 @@
### model
model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
freeze_vision_tower: true
freeze_multi_modal_projector: true
freeze_language_model: false
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity,alpaca_en_demo
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2_5vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -27,6 +27,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/kto output_dir: saves/llama3-8b/lora/kto
@@ -24,6 +25,7 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/ppo output_dir: saves/llama3-8b/lora/ppo
@@ -24,6 +25,7 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -24,6 +24,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -0,0 +1,36 @@
#!/bin/bash
set -x
MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct
llamafactory-cli train \
--model_name_or_path ${MODEL_PATH} \
--trust_remote_code \
--stage sft \
--do_train \
--finetuning_type lora \
--lora_rank 8 \
--lora_target all \
--dataset identity,alpaca_en_demo \
--template llama3 \
--cutoff_len 2048 \
--max_samples 1000 \
--overwrite_cache \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \
--output_dir saves/llama3-8b/lora/sft \
--logging_steps 10 \
--save_steps 500 \
--plot_loss \
--overwrite_output_dir \
--save_only_model false \
--report_to none \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--learning_rate 1e-4 \
--num_train_epochs 3.0 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--bf16 \
--ddp_timeout 180000000

View File

@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -26,6 +26,7 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -26,14 +26,21 @@ save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: llama3_8b_sft_lora ray_run_name: llama3_8b_sft_lora
ray_storage_path: ./saves ray_storage_path: ./saves
ray_num_workers: 4 # number of GPUs to use ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker: resources_per_worker:
GPU: 1 GPU: 1
placement_strategy: PACK # ray_init_kwargs:
# runtime_env:
# env_vars:
# <YOUR-ENV-VAR-HERE>: "<YOUR-ENV-VAR-HERE>"
# pip:
# - emoji
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,21 +1,20 @@
# pip install git+https://github.com/hiyouga/transformers.git@llama4_train
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: meta-llama/Llama-4-Scout-17B-16E-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
### method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: lora
freeze_vision_tower: true # choices: [true, false] lora_rank: 8
freeze_multi_modal_projector: true # choices: [true, false] lora_target: all
freeze_language_model: false # choices: [true, false]
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset ### dataset
dataset: mllm_demo,identity,alpaca_en_demo dataset: mllm_demo,identity,alpaca_en_demo
template: qwen2_vl template: llama4
cutoff_len: 2048 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
@@ -23,17 +22,18 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/full/sft output_dir: saves/llama4-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 1.0e-5 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
@@ -42,6 +42,7 @@ ddp_timeout: 180000000
resume_from_checkpoint: null resume_from_checkpoint: null
### eval ### eval
# eval_dataset: alpaca_en_demo
# val_size: 0.1 # val_size: 0.1
# per_device_eval_batch_size: 1 # per_device_eval_batch_size: 1
# eval_strategy: steps # eval_strategy: steps

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
@@ -23,12 +23,13 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/lora/dpo output_dir: saves/qwen2_5vl-7b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct model_name_or_path: Qwen/Qwen2.5-VL-7B-Instruct
image_max_pixels: 262144 image_max_pixels: 262144
video_max_pixels: 16384 video_max_pixels: 16384
trust_remote_code: true trust_remote_code: true
@@ -21,12 +21,13 @@ preprocessing_num_workers: 16
dataloader_num_workers: 4 dataloader_num_workers: 4
### output ### output
output_dir: saves/qwen2_vl-7b/lora/sft output_dir: saves/qwen2_5vl-7b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,7 +1,7 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
quantization_method: bitsandbytes quantization_method: bnb
double_quantization: false double_quantization: false
trust_remote_code: true trust_remote_code: true
@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -1,7 +1,7 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)] quantization_method: bnb # choices: [bnb, hqq, eetq]
trust_remote_code: true trust_remote_code: true
### method ### method
@@ -18,6 +18,7 @@ cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
dataloader_num_workers: 4
### output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
@@ -25,6 +26,8 @@ logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1

View File

@@ -19,13 +19,36 @@ dynamic = [
] ]
[tool.ruff] [tool.ruff]
target-version = "py38" target-version = "py39"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
[tool.ruff.lint] [tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] ignore = [
select = ["C", "E", "F", "I", "W"] "C408", # collection
"C901", # complex
"E501", # line too long
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module
"D101", # no doc public class
"D102", # no doc public method
"D103", # no doc public function
"D104", # no doc public package
"D105", # no doc magic method
"D107", # no doc __init__
]
extend-select = [
"C", # complexity
"E", # error
"F", # pyflakes
"I", # isort
"W", # warning
"UP", # pyupgrade
"D", # pydocstyle
"PT009", # pytest assert
"RUF022", # sort __all__
]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
@@ -38,9 +61,12 @@ known-third-party = [
"peft", "peft",
"torch", "torch",
"transformers", "transformers",
"trl" "trl",
] ]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.format] [tool.ruff.format]
quote-style = "double" quote-style = "double"
indent-style = "space" indent-style = "space"
@@ -56,10 +82,14 @@ conflicts = [
], ],
[ [
{ extra = "torch-npu" }, { extra = "torch-npu" },
{ extra = "liger-kernel" }, { extra = "vllm" },
], ],
[ [
{ extra = "torch-npu" }, { extra = "torch-npu" },
{ extra = "sglang" },
],
[
{ extra = "vllm" }, { extra = "vllm" },
] { extra = "sglang" },
],
] ]

View File

@@ -1,26 +1,27 @@
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' transformers>=4.45.0,<=4.52.4,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0; sys_platform != 'darwin'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0; sys_platform == 'darwin'
datasets>=2.16.0,<=3.2.0 datasets>=2.16.0,<=3.6.0
accelerate>=0.34.0,<=1.2.1 accelerate>=0.34.0,<=1.7.0
peft>=0.11.1,<=0.12.0 peft>=0.14.0,<=0.15.2
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<=0.21.0 tokenizers>=0.19.0,<=0.21.1
gradio>=4.38.0,<=5.21.0 gradio>=4.38.0,<=5.31.0
pandas>=2.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
tiktoken tiktoken
protobuf protobuf
uvicorn uvicorn
pydantic
fastapi fastapi
sse-starlette sse-starlette
matplotlib>=3.7.0 matplotlib>=3.7.0
fire fire
omegaconf
packaging packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0
pydantic<=2.10.6
pandas>=2.0.0
av av
librosa librosa
tyro<0.9.0 tyro<0.9.0

View File

@@ -23,8 +23,8 @@ require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")), api_key="{}".format(os.getenv("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
) )
messages = [] messages = []
messages.append( messages.append(

View File

@@ -14,7 +14,6 @@
import json import json
import os import os
from typing import Sequence
from openai import OpenAI from openai import OpenAI
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@@ -23,7 +22,7 @@ from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: def calculate_gpa(grades: list[str], hours: list[int]) -> float:
grade_to_score = {"A": 4, "B": 3, "C": 2} grade_to_score = {"A": 4, "B": 3, "C": 2}
total_score, total_hour = 0, 0 total_score, total_hour = 0, 0
for grade, hour in zip(grades, hours): for grade, hour in zip(grades, hours):
@@ -34,8 +33,8 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")), api_key="{}".format(os.getenv("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
) )
tools = [ tools = [
{ {

View File

@@ -15,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from typing import Any
import fire import fire
import torch import torch
@@ -29,13 +29,13 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool): def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu", weights_only=True)
baichuan2_state_dict.update(shard_weight) baichuan2_state_dict.update(shard_weight)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"): for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
if "W_pack" in key: if "W_pack" in key:
proj_size = value.size(0) // 3 proj_size = value.size(0) // 3
@@ -75,7 +75,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None) llama2_config_dict.pop("auto_map", None)
@@ -94,8 +94,8 @@ def llamafy_baichuan2(
shard_size: str = "2GB", shard_size: str = "2GB",
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Convert the Baichuan2-7B model in the same format as LLaMA2-7B.
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
""" """

View File

@@ -15,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from typing import Any
import fire import fire
import torch import torch
@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str: def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key) qwen_state_dict[key] = f.get_tensor(key)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"): for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None: if torch_dtype is None:
@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu" llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
@@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size: str = "2GB", shard_size: str = "2GB",
save_safetensors: bool = False, save_safetensors: bool = False,
): ):
r""" r"""Convert the Qwen models in the same format as LLaMA2.
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
""" """

View File

@@ -0,0 +1,39 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
if __name__ == "__main__":
vision_config = Llama4VisionConfig(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
text_config = Llama4TextConfig(
hidden_size=512,
intermediate_size=1024,
intermediate_size_mlp=1024,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
head_dim=512 // 8,
num_local_experts=2,
)
config = Llama4Config(vision_config=vision_config, text_config=text_config)
model = Llama4ForConditionalGeneration._from_config(config)
model.save_pretrained("tiny-llama4")

View File

@@ -0,0 +1,79 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import time
import fire
from datasets import load_dataset
try:
import jieba # type: ignore
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
from rouge_chinese import Rouge # type: ignore
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.")
raise
def compute_metrics(sample):
hypothesis = list(jieba.cut(sample["predict"]))
reference = list(jieba.cut(sample["label"]))
bleu_score = sentence_bleu(
[list(sample["label"])],
list(sample["predict"]),
smoothing_function=SmoothingFunction().method3,
)
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
metric_result = {}
for k, v in result.items():
metric_result[k] = round(v["f"] * 100, 4)
metric_result["bleu-4"] = round(bleu_score * 100, 4)
return metric_result
def main(filename: str):
start_time = time.time()
dataset = load_dataset("json", data_files=filename, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
print(f"{task}: {sum(scores) / len(scores):.4f}")
average_score[task] = sum(scores) / len(scores)
with open("predictions_score.json", "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to predictions_score.json")
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -18,7 +18,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import fire import fire
import torch import torch
@@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB", shard_size: str = "5GB",
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
""" """
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.") raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
@@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
state_dict = model.state_dict() state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict() output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if f".{i:d}." in key: if f".{i:d}." in key:

View File

@@ -38,8 +38,8 @@ def quantize_loftq(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
@@ -72,7 +72,7 @@ def quantize_loftq(
print(f"Adapter weights saved in {loftq_dir}") print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")

View File

@@ -37,8 +37,8 @@ def quantize_pissa(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
@@ -67,7 +67,7 @@ def quantize_pissa(
print(f"Adapter weights saved in {pissa_dir}") print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")

136
scripts/qwen_omni_merge.py Normal file
View File

@@ -0,0 +1,136 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Why we need this script for qwen_omni?
Because the qwen_omni model is constructed by two parts:
1. [Thinker]:[audio_encoder, vision_encoder, LLM backbone], which our repository does support to post-training.
2. [Talker]: [audio_decoder, wave_model], which is not supported to post-training without specific tokenizer.
When we post-training the model, we exactly train the [Thinker] part, and the [Talker] part is dropped.
So, to get the complete model, we need to merge the [Talker] part back to the [Thinker] part.
LoRA mode: [Thinker + LoRA weights] + [Original Talker] -> [Omni model]
Full mode: [Thinker] + [Original Talker] -> [Omni model]
For Processor, we do saved the processor from trained model instead of the original model.
"""
import os
import shutil
import fire
from peft import PeftModel
from transformers import (
AutoProcessor,
Qwen2_5OmniForConditionalGeneration, # type: ignore
Qwen2_5OmniThinkerForConditionalGeneration,
)
def merge_lora(
base_model_path: str,
lora_checkpoint_path: str,
extra_file: str = "spk_dict.pt",
submodule_name: str = "thinker",
save_path: str = "./merged_model_checkpoint",
):
"""Load the original model, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations.
Args:
base_model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
print("Successfully loaded the original model.")
# 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.")
# 3. Load the LoRA weights onto the extracted submodule
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
processor = AutoProcessor.from_pretrained(lora_checkpoint_path)
print("LoRA weights and processor loaded successfully.")
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule = lora_model.merge_and_unload()
print("LoRA weights merged successfully.")
# 5. Replace the original submodule with the merged submodule in the model
setattr(model, submodule_name, merged_submodule)
# 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Merged model and tokenizer saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
def save_full_model(
saved_thinker_path: str,
base_model_path: str,
save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt",
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
Then save the complete model along with its tokenizer and processor configuration.
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
# 1. Load the saved thinker module and the original model
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
saved_thinker_path, torch_dtype="auto", device_map="cpu"
)
base_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
base_model_path, torch_dtype="auto", device_map="cpu"
)
base_model.thinker = thinker
# 2. Save the complete model along with its tokenizer and processor configuration
processor = AutoProcessor.from_pretrained(saved_thinker_path)
base_model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Merged model and processor saved to {save_path}.")
# 3. Copy the extra file from the base model directory to the save_path
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
if __name__ == "__main__":
fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})

View File

@@ -29,8 +29,8 @@ def calculate_flops(
seq_length: int = 512, seq_length: int = 512,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r""" r"""Calculate the flops of pre-trained models.
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
""" """
with get_accelerator().device(0): with get_accelerator().device(0):

View File

@@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False, packing: bool = False,
): ):
r""" r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16 python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
""" """
@@ -89,9 +89,8 @@ def calculate_lr(
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr lr = lr / 6.0 if is_mistral_or_gemma else lr
print( print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format( f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
lr, valid_ratio * 100, token_batch_size f"and effective token batch size {token_batch_size:.2f}"
)
) )

View File

@@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute: bool = False, include_recompute: bool = False,
include_flashattn: bool = False, include_flashattn: bool = False,
) -> int: ) -> int:
r""" r"""Calculate the FLOPs of model per forward/backward pass."""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None) hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None) vocab_size = getattr(config, "vocab_size", None)
@@ -86,9 +84,7 @@ def compute_model_flops(
def compute_device_flops(world_size: int) -> float: def compute_device_flops(world_size: int) -> float:
r""" r"""Calculate the FLOPs of the device capability per second."""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name: if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size return 989 * 1e12 * world_size
@@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel: bool = False, liger_kernel: bool = False,
unsloth_gc: bool = False, unsloth_gc: bool = False,
) -> float: ) -> float:
r""" r"""Calculate MFU for given model and hyper-params.
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
""" """
args = { args = {

View File

@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence from typing import Any, Literal, Optional
import fire import fire
import torch import torch
@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
train_on_prompt: bool = False train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
r""" r"""Pad batched data to the longest sequence in the batch."""
Pads batched data to the longest sequence in the batch.
"""
chosen_features = [] chosen_features = []
for feature in features: for feature in features:
chosen_features.append( chosen_features.append(
@@ -68,8 +64,8 @@ def calculate_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r""" r"""Calculate the ppl on the dataset of the pre-trained models.
Calculates the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
""" """
@@ -111,17 +107,17 @@ def calculate_ppl(
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0 total_ppl = 0
perplexities = [] perplexities = []
batch: Dict[str, "torch.Tensor"] batch: dict[str, torch.Tensor]
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"): for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device) batch = batch.to(model.device)
outputs = model(**batch) outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:] shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1) flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1) flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item() total_ppl += sentence_logps.exp().sum().item()

View File

@@ -29,8 +29,8 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r""" r"""Calculate the distribution of the input lengths in the dataset.
Calculates the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
""" """

View File

@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import json import json
from typing import Optional from typing import Optional
import fire import fire
from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import check_version, get_device_count from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_vllm_available from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -47,16 +49,20 @@ def vllm_infer(
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None,
enable_thinking: bool = True,
seed: Optional[int] = None, seed: Optional[int] = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
video_fps: float = 2.0,
video_maxlen: int = 128,
batch_size: int = 1024,
): ):
r""" r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
""" """
check_version("vllm>=0.4.3,<=0.7.3")
if pipeline_parallel_size > get_device_count(): if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
@@ -70,6 +76,8 @@ def vllm_infer(
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
max_samples=max_samples, max_samples=max_samples,
preprocessing_num_workers=16, preprocessing_num_workers=16,
default_system=default_system,
enable_thinking=enable_thinking,
vllm_config=vllm_config, vllm_config=vllm_config,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@@ -84,26 +92,28 @@ def vllm_infer(
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
llm = LLM(**engine_args)
# load datasets
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module) dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
train_dataset = dataset_module["train_dataset"]
inputs, prompts, labels = [], [], []
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)
}
else:
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
@@ -120,30 +130,68 @@ def vllm_infer(
else: else:
lora_request = None lora_request = None
engine_args = { # Store all results in these lists
"model": model_args.model_name_or_path, all_prompts, all_preds, all_labels = [], [], []
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict): # Add batch process to avoid the issue of too many files opened
engine_args.update(model_args.vllm_config) for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) for j in range(len(batch["input_ids"])):
preds = [result.outputs[0].text for result in results] if batch["images"][j] is not None:
image = batch["images"][j]
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"]
}
elif batch["videos"][j] is not None:
video = batch["videos"][j]
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
video,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_fps=video_fps,
video_maxlen=video_maxlen,
)["videos"]
}
elif batch["audios"][j] is not None:
audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios(
audio,
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
skip_special_tokens=skip_special_tokens,
)
)
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
gc.collect()
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
print("*" * 70) print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.") print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70) print("*" * 70)

View File

@@ -14,7 +14,6 @@
import os import os
import re import re
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -27,14 +26,14 @@ def get_version() -> str:
return version return version
def get_requires() -> List[str]: def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
def get_console_scripts() -> List[str]: def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"] console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]: if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main") console_scripts.append("lmf = llamafactory.cli:main")
@@ -43,23 +42,22 @@ def get_console_scripts() -> List[str]:
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel"], "liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
"eetq": ["eetq"], "eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.7.3"], "vllm": ["vllm>=0.4.3,<=0.8.6"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"apollo": ["apollo-torch"], "apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"], "adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"minicpm_v": [ "minicpm_v": [
"soundfile", "soundfile",
"torchvision", "torchvision",
@@ -73,7 +71,7 @@ extra_require = {
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
"swanlab": ["swanlab"], "swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest"], "dev": ["pre-commit", "ruff", "pytest", "build"],
} }
@@ -82,11 +80,11 @@ def main():
name="llamafactory", name="llamafactory",
version=get_version(), version=get_version(),
author="hiyouga", author="hiyouga",
author_email="hiyouga AT buaa.edu.cn", author_email="hiyouga@buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License", license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory", url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"}, package_dir={"": "src"},

View File

@@ -12,29 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
r""" r"""Efficient fine-tuning of large language models.
Efficient fine-tuning of large language models.
Level: Level:
api, webui > chat, eval, train > data, model > hparams > extras api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1

View File

@@ -16,9 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Annotated, Optional
from typing_extensions import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName

View File

@@ -18,11 +18,12 @@ import json
import os import os
import re import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
@@ -55,7 +56,7 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -71,7 +72,14 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: ) -> tuple[
list[dict[str, str]],
Optional[str],
Optional[str],
Optional[list["ImageInput"]],
Optional[list["VideoInput"]],
Optional[list["AudioInput"]],
]:
if is_env_enabled("API_VERBOSE", "1"): if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
@@ -79,7 +87,8 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content
else: else:
system = None system = None
@@ -87,7 +96,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
images = [] images, videos, audios = [], [], []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -106,7 +115,7 @@ def _process_request(
for input_item in message.content: for input_item in message.content:
if input_item.type == "text": if input_item.type == "text":
text_content += input_item.text text_content += input_item.text
else: elif input_item.type == "image_url":
text_content += IMAGE_PLACEHOLDER text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
@@ -117,6 +126,32 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB")) images.append(Image.open(image_stream).convert("RGB"))
elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb")
else: # web uri
video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream)
elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb")
else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
)
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else: else:
@@ -131,7 +166,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, images or None return input_messages, system, tools, images or None, videos or None, audios or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@@ -150,17 +185,20 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request) input_messages, system, tools, images, videos, audios = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
images, images,
videos,
audios,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n, num_return_sequences=request.n,
repetition_penalty=request.presence_penalty,
stop=request.stop, stop=request.stop,
) )
@@ -201,7 +239,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request) input_messages, system, tools, images, videos, audios = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -216,10 +254,13 @@ async def create_stream_chat_completion_response(
system, system,
tools, tools,
images, images,
videos,
audios,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
repetition_penalty=request.presence_penalty,
stop=request.stop, stop=request.stop,
): ):
if len(new_token) != 0: if len(new_token) != 0:

View File

@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]: def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1

View File

@@ -14,7 +14,7 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class ModelList(BaseModel): class ModelList(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
data: List[ModelCard] = [] data: list[ModelCard] = []
class Function(BaseModel): class Function(BaseModel):
@@ -56,7 +56,7 @@ class Function(BaseModel):
class FunctionDefinition(BaseModel): class FunctionDefinition(BaseModel):
name: str name: str
description: str description: str
parameters: Dict[str, Any] parameters: dict[str, Any]
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
@@ -70,38 +70,42 @@ class FunctionCall(BaseModel):
function: Function function: Function
class ImageURL(BaseModel): class URL(BaseModel):
url: str url: str
detail: Literal["auto", "low", "high"] = "auto"
class MultimodalInputItem(BaseModel): class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"] type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None text: Optional[str] = None
image_url: Optional[ImageURL] = None image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: list[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None do_sample: Optional[bool] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: int = 1 n: int = 1
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, list[str]]] = None
stream: bool = False stream: bool = False
@@ -128,7 +132,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
@@ -137,12 +141,12 @@ class ChatCompletionStreamResponse(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionStreamResponseChoice] choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: List[str] messages: list[str]
max_length: Optional[int] = None max_length: Optional[int] = None
@@ -150,4 +154,4 @@ class ScoreEvaluationResponse(BaseModel):
id: str id: str
object: Literal["score.evaluation"] = "score.evaluation" object: Literal["score.evaluation"] = "score.evaluation"
model: str model: str
scores: List[float] scores: list[float]

View File

@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -36,8 +37,7 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r""" r"""Base class for inference engine of chat models.
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores(). Must implements async methods: chat(), stream_chat() and get_scores().
""" """
@@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
template: "Template" template: "Template"
generating_args: Dict[str, Any] generating_args: dict[str, Any]
@abstractmethod @abstractmethod
def __init__( def __init__(
@@ -57,50 +57,42 @@ class BaseEngine(ABC):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
r""" r"""Initialize an inference engine."""
Initializes an inference engine.
"""
... ...
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
... ...

Some files were not shown because too many files have changed in this diff Show More