187 Commits

Author SHA1 Message Date
hiyouga
c0c387e4db release v0.8.0
Former-commit-id: 004db680b9e3996ec511ee818df6c0c02bf13603
2024-06-08 05:20:54 +08:00
hiyouga
ae60ea15da add ultrafeedback and fineweb #4085 #4132
Former-commit-id: 968e4992e2f2a3ccba73e8668f1654ddc6eb0034
2024-06-08 02:42:34 +08:00
hiyouga
72cd1123a8 fix ci
Former-commit-id: 3f4d293fd861d765edb2040f80d16f99a5e1e3c6
2024-06-08 02:00:44 +08:00
hiyouga
1364190a66 fix ci
Former-commit-id: 95aceebd61d195be5c980a919c12c59b56722898
2024-06-08 01:57:36 +08:00
hiyouga
6d17c59090 add ci
Former-commit-id: 3ea3acdadaa54abe33d93538580196cfdd91ee56
2024-06-08 01:48:30 +08:00
hiyouga
e0f2c0b5dc init unittest
Former-commit-id: 1c6f21cb8878ced043fe0b27c72cad2ef6ee990e
2024-06-08 01:35:58 +08:00
hiyouga
073e34855d Delete .readthedocs.yaml
Former-commit-id: dd3ee514216a9a329519c58d79208040adcad126
2024-06-08 00:58:10 +08:00
hiyouga
ff9ba70bb8 reorganize adapter code
Former-commit-id: b26c2df9d97f4efffccbf7d28de13619b43f10dd
2024-06-08 00:47:23 +08:00
hoshi-hiyouga
adbebb0e3f fix #4139
Former-commit-id: c025a4d74f293c14c2705e68af20a82a84608520
2024-06-08 00:45:02 +08:00
hiyouga
3f6b3eed98 add resume args in webui
Former-commit-id: 1d86ad768b1f36e54b4c2a9f18f6ea5a7df04c90
2024-06-08 00:22:16 +08:00
hiyouga
f45e81e186 fix #4137
Former-commit-id: cdc0d6f5a2e5040e145c82c4801f37bd76529047
2024-06-07 19:16:06 +08:00
hiyouga
ba648fd003 tiny fix
Former-commit-id: 0621bcad1dfbe8ce2464f741d4256c5df2a8d1b6
2024-06-07 05:19:21 +08:00
hiyouga
b0e5a76f4c fix ppo trainer save zero3 model
accelerator.get_state_dict(ds_model) should be called at all ranks


Former-commit-id: 3a0f60f0aa072531e4ae5819ec00c8fa42aa0913
2024-06-07 05:14:19 +08:00
hiyouga
8692796c9b fix ppo in trl 0.8.6
Former-commit-id: 5e0d66a0d80b4bd4a8506e2317209d8fb9d25ff6
2024-06-07 04:48:29 +08:00
hiyouga
d0edcde4ea fix #4120
Former-commit-id: 2a44da678a5e360a9c0f9056397ac9e801329321
2024-06-07 04:18:05 +08:00
hiyouga
8c4c2e580c update data processors
Former-commit-id: 04b138cbcb8b9a72e4bbda6c65843bb459e525e7
2024-06-07 04:15:40 +08:00
hoshi-hiyouga
07f33e7641 Merge pull request #4009 from AlongWY/main
supervised packing with greedy knapsack algorithm

Former-commit-id: 5ded166b39a75a98ded5733678f5a1eab7d4cc71
2024-06-07 03:48:46 +08:00
hoshi-hiyouga
1998c641af Update supervised.py
Former-commit-id: 04b6c2a754e602e0b698cfe6c255c2f2486d8865
2024-06-07 03:42:08 +08:00
hoshi-hiyouga
be1e5f9d62 Update supervised.py
Former-commit-id: 49993c4f4e1f871a22ff0196afe60026b668a4dc
2024-06-07 03:38:23 +08:00
hoshi-hiyouga
fdeec6db52 Update supervised.py
Former-commit-id: 67625b5278a839c12a3e4245f9e90af67d8b11b4
2024-06-07 03:38:04 +08:00
hiyouga
a4d335b42f add qwen2 models
Former-commit-id: 49cb694d02c876e3740a003a8b332349f4310ad3
2024-06-07 00:22:57 +08:00
hiyouga
fcb134e144 rename files
Former-commit-id: e1a8431770fc36c0c9ee7fed4abbc3d7fdcc5efd
2024-06-07 00:09:06 +08:00
hiyouga
a47e24222a add DISABLE_TORCHRUN option
Former-commit-id: bcc574b479c2101438723aadead42743d4378776
2024-06-06 23:44:58 +08:00
hoshi-hiyouga
b96b995620 Merge pull request #4082 from MengqingCao/bugfix
Fix #4077

Former-commit-id: 288028c3fb6bb1b58d1b7f4e8b90108c9bbf27d1
2024-06-06 23:38:40 +08:00
hoshi-hiyouga
c231706aa5 Update cli.py
Former-commit-id: 32190507534adf5f505858b3af2b592ca6568ac7
2024-06-06 23:38:09 +08:00
hiyouga
35b5117a59 fix ppo+zero3 #3108
Former-commit-id: 33a93cc29e3e57bf001515000c0a70c112573dea
2024-06-06 23:30:07 +08:00
hiyouga
80f716bc10 fix torch gc
Former-commit-id: e173799d057598e5692a407601c30d8ce1513461
2024-06-06 20:30:25 +08:00
hiyouga
ca95e98ca0 fix ppo dataset bug #4012
Former-commit-id: 7fc51b2e93698ae5e012566af8481f4d861c873d
2024-06-06 19:03:20 +08:00
hiyouga
d5559461c1 update trainers
Former-commit-id: b7f6c4a171293cf4f3e88f15a811f847342f84ee
2024-06-06 18:45:49 +08:00
hiyouga
f4acd81e2f fix base64 image read #4061
Former-commit-id: 66ccb2a27a04296b4600f2c85f428071bf14eeb0
2024-06-06 17:29:19 +08:00
hiyouga
31feb6e26c update readme
Former-commit-id: cc331fa2d28afe081937c50ea83d63add21d4e3a
2024-06-06 16:59:18 +08:00
hiyouga
7d5c0a069c update readme
Former-commit-id: fb1f709af5199976e63d7188e088e33c75d19bfe
2024-06-06 16:25:42 +08:00
hiyouga
937f49ec3d lora modules: all by default
Former-commit-id: 52c4ae87c7f4312704c31ef26b079b2c5b95ea5f
2024-06-06 03:53:28 +08:00
hiyouga
abc2a73a33 add codestral 22B
Former-commit-id: b011c7f527a57cb1d21c4e2c9631c2fb62bb835e
2024-06-06 03:42:50 +08:00
hiyouga
5e1bf7572c lint
Former-commit-id: 9030501eaef97ea249347198272adf0d709503ec
2024-06-06 03:33:44 +08:00
hoshi-hiyouga
8fdb32d0a3 Merge pull request #4066 from injet-zhou/main
add throughput entry to training log

Former-commit-id: d2816f343f405f3fab09f2a8eade774b886e8f92
2024-06-06 03:32:04 +08:00
hoshi-hiyouga
c709d5f7db Merge pull request #4080 from MengqingCao/npu
Add npu option for model exporting

Former-commit-id: 07fc67193ef6bcb8e8a392aff0c57a2eb36832bf
2024-06-06 03:15:44 +08:00
hoshi-hiyouga
f5b2749ec2 Update export.py
Former-commit-id: 694833c1104d13929d4f181f014a121f25955dc5
2024-06-06 03:14:46 +08:00
hoshi-hiyouga
ee5853c565 Update model_args.py
Former-commit-id: 09c0afd94a8a5f5b45a61b32c983d50e1b9e2941
2024-06-06 03:14:23 +08:00
hoshi-hiyouga
6ec6df8a5f Merge pull request #4053 from hzhaoy/feature/add_select_config_file
Support selecting saved configuration files

Former-commit-id: 568ef3cf2a793f268cbe01c39dec418a13e61ecd
2024-06-06 03:06:03 +08:00
hiyouga
fc95800840 add vllm_dtype arg #3387 #3717
Former-commit-id: a0dd3a6351bb78541d40fec1d2fc457d803c86a4
2024-06-06 02:53:27 +08:00
hiyouga
765715af21 support train from scratch #4033 #4075
Former-commit-id: 1290b9d01077e62f8de7a23637daa2586cc82bfa
2024-06-06 02:43:19 +08:00
hiyouga
639a7f6796 support image input in api #3971 #4061
Former-commit-id: c70aaf763ef22fb83ce3635e8ffd5ec4c89c1cb0
2024-06-06 02:29:55 +08:00
hiyouga
35379c7c0e update train hparams
Former-commit-id: 1ca9fce55b55bf209f4b76152b586731932a3f39
2024-06-06 01:49:20 +08:00
hiyouga
d992f5353f fix setup
Former-commit-id: b2b80d434fcc0c3838d229098e1c21d26632204c
2024-06-06 01:39:02 +08:00
hiyouga
875eef45f3 add llamafactory-cli env
Former-commit-id: 1df077184845ff5f394b9324d46f8c382869e590
2024-06-06 01:28:14 +08:00
hiyouga
556a4aa972 fix #4090
Former-commit-id: d9f15f30a8f4bc64778a5c96baeb6801700d7a2c
2024-06-06 00:50:32 +08:00
MengqingCao
8dc1969111 modify export_device option
Former-commit-id: b2fc4a5499e21a5b9622c2285402efef6e27a74d
2024-06-05 09:37:36 +00:00
hiyouga
b74c229498 fix #4079
Former-commit-id: fda732d7f4616373844c97beff416880260f49db
2024-06-05 16:56:54 +08:00
hiyouga
3dbca466fd update readme
Former-commit-id: 02d34db29a7a35c25711d49e98fd3167a2f4dfe7
2024-06-05 16:32:32 +08:00
MengqingCao
ce6f7fdb82 fix #4077
Former-commit-id: fedbe92f3b56294acc6c49f9a51e369cf2de3ead
2024-06-05 08:03:30 +00:00
hiyouga
7528bc1bc0 support glm-4
Former-commit-id: a10f4718fbf3f3c89dc7eb31cb8e1a46ca6adda5
2024-06-05 15:16:38 +08:00
MengqingCao
9dd5f7d642 add npu for model export
Former-commit-id: ce020b6eb3f35c1db37ee4835e694eddcd0f59b0
2024-06-05 07:06:40 +00:00
faddddeout
99ecb0daaf add throughput entry to log
Former-commit-id: 691f999f64c7bac78761e4354f89816d2f0d46fc
2024-06-04 11:04:29 +00:00
hzhaoy
39d8d7995a add: support selecting saved configuration files and loading training parameters
Former-commit-id: 5c9b17c1dc9093da0ea813642bce9b5c9ae96274
2024-06-04 10:33:43 +08:00
hiyouga
2ac2cde03e tiny fix
Former-commit-id: f9d50501aac1f60a3b445ca3fee9aa60995461ee
2024-06-04 00:31:10 +08:00
hiyouga
aa6c3766de fix #3873
Former-commit-id: 1ac325b4d682bb493573c18bb0b67ceae8d0d372
2024-06-04 00:21:50 +08:00
hiyouga
f4f5d7e3ce fix #3992
Former-commit-id: a48321fbf5196b88a11106cf74a74fbcea2ea50b
2024-06-04 00:17:36 +08:00
hiyouga
efbf6018d3 fix abort in webui DDP mode
Former-commit-id: b90ac72d753b13a3eed9cb8b898fac2f2fe5153f
2024-06-04 00:10:24 +08:00
hoshi-hiyouga
1090bb8bf3 Merge pull request #3987 from injet-zhou/main
Fix cann't interrupt training when using multi GPUs in webui

Former-commit-id: 455bb158b0e600723d2afaa2070b71178f2f5188
2024-06-04 00:04:07 +08:00
hiyouga
26bc79f971 fix #4043
Former-commit-id: 67af68f4fc5232760c57b3a0ae780628da09db6a
2024-06-03 23:30:37 +08:00
hiyouga
4c1f015eca remove gc warnings in DPO&KTO
Former-commit-id: b649bdcbafb464a638387429b770fe258b41f8af
2024-06-03 22:53:54 +08:00
hoshi-hiyouga
0655a183d3 Merge pull request #4045 from enji-zhou/feature/add_kto
fix KTO Trainer Sampler

Former-commit-id: 8e235beb9cf4939c06ccb753b047326a9839e77f
2024-06-03 22:09:25 +08:00
hoshi-hiyouga
7754024e9b Update trainer.py
Former-commit-id: 8565d4b43db905374c328ae57c71fc226980d14f
2024-06-03 22:08:38 +08:00
enji.zhou
b4913569a8 fix KTO Trainer Sampler
Former-commit-id: 39eb1bfa272011554322e9bb2534f83b68282a70
2024-06-03 21:32:38 +08:00
hoshi-hiyouga
eae9f09ca8 Merge pull request #4006 from Uminosachi/scheduler-kwargs
Set scheduler_specific_kwargs to get_scheduler

Former-commit-id: c6ed1955fd8990ddb960750913c9d8b13fe0ace3
2024-06-03 19:27:53 +08:00
hiyouga
8264e5ceaa update placeholder in issue template
Former-commit-id: 5503a90d7e38273b67129e0b9eb62bd1fd23154f
2024-06-03 19:24:10 +08:00
hoshi-hiyouga
b76f319e45 Merge pull request #4011 from statelesshz/issue-template
Update bug-report.yml

Former-commit-id: 1fbc46f45ae4e673f0b20b5eacab3d81d1053807
2024-06-03 19:20:43 +08:00
hiyouga
82d744716a fix #4005 #4013
Former-commit-id: 8608fa268cde5cddf8d0c6c2eb2cb5fa246c1831
2024-06-03 19:12:29 +08:00
hoshi-hiyouga
1a3764ab8f Merge pull request #4007 from xu-song/patch-3
Update model_args.py

Former-commit-id: d88b3a0f2707bcc964f642d348295b99f7c796f8
2024-06-03 18:54:37 +08:00
hiyouga
d2ede9d393 fix #4022
Former-commit-id: 9541f2f1f1b7d7877eb734f051048e52003a3430
2024-06-03 18:38:36 +08:00
hiyouga
5690f513fc bump versions
transformers 4.37.2->4.41.2
datasets 2.14.3->2.16.0
accelerate 0.27.2->0.30.1
peft 0.10.0->0.11.1
trl 0.8.1->0.8.6


Former-commit-id: 5f1e041f7295bf42a41dd4d9e7f0c42fcc37fed2
2024-06-03 18:29:38 +08:00
hiyouga
123a845209 fix data loader hint
Former-commit-id: 25b56126a11591b0155e2f72b673dd8f45a6c8c9
2024-06-03 18:28:27 +08:00
ylfeng
b1b7d735b3 remove empty line
Former-commit-id: 3164710971a6d6545629f5bf133f98de5ff0991a
2024-05-31 21:43:08 +08:00
ylfeng
230c69f7ce fix eos
Former-commit-id: 6e236c952958cbfe50b5dcb7b8eff6aea8477922
2024-05-31 21:40:41 +08:00
ylfeng
bfc43558ef supervised packing with greedy knapsack algorithm
Former-commit-id: 24d12396c9aabd49da0b08719068f24679111cc6
2024-05-31 15:33:54 +08:00
Xu Song
f2ae2cc04d Update model_args.py
Former-commit-id: f1e018587e5722e41962abd60f74043a3e55f692
2024-05-31 14:35:48 +08:00
statelesshz
6e9c03f958 Update bug-report.yml
Former-commit-id: a8561502360c1e247eeacb46b77ffbcf3387c482
2024-05-31 13:18:18 +08:00
Uminosachi
2696f614a7 Set scheduler_specific_kwargs to get_scheduler
Former-commit-id: f04e70dfab44480ef4c015c06470443237f69ba9
2024-05-31 13:45:39 +09:00
hiyouga
070b944895 update readme
Former-commit-id: 3b92d8c2ddb288b849f38e573ca168cab23315d2
2024-05-30 16:40:17 +08:00
faddddeout
f5f091d390 fix cann't interrupt training when using multi GPUs in webui
Former-commit-id: a7fb02d52bc202c958490aa7081252be5d9eff50
2024-05-30 08:39:21 +00:00
hiyouga
14ab14a0e6 fix #3837
Former-commit-id: 72965aa3f13a9c085c29781b6790d80d00a545d8
2024-05-30 00:52:26 +08:00
hoshi-hiyouga
4f7c850115 Merge pull request #3829 from seanzhang-zhichen/add_dataset_sample_num
Add dataset sample num

Former-commit-id: ab38cf74ce48ea4f1800e077ca287f2eb9336135
2024-05-30 00:25:45 +08:00
hoshi-hiyouga
391eca66cf Update loader.py
Former-commit-id: 0aa59322906d91c5e385c9c02ebb5dd64ba060f3
2024-05-30 00:20:20 +08:00
hoshi-hiyouga
a67199246d Update loader.py
Former-commit-id: aa7f335e3ad5a78e4ed5f99c120be28e9733ea2e
2024-05-30 00:17:21 +08:00
hoshi-hiyouga
5f67fdaac9 Update loader.py
Former-commit-id: 19d8fd62c18ee3ba0e431fc241f7d315cb716fef
2024-05-30 00:12:12 +08:00
hoshi-hiyouga
05e6fe4287 Update parser.py
Former-commit-id: 310cc11e8c83f16fc5bccc349c38fea347ea9a97
2024-05-30 00:05:20 +08:00
hoshi-hiyouga
91cc571e6e Update README_zh.md
Former-commit-id: 3007d260ed45169583a74497a53b661337dd5f71
2024-05-30 00:04:47 +08:00
hoshi-hiyouga
890926e60c Update README.md
Former-commit-id: 65fb69e388c0a04c15ecd11441e567966f51fae5
2024-05-30 00:04:26 +08:00
hiyouga
87aa332583 better llamaboard
* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
2024-05-29 23:55:38 +08:00
hiyouga
f90c4ca672 fix cohere system
Former-commit-id: 5d629b29e705c8ff8dd4521719d9c0e67a3fe0a2
2024-05-29 20:58:23 +08:00
hiyouga
a922e85a5c fix #3965
Former-commit-id: 37d15ac55d0be0ff47d6a88f07e2d823117a4a36
2024-05-29 20:55:51 +08:00
hiyouga
9a65820592 update readme
Former-commit-id: 440e9de66986ef7736361ce8ec3e23ce68655a56
2024-05-29 18:39:11 +08:00
hoshi-hiyouga
f4e16ae373 Merge pull request #3930 from MengqingCao/npu
Add Ascend npu doc and dependency

Former-commit-id: 7210090e4fc6531b9f6122f104875811a8798185
2024-05-29 18:33:38 +08:00
MengqingCao
e2cfd34da0 update torch-npu version
Former-commit-id: a70d7fcf2967eb30280a1fb845b39db7878f535c
2024-05-29 10:05:11 +00:00
MengqingCao
668dea9706 update cann kernels url
Former-commit-id: 23c65e9d7e8817b5815264e44cbf4a7bcb88d3d7
2024-05-29 09:53:31 +00:00
hoshi-hiyouga
084be442f2 Merge pull request #3958 from hzhaoy/add_telechat_12b_support
add TeleChat-12B/TeleChat-12B-v2 models

Former-commit-id: c228546a09764423ae66966079802022185f7e86
2024-05-29 17:20:53 +08:00
hzhaoy
29cb4a1327 add TeleChat-12B/TeleChat-12B-v2 models
Former-commit-id: e0675385c88af03aaef8d51586c8a282829c4051
2024-05-29 15:00:37 +08:00
hiyouga
81a61134b8 fix hf chat engine
Former-commit-id: 76ce52911690ab0dd8ffa5587127afb4ec942abe
2024-05-29 01:20:07 +08:00
hiyouga
cb1a49aa02 add ds config to webui
Former-commit-id: 66d72b263d36dc81de9f6152077663b613035977
2024-05-29 01:13:17 +08:00
hiyouga
351b4efc6c 10x generate in ppo w/ zero3
https://github.com/huggingface/trl/pull/1483

Former-commit-id: 5dc43ba8b373d8803bc22d88b3d0d95ef8b9c7f8
2024-05-29 00:23:23 +08:00
hiyouga
9b551309de update dpo, kto trainer
Former-commit-id: 4a6cc3c7046f8b27d05ea53ef216bab6fa7ebfaf
2024-05-29 00:14:29 +08:00
hiyouga
9fed4a2ef4 clean kto trainer
Former-commit-id: 76402bd78cbd3a99a544f0ac019468b569b0e1d1
2024-05-28 21:43:26 +08:00
hiyouga
bceac4f554 bump vllm version to 0.4.1
Former-commit-id: a00fd39a4c2f270620711f2bfbad8d460fb4aa89
2024-05-28 21:27:27 +08:00
hiyouga
ae3a88d3a7 update readme
Former-commit-id: bc861f76706df3f643028f1dfc8ec2044b067a08
2024-05-28 19:35:52 +08:00
hiyouga
9138a7a5ba support DDP in webui
Former-commit-id: d059262ff8dc857f597d2657546ec625726a664a
2024-05-28 19:24:22 +08:00
hiyouga
9912b43fcc update readme
Former-commit-id: e2c7de1b5147801b301cfc5da0e2866273da18f5
2024-05-28 16:41:34 +08:00
hiyouga
5ac37555a4 update readme
Former-commit-id: 30ef8ee1e86136f38f105b67f70c417d20552f41
2024-05-28 16:19:56 +08:00
hiyouga
34bdc730a6 fix #3931
Former-commit-id: 47e0072416b545d9718af4fa266a83f747b9a4f7
2024-05-28 13:44:22 +08:00
MengqingCao
e45a9d70fc add Ascend npu doc and dependency
Former-commit-id: 803d9f142a294f8c1e0b4e2046c214b0857ccfd6
2024-05-28 01:33:54 +00:00
hoshi-hiyouga
232b36059c Merge pull request #3925 from Yimi81/feat-fix-yi-template
fix yi template

Former-commit-id: 6caee1eb868b9f7b00578c6608883e89aa232d17
2024-05-27 22:59:32 +08:00
Yimi81
d9fbd675d5 fix yi template
Former-commit-id: b3669c8989c3adda305416245e32e9e5a3b7caac
2024-05-27 13:11:25 +00:00
hiyouga
0206e7b9de tiny fix
Former-commit-id: 4c47b3dcef9e400a1c35fce1ad53619a0a86fe81
2024-05-27 20:54:26 +08:00
hoshi-hiyouga
a886544d3d Merge pull request #3921 from gusye1234/main
Add openchat-3.6-8B support

Former-commit-id: 92e6bba3cab22b7835a68f787caf7992a398978e
2024-05-27 20:52:37 +08:00
hoshi-hiyouga
8c9b929bb0 Update template.py
Former-commit-id: f4dabce0a71c9978e051e70886941b64b928ffe2
2024-05-27 20:51:56 +08:00
hoshi-hiyouga
1bb1ae834e Update template.py
Former-commit-id: af869e4c48eb426c4078415533f6dab89123a9d8
2024-05-27 20:51:26 +08:00
Jianbai Ye
0d9e364a90 add openchat-3.6-8B support
Former-commit-id: b66f39d50d896d7597a1506e67ec210b31c9b700
2024-05-27 20:42:08 +08:00
hiyouga
3b28c003dd fix full/freeze tuning for mllm
Former-commit-id: df5860ddb593d5b82163a585d12160b41dbce0f3
2024-05-27 20:37:57 +08:00
hoshi-hiyouga
48ff9fb150 Merge pull request #3835 from BUAADreamer/main
fix some features in llava-style training

Former-commit-id: fc8583bd17dfb088a52e4d8fa91356b918373b50
2024-05-27 20:23:45 +08:00
hiyouga
c43bc74fe6 support Aya23
Former-commit-id: 071935b90006e2c79e39bb9ee0c5d48c6c910501
2024-05-27 20:23:24 +08:00
BUAADreamer
eaf9cc2195 Merge branch 'hiyouga:main' into main
Former-commit-id: cc1b82bf49b060987392c455fdbfe125ad667ec5
2024-05-27 20:10:58 +08:00
hiyouga
4bd276f58f add llava 1k datasets
Former-commit-id: 345d3355752f4a4dc454696a39f1610fffbbf382
2024-05-27 19:57:33 +08:00
hiyouga
f8cf0d5e5d update dpo examples
Former-commit-id: 69e32a7cb6336ca9a953c379ec794818b3f169bd
2024-05-27 19:56:04 +08:00
BUAADreamer
79bc60db33 Merge branch 'hiyouga:main' into main
Former-commit-id: d89e1f8bf8bad1dd125b4de8fe6c0b2b16411cb5
2024-05-27 19:00:48 +08:00
BUAADreamer
dc7c54067e add only tune lm and mm_proj
Former-commit-id: ba12ca430ec527fbfe4cd1eace0adb5c7712146a
2024-05-27 19:00:15 +08:00
BUAADreamer
932f0d5c20 add regex of only tune lm and mm_proj
Former-commit-id: 38d540b3e69bceabafafab524fcfc78aeb05612d
2024-05-27 18:59:00 +08:00
hiyouga
9670f5e41a add phi-3 7b/14b, mistral v0.3 models
Former-commit-id: 86dab182f9710b063f518922ccb49b01aa71c576
2024-05-27 18:20:16 +08:00
hiyouga
97a23e1cbe update readme
Former-commit-id: b8d0170fe0d094acce85dcb5f91775e4685ee055
2024-05-27 18:14:02 +08:00
BUAADreamer
11fcd055ec Merge branch 'hiyouga:main' into main
Former-commit-id: 113be744b3d044fbea3a8654158aa83ddb4599eb
2024-05-27 11:54:01 +08:00
hiyouga
b0d9966663 support SimPO #3900
Former-commit-id: 6b954ce60155cf8334150b795cfc4bb63ca74c8b
2024-05-26 23:46:33 +08:00
BUAADreamer
5c51ab7e1f Merge branch 'hiyouga:main' into main
Former-commit-id: fd5420c43e1414bcd3fadb6239f4e5d42e6ac10e
2024-05-25 14:18:49 +08:00
hiyouga
26f293d587 fix #3853
Former-commit-id: 465a5500bae1f30744d4b9b3db40aaf9171da2cb
2024-05-24 23:29:45 +08:00
seanzhang-zhichen
a3b52fd380 Merge branch 'main' into add_dataset_sample_num
Former-commit-id: 26300127c45f24e63b91f1b0cc73e46c3a936a91
2024-05-24 15:57:47 +08:00
BUAADreamer
27d8706d6d Merge branch 'hiyouga:main' into main
Former-commit-id: a4ce5ee381fd59f6b254ab634af51b6bb54edd97
2024-05-24 09:50:00 +08:00
hiyouga
bf59383783 refactor data preprocessing, fix mllm rlhf
Former-commit-id: 53ff2dd24f9121ea30c95063bb72e49a9b31e980
2024-05-24 04:08:25 +08:00
hoshi-hiyouga
1078611259 Merge pull request #3876 from dongdongqiang2018/main
added adapted to 910B image

Former-commit-id: 0708cc8a24589b9f22ad3df6685e57d1da0336f2
2024-05-24 01:54:30 +08:00
hiyouga
e6fc0ac8fe fix paligemma sft
requires transformers>=4.41.1


Former-commit-id: 80b3030569cd606ac0de43e9a682478f5bd7b727
2024-05-24 00:23:40 +08:00
hiyouga
554ca3d8dc fix oom issues in export
Former-commit-id: b7ccc882a192aa1e25b1e5816f875ea304282412
2024-05-23 23:32:45 +08:00
donggang
86dfdf956d adapted to 910B image
Former-commit-id: e095254808aace63a1be878620f683902f51cfb3
2024-05-23 09:48:22 +00:00
BUAADreamer
c0e4475485 Merge branch 'hiyouga:main' into main
Former-commit-id: 4076f52c8ba7da4624a1fb3fa52a7170d1c3171e
2024-05-21 22:18:20 +08:00
hiyouga
2b65f8bd5c fix paligemma sft
Former-commit-id: 60682d04414be37e611d6470618a8d599703942b
2024-05-21 20:03:09 +08:00
hiyouga
09e78272c2 Update README_zh.md
Former-commit-id: 34c4ba6bf9bb89170446fb396aa06ae44d251de0
2024-05-21 18:30:59 +08:00
hiyouga
cccce564bd update wechat
Former-commit-id: 6613349562194b48c5fc57aa68e620b8fa83fc0a
2024-05-21 18:22:32 +08:00
hiyouga
4adec327de fix #3847
Former-commit-id: d206b306ca4eadc8b3d4feaf490ad12f9452e562
2024-05-21 17:53:06 +08:00
BUAADreamer
1f093334d1 support pretraining of llava
Former-commit-id: 6a4c8cf0a6a1674c693b9337f018ff8df7477f8f
2024-05-21 08:57:14 +08:00
hiyouga
e0e8507108 support paligemma
Former-commit-id: 11c27f9bf204d3d6a9ca5bd4f0a19a420160453f
2024-05-21 00:01:22 +08:00
hiyouga
f5962f8128 fix paligemma data preprocess
Former-commit-id: 71b85437301739d9d96d3881d4a34b37c0f69db8
2024-05-20 23:51:32 +08:00
hiyouga
b31d808655 fix paligemma inference
Former-commit-id: 46357b7a677e8ba2e0a7c9d4ec1974abd061569c
2024-05-20 23:36:43 +08:00
hiyouga
247cda4b68 fix #3818
Former-commit-id: 3f366e05a34be224f53c5bf8334e57ae5d316004
2024-05-20 21:43:19 +08:00
hiyouga
e30975e9a2 add kto to webui
Former-commit-id: 6c866f4dbd45e868860be8351d1a65c4e1a4e02b
2024-05-20 21:20:25 +08:00
zhangzc
de9f1583c2 fix conflict
Former-commit-id: 6922b23a748c2459147bf44b96d86daa89f2c96c
2024-05-20 17:10:01 +08:00
hiyouga
ab48653e63 fix chat engines
do not use pop(key, default) since api assigns None to dict values


Former-commit-id: 3ebbd0b55ea07de2897c27ca54eeab5c3b319419
2024-05-20 00:36:43 +08:00
hoshi-hiyouga
6d7a1e3f8f Merge pull request #3812 from ycjcl868/feat/chat-support-system-prompt
feat: cli chat support system_message
Former-commit-id: 96596990527403e910c81e95e38bf2638541cf31
2024-05-20 00:31:32 +08:00
hoshi-hiyouga
e093dad7cb Update vllm_engine.py
Former-commit-id: 0b8278bd21baf35d3f60c6ed24f110b391c92a47
2024-05-20 00:31:04 +08:00
hoshi-hiyouga
b103a121f0 Update hf_engine.py
Former-commit-id: ce8b902e538c69d89f207db8a43c85072cd70265
2024-05-20 00:30:45 +08:00
hoshi-hiyouga
3578abc7a4 Update generating_args.py
Former-commit-id: 861c146fa7d9cb5b99372464bd068c20fa36415d
2024-05-20 00:29:31 +08:00
hoshi-hiyouga
17d398f419 Update chat_model.py
Former-commit-id: 7736aafdc81d175e9fb484dbb7cae9263120a0fc
2024-05-20 00:29:12 +08:00
hiyouga
3453a8eebb fix jinja template
Former-commit-id: 353561f0e3914de3f81499c4e4b831ae0a6383b6
2024-05-19 23:38:30 +08:00
ycjcl868
77a089c35c feat: cli chat support system_message
Former-commit-id: e3982bff596d01992733687a580c4f41c558061c
2024-05-19 23:17:46 +08:00
hiyouga
516d83c946 fix zero2 high ram usage
Former-commit-id: 01797126eb173250250e31f8e76b69ae0047745d
2024-05-19 21:53:54 +08:00
hiyouga
fd02c9f973 fix hf gen args
Former-commit-id: 491a84976258cbb2a2647922420e2f84de1e38cd
2024-05-19 19:39:32 +08:00
hiyouga
351e80a656 fix envs
Former-commit-id: d5e150cfb98f8216713415564ab386b8320c88cb
2024-05-19 18:27:18 +08:00
hiyouga
4f04e2ed93 fix #3807
Former-commit-id: 08b695969049de8bf9bd3e90b9700736d90385ee
2024-05-19 17:07:57 +08:00
hiyouga
a810d1b98e update readme
Former-commit-id: e0beb67a417b13c818a09bd419d4e20dd44ca842
2024-05-18 23:09:03 +08:00
hiyouga
fbe963a96a safe output path in webui
Former-commit-id: 23f14262e0d54631630c084ba71e0433ea1d4640
2024-05-18 22:42:28 +08:00
hiyouga
d13b8bee8a fix jetmoe z3 block
Former-commit-id: cb00a14d905395c4b8fadb955f0424a4c56668de
2024-05-18 22:28:45 +08:00
hiyouga
0aa072a155 improve data process logger
Former-commit-id: 33d0b012b56dbafc9fff87b821c2d1bf1409dbb5
2024-05-18 22:02:42 +08:00
hiyouga
57dde7c3bc update data readme
Former-commit-id: 22c7335b496e4a673383d5a1e4e60bf2cb4e35b3
2024-05-18 21:37:38 +08:00
hiyouga
6b9003f781 update data readme
Former-commit-id: beb864a9367943d3274cb6057423d1eb9aaf85c4
2024-05-18 21:15:20 +08:00
hiyouga
9c1c59e481 fix #3803
Former-commit-id: 1ef12c95059d14a1717c82ce04e529e7ad6435ed
2024-05-18 16:13:14 +08:00
hoshi-hiyouga
31daec2749 Merge pull request #3799 from hiyouga/dev
improve KTO impl, replace datasets

Former-commit-id: b4cc207855aa1dbb120f7999165e176e649af338
2024-05-18 03:49:13 +08:00
hiyouga
2bff90719b improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
2024-05-18 03:44:56 +08:00
hoshi-hiyouga
e4570e28a8 Merge pull request #3785 from enji-zhou/feature/add_kto
add kto

Former-commit-id: f60faa23e23022fd855dac6b1ecbd21e095bccb5
2024-05-18 03:07:18 +08:00
hoshi-hiyouga
d84a730daa Merge pull request #3794 from jue-jue-zi/main
feat: pass the `max_lora_rank` parameter to vLLM backend
Former-commit-id: be839961686a1845f00a56e398a7b3779df8b6e4
2024-05-17 16:17:30 +08:00
hoshi-hiyouga
0fd1a05cec Update model_args.py
Former-commit-id: f40a2fe5334865763e4d513292d359317b7a091b
2024-05-17 16:16:41 +08:00
juejuezi
6373d307ec feat: pass the max_lora_rank parameter to vLLM backend
Former-commit-id: a8756d839405ecb5deabe885cf11d1a61564deee
2024-05-17 16:07:39 +08:00
hiyouga
a32c3a50fc add deepseek v2 lite model
Former-commit-id: 5e864e6b721d8b891b1cc2ca2dcac41babb9eaaf
2024-05-17 13:25:36 +08:00
enji.zhou
66b5634ebf add kto
Former-commit-id: ec51986cf70b0bdd79b8141e45916670fb97a08e
2024-05-17 13:09:17 +08:00
hiyouga
92b3697e2c update badam example #3764
Former-commit-id: a3730fd0a96bab869be6d695031182dabaea8137
2024-05-17 02:21:10 +08:00
hiyouga
969e605c7e better dtype handle in loading
Former-commit-id: 663f0577dd61a1a31191db2c6fbb0c7cea533b21
2024-05-17 02:14:56 +08:00
hiyouga
a3320f26cf update examples
Former-commit-id: 3b5f138155d96b346bda18e465cf60ec7d99e19c
2024-05-17 01:02:00 +08:00
hiyouga
45329d9e3c enable inbrowser in webui
Former-commit-id: 71fdeedb64b2339eb1c740d670b87e0c03dada68
2024-05-17 00:08:56 +08:00
hiyouga
6481321470 add falcon 11b
Former-commit-id: 897acc725edc204fad393cc9616828431b4fa768
2024-05-17 00:08:33 +08:00
hiyouga
efcf5e050d fix examples #3769
Former-commit-id: 80c036beb8d9ddac8f844f1818c9488ded04e86e
2024-05-16 19:12:09 +08:00
hiyouga
dfa686b617 rename package
Former-commit-id: a07ff0c083558cfe6f474d13027642d3052fee08
2024-05-16 18:39:08 +08:00
hiyouga
fe638cf11f set dev version
Former-commit-id: 5e9c72d07c3793cdccbdb8a9f95f1bb5d714e0a3
2024-05-16 02:17:31 +08:00
zhangzc
7cdc16abdf Supports custom data set sampling quantity
Former-commit-id: fa8325401df27595de4611a89dfcc14644956abd
2024-03-27 14:22:50 +08:00
195 changed files with 4296 additions and 2422 deletions

View File

@@ -4,6 +4,8 @@
.venv .venv
cache cache
data data
hf_cache
output
examples examples
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@@ -13,6 +13,18 @@ body:
- label: I have read the README and searched the existing issues. - label: I have read the README and searched the existing issues.
required: true required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea - type: textarea
id: reproduction id: reproduction
validations: validations:
@@ -26,7 +38,7 @@ body:
请合理使用 Markdown 标签来格式化您的文本。 请合理使用 Markdown 标签来格式化您的文本。
placeholder: | placeholder: |
python src/train_bash.py ... llamafactory-cli train ...
- type: textarea - type: textarea
id: expected-behavior id: expected-behavior
@@ -38,18 +50,6 @@ body:
Please provide a clear and concise description of what you would expect to happen. Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。 请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: system-info
validations:
required: false
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
placeholder: transformers version, platform, python version, ...
- type: textarea - type: textarea
id: others id: others
validations: validations:

View File

@@ -2,28 +2,38 @@ name: tests
on: on:
push: push:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
pull_request: pull_request:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
jobs: jobs:
check_code_quality: tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.8"
cache: "pip"
cache-dependency-path: "setup.py"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ruff python -m pip install .[torch,dev]
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
- name: Test with pytest
run: |
make test

View File

@@ -6,7 +6,7 @@ COPY requirements.txt /app/
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
COPY . /app/ COPY . /app/
RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen] RUN pip install -e .[metrics,bitsandbytes,qwen]
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
EXPOSE 7860 EXPOSE 7860

View File

@@ -1,4 +1,4 @@
.PHONY: quality style .PHONY: quality style test
check_dirs := scripts src tests check_dirs := scripts src tests
@@ -9,3 +9,6 @@ quality:
style: style:
ruff check $(check_dirs) --fix ruff check $(check_dirs) --fix
ruff format $(check_dirs) ruff format $(check_dirs)
test:
pytest tests/

206
README.md
View File

@@ -3,15 +3,15 @@
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
## Table of Contents ## Table of Contents
@@ -46,7 +47,7 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
@@ -70,14 +71,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. [24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
[24/05/13] We supported fine-tuning the **Yi-1.5** series models. [24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. [24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
@@ -104,7 +113,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall`. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
@@ -142,43 +151,44 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Default module | Template | | Model | Model size | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
> >
> Remember to use the **SAME** template in training and inference. > Remember to use the **SAME** template in training and inference.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported. Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
You also can add a custom chat template to [template.py](src/llmtuner/data/template.py). You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
## Supported Training Approaches ## Supported Training Approaches
@@ -189,7 +199,9 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## Provided Datasets ## Provided Datasets
@@ -202,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@@ -209,12 +223,12 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
<details><summary>Supervised fine-tuning datasets</summary> <details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json) - [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -223,7 +237,6 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@@ -236,15 +249,16 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -260,13 +274,13 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
<details><summary>Preference datasets</summary> <details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details> </details>
@@ -281,21 +295,21 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### Hardware Requirement ### Hardware Requirement
@@ -319,12 +333,12 @@ huggingface-cli login
> Installation is mandatory. > Installation is mandatory.
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e '.[torch,metrics]'
``` ```
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -343,19 +357,35 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Join [NPU user group](assets/wechat_npu.jpg).
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
# install CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 | | CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.2.0 | 2.2.0 | | torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.2.0 | 2.2.0 | | torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
Docker image: Docker image:
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB: Coming soon - 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
@@ -387,29 +417,12 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU.
#### Use local environment #### Use local environment
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
``` ```
<details><summary>For Alibaba Cloud PAI or AutoDL users</summary>
If you encountered display problems in LLaMA Board on Alibaba Cloud PAI, try using the following command to set environment variables before starting LLaMA Board:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
If you are using AutoDL, please install a specific version of Gradio:
```bash
pip install gradio==4.10.0
```
</details> </details>
#### Use Docker #### Use Docker
@@ -420,7 +433,6 @@ docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llama_factory \
@@ -447,6 +459,9 @@ docker compose -f ./docker-compose.yml up -d
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
### Download from ModelScope Hub ### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
@@ -455,7 +470,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
``` ```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
```yaml
report_to: wandb
run_name: test_run # optional
```
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
## Projects using LLaMA Factory ## Projects using LLaMA Factory
@@ -502,7 +528,7 @@ If you have a project that should be incorporated, please contact via email or c
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
@@ -514,7 +540,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -3,15 +3,15 @@
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
## 目录 ## 目录
@@ -46,7 +47,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练ORPO 训练。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
@@ -70,14 +71,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分 [24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。 [24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。 [24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
@@ -104,7 +113,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 [24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall` 即可使模型获得工具调用能力。 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall_zh` 即可使模型获得工具调用能力。
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@@ -142,43 +151,44 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板** > 请务必在训练和推理时采用**完全一致**的模板。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
您也可以在 [template.py](src/llmtuner/data/template.py) 中添加自己的对话模板。 您也可以在 [template.py](src/llamafactory/data/template.py) 中添加自己的对话模板。
## 训练方法 ## 训练方法
@@ -189,7 +199,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## 数据集 ## 数据集
@@ -202,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@@ -209,12 +223,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>指令微调数据集</summary> <details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json) - [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -223,7 +237,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@@ -236,15 +249,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) - [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) - [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) - [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) - [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) - [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -260,13 +274,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>偏好数据集</summary> <details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) - [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details> </details>
@@ -281,21 +295,21 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.40.1 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.0 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.1 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### 硬件依赖 ### 硬件依赖
@@ -319,12 +333,12 @@ huggingface-cli login
> 此步骤为必需。 > 此步骤为必需。
```bash ```bash
git clone https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e '.[torch,metrics]'
``` ```
可选的额外依赖项torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -343,21 +357,37 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)** 加入 [NPU 用户群](assets/wechat_npu.jpg)
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
# 安装 CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# 安装 CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 | | CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.2.0 | 2.2.0 | | torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.2.0 | 2.2.0 | | torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
Docker 镜像: Docker 镜像:
- 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB敬请期待 - 64GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false` 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
@@ -387,31 +417,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练。
#### 使用本地环境 #### 使用本地环境
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
``` ```
<details><summary>阿里云 PAI 和 AutoDL 用户指南</summary>
如果您在阿里云 PAI 上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
```bash
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
```
如果您正在使用 AutoDL请安装下述 Gradio 版本:
```bash
pip install gradio==4.10.0
```
</details>
#### 使用 Docker #### 使用 Docker
```bash ```bash
@@ -420,7 +431,6 @@ docker run --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llama_factory \
@@ -447,6 +457,9 @@ docker compose -f ./docker-compose.yml up -d
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
### 从魔搭社区下载 ### 从魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@@ -455,7 +468,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
```yaml
report_to: wandb
run_name: test_run # 可选
```
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
@@ -502,7 +526,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) 1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
@@ -514,7 +538,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -1,16 +1,18 @@
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section. The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it.
Currently we support datasets in **alpaca** and **sharegpt** format.
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
"ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url and file_name)", "ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url and file_name)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
"file_name": "the name of the dataset file in this directory. (required if above are not specified)", "file_name": "the name of the dataset folder or dataset file in this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
"subset": "the name of the subset. (optional, default: None)", "subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns (optional)": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input)", "query": "the column name in the dataset containing the queries. (default: input)",
@@ -19,7 +21,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)", "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)" "images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
}, },
"tags (optional, used for the sharegpt format)": { "tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)", "role_tag": "the key in the message represents the identity. (default: from)",
@@ -33,28 +38,34 @@ If you are using a custom dataset, please add your **dataset description** to `d
} }
``` ```
After that, you can load the custom dataset by specifying `--dataset dataset_name`. ## Alpaca Format
---- ### Supervised Fine-Tuning Dataset
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format: * [Example dataset](alpaca_en_demo.json)
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response.
The `system` column will be used as the system prompt if specified.
The `history` column is a list consisting of string tuples representing prompt-response pairs in the history messages. Note that the responses in the history **will also be learned by the model** in supervised fine-tuning.
```json ```json
[ [
{ {
"instruction": "user instruction (required)", "instruction": "human instruction (required)",
"input": "user input (optional)", "input": "human input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)", "system": "system prompt (optional)",
"history": [ "history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"], ["human instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"] ["human instruction in the second round (optional)", "model response in the second round (optional)"]
] ]
} }
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -69,11 +80,11 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response. ### Pre-training Dataset
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning. - [Example dataset](c4_demo.json)
For the **pre-training datasets**, only the `prompt` column will be used for training, for example: In pre-training, only the `text` column will be used for model learning.
```json ```json
[ [
@@ -82,7 +93,7 @@ For the **pre-training datasets**, only the `prompt` column will be used for tra
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -93,22 +104,24 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example: ### Preference Dataset
Preference datasets are used for reward modeling, DPO training and ORPO training.
It requires a better response in `chosen` column and a worse response in `rejected` column.
```json ```json
[ [
{ {
"instruction": "user instruction", "instruction": "human instruction (required)",
"input": "user input", "input": "human input (optional)",
"output": [ "chosen": "chosen answer (required)",
"chosen answer", "rejected": "rejected answer (required)"
"rejected answer"
]
} }
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -117,14 +130,85 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": { "columns": {
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "chosen": "chosen",
"rejected": "rejected"
} }
} }
``` ```
---- ### KTO Dataset
The dataset in **sharegpt** format should follow the below format: - [Example dataset](kto_en_demo.json)
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Dataset
- [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt Format
### Supervised Fine-Tuning Dataset
- [Example dataset](glaive_toolcall_en_demo.json)
Compared to the alpaca format, the sharegpt format allows the datasets have **more roles**, such as human, gpt, observation and function. They are presented in a list of objects in the `conversations` column.
Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions.
```json ```json
[ [
@@ -132,7 +216,15 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "user instruction" "value": "human instruction"
},
{
"from": "function_call",
"value": "tool arguments"
},
{
"from": "observation",
"value": "tool result"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -145,7 +237,7 @@ The dataset in **sharegpt** format should follow the below format:
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -155,19 +247,63 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"messages": "conversations", "messages": "conversations",
"system": "system", "system": "system",
"tools": "tools" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
where the `messages` column should be a list following the `u/a/u/a/u/a` order. ### Preference Dataset
We also supports the dataset in the **openai** format: - [Example dataset](dpo_en_demo.json)
Preference datasets in sharegpt format also require a better message in `chosen` column and a worse message in `rejected` column.
```json
[
{
"conversations": [
{
"from": "human",
"value": "human instruction"
},
{
"from": "gpt",
"value": "model response"
},
{
"from": "human",
"value": "human instruction"
}
],
"chosen": {
"from": "gpt",
"value": "chosen answer (required)"
},
"rejected": {
"from": "gpt",
"value": "rejected answer (required)"
}
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"ranking": true,
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
```json ```json
[ [
@@ -179,7 +315,7 @@ We also supports the dataset in the **openai** format:
}, },
{ {
"role": "user", "role": "user",
"content": "user instruction" "content": "human instruction"
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -190,7 +326,7 @@ We also supports the dataset in the **openai** format:
] ]
``` ```
Regarding the above dataset, the description in `dataset_info.json` should be: Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json ```json
"dataset_name": { "dataset_name": {
@@ -209,4 +345,6 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
} }
``` ```
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet. The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
Pre-training datasets are **incompatible** with the sharegpt format.

View File

@@ -1,16 +1,18 @@
如果您使用自定义数据集,请务必按照以下格式`dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子 [dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**`dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。
```json ```json
"数据集名称": { "数据集名称": {
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name", "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"ranking": "是否为偏好数据集可选默认False",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None", "folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集(可选,默认:False", "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns可选": { "columns可选": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
@@ -19,7 +21,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations", "messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None", "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None" "images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
}, },
"tags可选用于 sharegpt 格式)": { "tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from", "role_tag": "消息中代表发送者身份的键名默认from",
@@ -28,22 +33,28 @@
"assistant_tag": "消息中代表助手的 role_tag默认gpt", "assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation", "observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call", "function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system " "system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system column"
} }
} }
``` ```
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。 ## Alpaca 格式
---- ### 指令监督微调数据集
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织: - [样例数据集](alpaca_zh_demo.json)
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
如果指定,`system` 列对应的内容将被作为系统提示词。
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容**也会被用于模型学习**。
```json ```json
[ [
{ {
"instruction": "用户指令(必填)", "instruction": "人类指令(必填)",
"input": "用户输入(选填)", "input": "人类输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)", "system": "系统提示词(选填)",
"history": [ "history": [
@@ -54,7 +65,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -69,11 +80,11 @@
} }
``` ```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。 ### 预训练数据集
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。 - [样例数据集](c4_demo.json)
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如: 在预训练时,只有 `text` 列中的内容会用于模型学习。
```json ```json
[ [
@@ -82,7 +93,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -93,22 +104,24 @@
} }
``` ```
对于**偏好数据集**`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如: ### 偏好数据集
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
```json ```json
[ [
{ {
"instruction": "用户指令", "instruction": "人类指令(必填)",
"input": "用户输入", "input": "人类输入(选填)",
"output": [ "chosen": "优质回答(必填)",
"质回答", "rejected": "质回答(必填)"
"劣质回答"
]
} }
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -117,14 +130,85 @@
"columns": { "columns": {
"prompt": "instruction", "prompt": "instruction",
"query": "input", "query": "input",
"response": "output", "chosen": "chosen",
"rejected": "rejected"
} }
} }
``` ```
---- ### KTO 数据集
**sharegpt** 格式的数据集按照以下方式组织: - [样例数据集](kto_en_demo.json)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### 多模态数据集
- [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt 格式
### 指令监督微调数据集
- [样例数据集](glaive_toolcall_zh_demo.json)
相比 alpaca 格式的数据集sharegpt 格式支持**更多的角色种类**,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 `conversations` 列中。
注意其中 human 和 observation 必须出现在奇数位置gpt 和 function 必须出现在偶数位置。
```json ```json
[ [
@@ -132,7 +216,15 @@
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "用户指令" "value": "人类指令"
},
{
"from": "function_call",
"value": "工具参数"
},
{
"from": "observation",
"value": "工具结果"
}, },
{ {
"from": "gpt", "from": "gpt",
@@ -145,7 +237,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -155,19 +247,63 @@
"messages": "conversations", "messages": "conversations",
"system": "system", "system": "system",
"tools": "tools" "tools": "tools"
},
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。 ### 偏好数据集
我们同样支持 **openai** 格式的数据集: - [样例数据集](dpo_zh_demo.json)
Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的消息,并在 `rejected` 列中提供更差的消息。
```json
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
},
{
"from": "human",
"value": "人类指令"
}
],
"chosen": {
"from": "gpt",
"value": "优质回答"
},
"rejected": {
"from": "gpt",
"value": "劣质回答"
}
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"ranking": true,
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
```json ```json
[ [
@@ -179,7 +315,7 @@
}, },
{ {
"role": "user", "role": "user",
"content": "用户指令" "content": "人类指令"
}, },
{ {
"role": "assistant", "role": "assistant",
@@ -190,7 +326,7 @@
] ]
``` ```
对于上述格式的数据,`dataset_info.json` 中的描述应为: 对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json ```json
"数据集名称": { "数据集名称": {
@@ -209,4 +345,6 @@
} }
``` ```
预训练数据集和偏好数据集**尚不支持** sharegpt 格式 Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似
预训练数据集**不支持** sharegpt 格式。

View File

@@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -1 +0,0 @@
a97cf9475291591843976554878568e046d8a46d

View File

@@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

View File

@@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@@ -1 +0,0 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
features = datasets.Features( features = datasets.Features(
{ {
"instruction": datasets.Value("string"), "instruction": datasets.Value("string"),
"output": datasets.Sequence(datasets.Value("string")), "chosen": datasets.Value("string"),
"rejected": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
} }
) )
@@ -79,5 +80,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break break
prompt = prompt[:human_idx] prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history} yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1 key += 1

View File

@@ -1 +0,0 @@
736bcedea2b24a1414765c6d69cbdafaea839f3c

30
data/wiki_demo.txt Normal file

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb

View File

@@ -10,8 +10,6 @@ services:
- ./hf_cache:/root/.cache/huggingface/ - ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data - ./data:/app/data
- ./output:/app/output - ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports: ports:
- "7860:7860" - "7860:7860"
ipc: host ipc: host

View File

@@ -154,7 +154,7 @@ class MMLU(datasets.GeneratorBasedBuilder):
] ]
def _generate_examples(self, filepath): def _generate_examples(self, filepath):
df = pd.read_csv(filepath) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): for i, instance in enumerate(df.to_dict(orient="records")):

View File

@@ -47,16 +47,16 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
``` ```
#### DPO Training #### DPO/ORPO/SimPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
``` ```
#### ORPO Training #### KTO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
``` ```
#### Preprocess Dataset #### Preprocess Dataset
@@ -107,22 +107,23 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
### LoRA Fine-Tuning on Multiple GPUs ### LoRA Fine-Tuning on Multiple GPUs
#### Supervised Fine-Tuning with Accelerate on Single Node #### Supervised Fine-Tuning on Single Node
```bash ```bash
bash examples/lora_multi_gpu/single_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
bash examples/lora_multi_gpu/multi_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
``` ```
### LoRA Fine-Tuning on Multiple NPUs ### LoRA Fine-Tuning on Multiple NPUs
@@ -130,27 +131,28 @@ bash examples/lora_multi_gpu/ds_zero3.sh
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 #### Supervised Fine-Tuning with DeepSpeed ZeRO-0
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
``` ```
### Full-Parameter Fine-Tuning on Multiple GPUs ### Full-Parameter Fine-Tuning on Multiple GPUs
#### Supervised Fine-Tuning with Accelerate on Single Node #### Supervised Fine-Tuning on Single Node
```bash ```bash
bash examples/full_multi_gpu/single_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
``` ```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
bash examples/full_multi_gpu/predict.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@@ -171,22 +173,24 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
#### Use CLI #### Use CLI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### Use Web UI #### Use Web UI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### Launch OpenAI-style API #### Launch OpenAI-style API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### Extras ### Extras

View File

@@ -47,16 +47,16 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
``` ```
#### DPO 训练 #### DPO/ORPO/SimPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
``` ```
#### ORPO 训练 #### KTO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
``` ```
#### 预处理数据集 #### 预处理数据集
@@ -107,50 +107,52 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
### 多 GPU LoRA 微调 ### 多 GPU LoRA 微调
#### 使用 Accelerate 进行单节点训练 #### 在单机上进行指令监督微调
```bash ```bash
bash examples/lora_multi_gpu/single_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### 使用 Accelerate 进行多节点训练 #### 在多机上进行指令监督微调
```bash ```bash
bash examples/lora_multi_gpu/multi_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
``` ```
### 多 NPU LoRA 微调 ### 多 NPU LoRA 微调
#### 使用 DeepSpeed ZeRO-0 训练 #### 使用 DeepSpeed ZeRO-0 进行指令监督微调
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
``` ```
### 多 GPU 全参数微调 ### 多 GPU 全参数微调
#### 使用 DeepSpeed 进行单节点训练 #### 在单机上进行指令监督微调
```bash ```bash
bash examples/full_multi_gpu/single_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
``` ```
#### 使用 DeepSpeed 进行多节点训练 #### 在多机上进行指令监督微调
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
bash examples/full_multi_gpu/predict.sh CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@@ -171,22 +173,24 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
### 推理 LoRA 模型 ### 推理 LoRA 模型
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
#### 使用命令行接口 #### 使用命令行接口
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### 使用浏览器界面 #### 使用浏览器界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### 启动 OpenAI 风格 API #### 启动 OpenAI 风格 API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### 杂项 ### 杂项

View File

@@ -5,16 +5,16 @@ downcast_bf16: 'no'
fsdp_config: fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_offload_params: true fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true # offload may affect training speed
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_use_orig_params: false fsdp_use_orig_params: true
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16 # or bf16
num_machines: 1 # the number of nodes num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static

View File

@@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,16 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1 # the number of nodes
num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 1
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,41 +1,41 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_badam: true use_badam: true
badam_switch_mode: descending badam_switch_mode: ascending
badam_switch_interval: 50 badam_switch_interval: 50
badam_verbose: 2 badam_verbose: 2
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,42 +1,42 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp ### ddp
ddp_timeout: 180000000 ddp_timeout: 180000000
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,10 +1,6 @@
#!/bin/bash #!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA # DO NOT use GPTQ/AWQ model in FSDP+QLoRA
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file examples/accelerate/fsdp_config.yaml \ --config_file examples/accelerate/fsdp_config.yaml \
src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml

View File

@@ -1,7 +1,7 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
@@ -11,32 +11,32 @@ galore_target: mlp,self_attn
galore_rank: 128 galore_rank: 128
galore_scale: 2.0 galore_scale: 2.0
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,7 +1,7 @@
# model ### model
model_name_or_path: models/llama3-8b-instruct-pro model_name_or_path: models/llama3-8b-instruct-pro
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: freeze finetuning_type: freeze
@@ -9,32 +9,32 @@ freeze_trainable_layers: 8
freeze_trainable_modules: all freeze_trainable_modules: all
use_llama_pro: true use_llama_pro: true
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b-instruct-pro/freeze/sft output_dir: saves/llama3-8b-instruct-pro/freeze/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
loraplus_lr_ratio: 16.0 loraplus_lr_ratio: 16.0
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
mixture_of_depths: convert mixture_of_depths: convert
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b-mod/full/sft output_dir: saves/llama3-8b-mod/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
optim: paged_adamw_8bit optim: paged_adamw_8bit
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,23 +1,23 @@
# model ### model
model_name_or_path: saves/llama3-8b/full/sft model_name_or_path: saves/llama3-8b/full/sft
# method ### method
stage: sft stage: sft
do_predict: true do_predict: true
finetuning_type: full finetuning_type: full
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/predict output_dir: saves/llama3-8b/full/predict
overwrite_output_dir: true overwrite_output_dir: true
# eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true

View File

@@ -1,41 +1,41 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
# ddp ### ddp
ddp_timeout: 180000000 ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/llama3-8b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=2
RANK=0
MASTER_ADDR=192.168.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

View File

@@ -1,41 +1,41 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp ### ddp
ddp_timeout: 180000000 ddp_timeout: 180000000
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,42 +1,42 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp ### ddp
ddp_timeout: 180000000 ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,6 +0,0 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/master_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml

View File

@@ -1,42 +1,42 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# ddp ### ddp
ddp_timeout: 180000000 ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json deepspeed: examples/deepspeed/ds_z0_config.json
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,40 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: dpo stage: dpo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
dpo_ftx: 1.0 pref_beta: 0.1
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
# dataset ### dataset
dataset: orca_rlhf dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/dpo output_dir: saves/llama3-8b/lora/dpo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,19 +1,19 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
# method ### method
finetuning_type: lora finetuning_type: lora
# dataset ### dataset
task: mmlu task: mmlu
split: test split: test
template: fewshot template: fewshot
lang: en lang: en
n_shot: 5 n_shot: 5
# output ### output
save_dir: saves/llama3-8b/lora/eval save_dir: saves/llama3-8b/lora/eval
# eval ### eval
batch_size: 4 batch_size: 4

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: orpo stage: kto
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: orca_rlhf dataset: kto_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/orpo output_dir: saves/llama3-8b/lora/kto
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
reward_model: saves/llama3-8b/lora/reward reward_model: saves/llama3-8b/lora/reward
# method ### method
stage: ppo stage: ppo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/ppo output_dir: saves/llama3-8b/lora/ppo
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# generate ### generate
max_new_tokens: 512 max_new_tokens: 512
top_k: 0 top_k: 0
top_p: 0.9 top_p: 0.9

View File

@@ -1,24 +1,24 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
# method ### method
stage: sft stage: sft
do_predict: true do_predict: true
finetuning_type: lora finetuning_type: lora
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/predict output_dir: saves/llama3-8b/lora/predict
overwrite_output_dir: true overwrite_output_dir: true
# eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true

View File

@@ -1,37 +1,37 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: pt stage: pt
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: c4_demo dataset: c4_demo
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: rm stage: rm
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: orca_rlhf dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/reward output_dir: saves/llama3-8b/lora/reward
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,14 +1,14 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
@@ -16,6 +16,6 @@ overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
tokenized_path: saves/llama3-8b/dataset/sft tokenized_path: saves/llama3-8b/dataset/sft
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
overwrite_output_dir: true overwrite_output_dir: true

View File

@@ -1,14 +1,14 @@
# model ### model
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true visual_inputs: true
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: mllm_demo dataset: mllm_demo
template: vicuna template: vicuna
cutoff_len: 1024 cutoff_len: 1024
@@ -16,24 +16,24 @@ max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llava1_5-7b/lora/sft output_dir: saves/llava1_5-7b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,8 +1,8 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
template: llama3 template: llama3
# export ### export
export_dir: models/llama3_gptq export_dir: models/llama3_gptq
export_quantization_bit: 4 export_quantization_bit: 4
export_quantization_dataset: data/c4_demo.json export_quantization_dataset: data/c4_demo.json

View File

@@ -1,12 +1,12 @@
# Note: DO NOT use quantized model or quantization_bit when merging lora adapters ### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
adapter_name_or_path: saves/llama3-8b/lora/sft adapter_name_or_path: saves/llama3-8b/lora/sft
template: llama3 template: llama3
finetuning_type: lora finetuning_type: lora
# export ### export
export_dir: models/llama3_lora_sft export_dir: models/llama3_lora_sft
export_size: 2 export_size: 2
export_device: cpu export_device: cpu

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16 model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,39 +1,39 @@
# model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4 quantization_bit: 4
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -1,38 +1,38 @@
# model ### model
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
# method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
# dataset ### dataset
dataset: identity,alpaca_gpt4_en dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
# output ### output
output_dir: saves/llama3-8b/lora/sft output_dir: saves/llama3-8b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
# train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
# eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@@ -13,7 +13,7 @@ select = ["C", "E", "F", "I", "W"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["llmtuner"] known-first-party = ["llamafactory"]
known-third-party = [ known-third-party = [
"accelerate", "accelerate",
"datasets", "datasets",

View File

@@ -1,12 +1,13 @@
transformers>=4.37.2 transformers>=4.41.2
datasets>=2.14.3 datasets>=2.16.0
accelerate>=0.27.2 accelerate>=0.30.1
peft>=0.10.0 peft>=0.11.1
trl>=0.8.1 trl>=0.8.6
gradio>=4.0.0 gradio>=4.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
tiktoken
protobuf protobuf
uvicorn uvicorn
pydantic pydantic

View File

@@ -8,7 +8,7 @@ import torch
from deepspeed.accelerator import get_accelerator # type: ignore from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner.chat import ChatModel from llamafactory.chat import ChatModel
def calculate_flops( def calculate_flops(

View File

@@ -12,10 +12,10 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_tokenizer from llamafactory.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models

View File

@@ -12,10 +12,10 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_model, load_tokenizer from llamafactory.model import load_model, load_tokenizer
@dataclass @dataclass

View File

@@ -7,9 +7,9 @@ from collections import defaultdict
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from llmtuner.data import get_dataset from llamafactory.data import get_dataset
from llmtuner.hparams import get_train_args from llamafactory.hparams import get_train_args
from llmtuner.model import load_tokenizer from llamafactory.model import load_tokenizer
def length_cdf( def length_cdf(

View File

@@ -104,10 +104,10 @@ def block_expansion(
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
print("Fine-tune this model with:") print("Fine-tune this model with:")
print(" --model_name_or_path {} \\".format(output_dir)) print("model_name_or_path: {}".format(output_dir))
print(" --finetuning_type freeze \\") print("finetuning_type: freeze")
print(" --freeze_trainable_layers {} \\".format(num_expand)) print("freeze_trainable_layers: {}".format(num_expand))
print(" --use_llama_pro") print("use_llama_pro: true")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -20,7 +20,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="0", api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
) )
tools = [ tools = [

View File

@@ -5,7 +5,7 @@ from setuptools import find_packages, setup
def get_version(): def get_version():
with open(os.path.join("src", "llmtuner", "cli.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@@ -21,24 +21,25 @@ def get_requires():
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.0"], "vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam"], "badam": ["badam"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"quality": ["ruff"], "dev": ["ruff", "pytest"],
} }
def main(): def main():
setup( setup(
name="llmtuner", name="llamafactory",
version=get_version(), version=get_version(),
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
@@ -53,7 +54,7 @@ def main():
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]}, entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -2,8 +2,8 @@ import os
import uvicorn import uvicorn
from llmtuner.api.app import create_app from llamafactory.api.app import create_app
from llmtuner.chat import ChatModel from llamafactory.chat import ChatModel
def main(): def main():

View File

@@ -0,0 +1,6 @@
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
__version__ = VERSION

View File

@@ -1,10 +1,13 @@
import base64
import io
import json import json
import os
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@@ -25,7 +28,17 @@ if is_fastapi_available():
from fastapi import HTTPException, status from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@@ -40,7 +53,9 @@ ROLE_MAPPING = {
} }
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:
@@ -49,12 +64,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content system = request.messages.pop(0).content
else: else:
system = "" system = None
if len(request.messages) % 2 == 0: if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -66,6 +82,21 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
arguments = message.tool_calls[0].function.arguments arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@@ -76,9 +107,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = "" tools = None
return input_messages, system, tools return input_messages, system, tools, image
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@@ -97,11 +128,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@@ -145,7 +177,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -159,6 +191,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,

View File

@@ -56,9 +56,19 @@ class FunctionCall(BaseModel):
function: Function function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[str] = None content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[List[FunctionCall]] = None

View File

@@ -2,12 +2,13 @@ import asyncio
import concurrent.futures import concurrent.futures
import os import os
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -23,6 +24,9 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
def __init__( def __init__(
self, self,
@@ -55,47 +59,69 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None, image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if processor is not None and image is not None and "<image>" not in messages[0]["content"]: if (
messages[0]["content"] = "<image>" + messages[0]["content"] processor is not None
and image is not None
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
dict( dict(
do_sample=do_sample, do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature, temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p, top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k, top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty
length_penalty=length_penalty, if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
) )
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False
if not generating_args["do_sample"]: if not generating_args["do_sample"]:
generating_args.pop("temperature", None) generating_args.pop("temperature", None)
@@ -111,14 +137,13 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict( gen_kwargs = dict(
inputs=inputs, inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if processor is not None and image is not None: if pixel_values is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") gen_kwargs["pixel_values"] = pixel_values
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@@ -1,12 +1,12 @@
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -17,7 +17,6 @@ if is_vllm_available():
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from numpy.typing import NDArray from numpy.typing import NDArray
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
@@ -36,8 +35,6 @@ class VllmEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@@ -51,7 +48,7 @@ class VllmEngine(BaseEngine):
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
"trust_remote_code": True, "trust_remote_code": True,
"download_dir": model_args.cache_dir, "download_dir": model_args.cache_dir,
"dtype": infer_dtype, "dtype": model_args.vllm_dtype,
"max_model_len": model_args.vllm_maxlen, "max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1, "tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util, "gpu_memory_utilization": model_args.vllm_gpu_util,
@@ -59,6 +56,7 @@ class VllmEngine(BaseEngine):
"disable_log_requests": True, "disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager, "enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
} }
if model_args.visual_inputs: if model_args.visual_inputs:
@@ -66,11 +64,10 @@ class VllmEngine(BaseEngine):
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2 self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values" engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>") engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
@@ -91,27 +88,49 @@ class VllmEngine(BaseEngine):
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"] if (
self.processor is not None
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
if max_length: if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1 max_tokens = max_length - prompt_length if max_length > prompt_length else 1
@@ -120,32 +139,26 @@ class VllmEngine(BaseEngine):
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=num_return_sequences, n=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=(
temperature=temperature, repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
top_p=top_p, )
top_k=top_k, or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
length_penalty=length_penalty, length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
if self.processor is not None and image is not None:
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
prompt=None, inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request, lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
) )
return result_generator return result_generator

View File

@@ -1,9 +1,16 @@
import os
import random
import subprocess
import sys import sys
from enum import Enum, unique from enum import Enum, unique
from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@@ -23,8 +30,6 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
VERSION = "0.7.1"
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
@@ -37,11 +42,14 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__)
@unique @unique
class Command(str, Enum): class Command(str, Enum):
API = "api" API = "api"
CHAT = "chat" CHAT = "chat"
ENV = "env"
EVAL = "eval" EVAL = "eval"
EXPORT = "export" EXPORT = "export"
TRAIN = "train" TRAIN = "train"
@@ -57,11 +65,34 @@ def main():
run_api() run_api()
elif command == Command.CHAT: elif command == Command.CHAT:
run_chat() run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL: elif command == Command.EVAL:
run_eval() run_eval()
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp() run_exp()
elif command == Command.WEBDEMO: elif command == Command.WEBDEMO:
run_web_demo() run_web_demo()

View File

@@ -0,0 +1,16 @@
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]

View File

@@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features from datasets import Features
from .utils import Role from ..extras.logging import get_logger
from .data_utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -14,7 +15,13 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = [] outputs = []
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for image in images: for image in images:
@@ -29,6 +36,9 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
@@ -45,21 +55,32 @@ def convert_alpaca(
if dataset_attr.query and examples[dataset_attr.query][i]: if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i]) content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list): if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else: else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = [] response = []
outputs["prompt"].append(prompt) outputs["prompt"].append(prompt)
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs return outputs
@@ -68,6 +89,9 @@ def convert_alpaca(
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
@@ -87,21 +111,62 @@ def convert_sharegpt(
else: else:
system = examples[dataset_attr.system][i] if dataset_attr.system else "" system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0: if len(messages) == 0:
continue continue
aligned_messages = [] aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages)) logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
outputs["prompt"].append(aligned_messages[:-1]) if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
outputs["response"].append(aligned_messages[-1:]) dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system) outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])

View File

@@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import Any, Dict, Sequence
import torch
from transformers import DataCollatorForSeq2Seq
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature["{}_input_ids".format(key)],
"attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@@ -10,7 +10,7 @@ if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -1,17 +1,19 @@
import inspect import inspect
import os import os
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer from .template import get_template_and_fix_tokenizer
from .utils import merge_dataset
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -57,12 +59,12 @@ def load_single_dataset(
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File not found.") raise ValueError("File {} not found.".format(local_path))
if data_path is None: if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.") raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else: else:
raise NotImplementedError raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
try: try:
@@ -105,9 +107,21 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples)) dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args) return align_dataset(dataset, dataset_attr, data_args)
@@ -116,7 +130,7 @@ def get_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
@@ -165,14 +179,17 @@ def get_dataset(
if training_args.should_save: if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path) dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
exit(0) sys.exit(0)
if training_args.should_log: if training_args.should_log:
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset return dataset

View File

@@ -20,23 +20,28 @@ class DatasetAttr:
""" basic configs """ """ basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """ """ extra configs """
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: bool = False num_samples: Optional[int] = None
formatting: Literal["alpaca", "sharegpt"] = "alpaca" """ common columns """
""" columns """
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
""" columns for the alpaca format """ """ rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction" prompt: Optional[str] = "instruction"
query: Optional[str] = "input" query: Optional[str] = "input"
response: Optional[str] = "output" response: Optional[str] = "output"
history: Optional[str] = None history: Optional[str] = None
""" columns for the sharegpt format """ """ sharegpt columns """
messages: Optional[str] = "conversations" messages: Optional[str] = "conversations"
tools: Optional[str] = None """ sharegpt tags """
""" tags for the sharegpt format """
role_tag: Optional[str] = "from" role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value" content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human" user_tag: Optional[str] = "human"
@@ -98,17 +103,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("num_samples", dataset_info[name])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "images"] column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:
column_names.extend(["messages", "tools"]) column_names.extend(["messages"])
for column_name in column_names: for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@@ -0,0 +1,84 @@
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -0,0 +1,126 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs

View File

@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))

View File

@@ -0,0 +1,36 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result

View File

@@ -0,0 +1,64 @@
import bisect
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@@ -0,0 +1,169 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING:
from transformers import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments
from ..template import Template
logger = get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i],
response=examples["response"][i],
system=examples["system"][i],
tools=examples["tools"][i],
template=template,
tokenizer=tokenizer,
processor=None,
data_args=data_args,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))

Some files were not shown because too many files have changed in this diff Show More