418 Commits

Author SHA1 Message Date
hoshi-hiyouga
57354fc990 Merge pull request #6124 from hiyouga/hiyouga/release
[release] release v0.9.1

Former-commit-id: f61cdd99fd282612884c92d36e111ad46b4e0d00
2024-11-25 00:20:02 +08:00
hoshi-hiyouga
89f240805c Merge pull request #6126 from hiyouga/hiyouga/fix_vllm
[inference] fix vllm

Former-commit-id: c5025c3ee6e67e62724cc3f34fbf8aa9968590f5
2024-11-25 00:19:54 +08:00
hoshi-hiyouga
27bbea886c Merge pull request #6010 from XYZliang/fix-#4316
Increase shm_size to 16GB in docker-compose.yml

Former-commit-id: 73194233f9f1aa8299be1360deb25b753338e168
2024-11-25 00:16:42 +08:00
hoshi-hiyouga
3ec3dda33a Merge pull request #6125 from hiyouga/hiyouga/fix_cli
[cli] remove shell=True in cli

Former-commit-id: cf3ec28baa9a9f1ba342fe3a627e85d8799a1912
2024-11-25 00:07:35 +08:00
hiyouga
ae9f338bf7 fix vllm
Former-commit-id: 9ce0e4b07e3733c015137bc93c7e6d53bf25b08e
2024-11-25 00:07:24 +08:00
hiyouga
bf44f76dc7 fix cli
Former-commit-id: 9338c287cc15c0cad8d5ddbdadfb6f64d383c034
2024-11-24 23:56:21 +08:00
hiyouga
c18581f0a4 release v0.9.1
Former-commit-id: a134ad42c65dc4d72e3083c932ddfaaa687c513d
2024-11-24 23:48:41 +08:00
hoshi-hiyouga
9f6c5c4798 Merge pull request #6123 from hiyouga/hiyouga/fix_qwen2vl_vllm
[inference] fix qwen2vl vllm infer

Former-commit-id: 5d886f99e3bd20795d5313dccf9f045d37a0aefc
2024-11-24 23:42:11 +08:00
hiyouga
7bc03ac986 fix qwen2vl vllm infer
Former-commit-id: 3ac98847fdc23129912c8994ed19a8c66fe00b8c
2024-11-24 23:27:24 +08:00
hoshi-hiyouga
85d7e4f4ab Merge pull request #6121 from hiyouga/hiyouga/readme
[readme] update readme

Former-commit-id: d603650a671c3a323f29001fd0cc53563d28f3e0
2024-11-24 03:28:09 +08:00
hiyouga
bf69747f40 update readme
Former-commit-id: 48423afe53d6f6de1a257a33019909009626a42e
2024-11-23 19:27:18 +00:00
hoshi-hiyouga
f1146bf7b6 Merge pull request #6120 from hiyouga/hiyouga/fix_ci
[test] fix ci

Former-commit-id: 573a0978b82986ec45aae16637edb6ff4af54a35
2024-11-24 03:21:11 +08:00
hiyouga
9efd1fec90 fix ci
Former-commit-id: 91c672f0147bb6eb998871a42f8a89992af88528
2024-11-23 19:13:32 +00:00
hoshi-hiyouga
3b91839a55 Merge pull request #5555 from marko1616/feat/llama3.2vl
Support llama3.2 vision

Former-commit-id: 8151dc488585d1cec6d4a0c9c6dcd46a6a57e9f0
2024-11-24 02:49:07 +08:00
hiyouga
bc4421eeef add forbidden modules
Former-commit-id: c9f4d051d0eca7515bab201afdef17f1ac1b3cb9
2024-11-23 18:34:15 +00:00
hiyouga
5003820a6a fix inputs
Former-commit-id: 7d535bb8cdf7e81edda81152e63c8cfe6c9dcc9f
2024-11-23 18:26:02 +00:00
marko1616
cd2485f28d Linter.
Former-commit-id: 719d124f65ebb18ba0a1212751da9909160fb6f1
2024-11-23 16:09:04 +00:00
marko1616
918a367378 Tiny fix.
Former-commit-id: 4c1cef12d812832eed58b5da562ba083104756d3
2024-11-23 16:09:01 +00:00
marko1616
3d35aeca72 Support llama3.2vl.
Former-commit-id: 664229d7d1f7994e1ae68c5d197ab81f081bcd2e
2024-11-23 16:07:35 +00:00
hoshi-hiyouga
53b1e5fd1d Merge commit from fork
[patch] Patch remote OS command injection vulnerability

Former-commit-id: 960897b950e29aa440afa45b4deb9d42d2f6e941
2024-11-21 22:39:44 +08:00
hoshi-hiyouga
b852c895cf do not split save_cmd ret value
Former-commit-id: 1e312072fb4a9f472e2d3fa7e6b4fb0aec00b566
2024-11-21 22:30:23 +08:00
superboy-zjc
aaa7ed8712 [patch] Patch remote OS command injection vulnerability
Former-commit-id: 4678ceea4ce334a8289caf87d86047e67c67c603
2024-11-21 01:52:12 -05:00
hoshi-hiyouga
205aca5b03 Merge pull request #6078 from wtmlon/support-efficient-tokens-calculation
support effective tokens calculation on sft/dpo

Former-commit-id: d0510e6d49b43c5ffadd8af653c3bdecc1582417
2024-11-20 13:43:15 +08:00
Ting
87b1f851f1 code refactor
Former-commit-id: ee3f85aa9677d0aeecb3bc396530d2cd7c50dce5
2024-11-19 20:33:18 +08:00
Ting
fca814b30d update
Former-commit-id: 516ed0ea5fed8c74fe3669a7e85dd89b5a0ec3c2
2024-11-19 19:12:10 +08:00
Ting
a20c2b6ecf update
Former-commit-id: a3e8ca53e654136242197a2da872cc0e5cf67880
2024-11-19 19:10:07 +08:00
Ting
fee94e1c54 support efficient tokens calculation on sft/dpo
Former-commit-id: b157d5cccdeb42412b8b440d25d5bdfa8a50be68
2024-11-19 17:15:47 +08:00
hoshi-hiyouga
047a596542 Merge pull request #6065 from hiyouga/hiyouga-patch-1
[misc] fix dep package version

Former-commit-id: 34a09e6cd1a8b1c2acddf837f1c787978bc526f5
2024-11-18 21:13:59 +08:00
hoshi-hiyouga
3d45606984 fix #6061
Former-commit-id: 4eb0b6763f0a1b3cde89bd5c69760178bb35d303
2024-11-18 20:56:44 +08:00
hoshi-hiyouga
310c107d56 Merge pull request #6052 from hiyouga/hiyouga-patch-1
[trainer] fix DPO metrics

Former-commit-id: 94add263fe874d2be1b37110faf5da7a5096df6d
2024-11-16 16:20:12 +08:00
hoshi-hiyouga
089e4d9e96 fix #6050
Former-commit-id: 028ea3d9b4fa4ab74a969ac80e61a449d6c15e74
2024-11-16 16:11:16 +08:00
hoshi-hiyouga
ae56c3cf49 Merge pull request #6046 from hiyouga/hiyouga/add_code_model
[model] add qwen-coder and opencoder

Former-commit-id: 5b485671aee8dd2f775371d0b9ff3d0d043159f3
2024-11-15 21:58:03 +08:00
hiyouga
0a0288a286 add qwen-coder and opencoder
Former-commit-id: 9669a42704cd40bdfc76ca278cc6a562549bc27d
2024-11-15 21:48:38 +08:00
XYZliang
25da686758 Increase shm_size to 16GB in docker-compose.yml to optimize shared memory allocation for large-scale model fine-tuning tasks.
This pull request increases the shm_size parameter in docker-compose.yml to 16GB. The goal is to enhance the LLaMA-Factory framework’s performance for large model fine-tuning tasks by providing sufficient shared memory for efficient data loading and parallel processing.

This PR also addresses the issues discussed in [this comment](https://github.com/hiyouga/LLaMA-Factory/issues/4316#issuecomment-2466270708) regarding Shared Memory Limit error.


Former-commit-id: de2616d103b4bdc2458874068b1a223c7de82b4e
2024-11-13 10:13:59 +08:00
hoshi-hiyouga
e2da3cc9fa Merge pull request #5990 from hiyouga/hiyouga/dev_vllm
[generate] fix vllm config args

Former-commit-id: ee0745022bd7484f4f2e6b183088f55d5e60c085
2024-11-11 14:10:35 +08:00
hoshi-hiyouga
c42e5cf401 fix #5988
Former-commit-id: 9e08e206a8ea9926768b0f1d5ff9d7e3e216c269
2024-11-11 13:57:14 +08:00
hoshi-hiyouga
9943cd1c96 Merge pull request #5982 from hiyouga/hiyouga/vllm_args
[args] add vllm config

Former-commit-id: 07d3de5c8376d3c4147411ec603da4254885d2d7
2024-11-10 21:37:18 +08:00
hiyouga
1e6f96508a add vllm config
Former-commit-id: 95365f0ce4f362bde7de8b679b54b548d7055bfb
2024-11-10 21:28:18 +08:00
hoshi-hiyouga
d401974f69 Merge pull request #5973 from JJJJerry/fix_vllm_generate
fix VllmEngine: 将inputs参数替换为prompt

Former-commit-id: d3271416a316e6b92aea3026f6941f6967215a7b
2024-11-10 21:04:38 +08:00
hoshi-hiyouga
09b2dbe859 Update vllm_engine.py
Former-commit-id: 5638fae81c180b7d91eb6aebe6629640beb217d8
2024-11-10 20:57:00 +08:00
JJJJerry
7f8ef8c132 fix VllmEngine: 将inputs参数替换为prompt
Former-commit-id: 5affb1d20921afd3fe48802ff80785e412e2e3aa
2024-11-09 11:45:59 +08:00
hoshi-hiyouga
fcb6283a72 Merge pull request #5971 from hiyouga/hiyouga/fix_webui
[webui] fix extra args

Former-commit-id: d04e21d69e60ab4a350e70da7d1abbf11cfeed0e
2024-11-09 00:25:24 +08:00
hiyouga
0027f46ccc fix extra args
Former-commit-id: 2c98a1bc3d885170f8298872c2ea2e24427fb447
2024-11-09 00:24:27 +08:00
hoshi-hiyouga
967a27695e Merge pull request #5970 from hiyouga/hiyouga/fix_beam
[generation] fix vllm v0.6.3

Former-commit-id: 571d4538568272fd59cc5621e56113329c857546
2024-11-08 23:58:15 +08:00
hiyouga
3ce8a326c6 fix #5966
Former-commit-id: a9a99b545609083533cca1fd1e5480c60ea68750
2024-11-08 23:49:16 +08:00
hoshi-hiyouga
91b56b7baf Merge pull request #5927 from hiyouga/hiyouga/dev_fixmmchat
[fix] chat engines

Former-commit-id: e9c22e2d089927eee3bce052bbf7d6502d0ac544
2024-11-04 16:36:23 +08:00
hiyouga
e2fa961302 add image input type
Former-commit-id: 6fe260e35ff12662b72f26ec9df44e87b9693551
2024-11-04 08:27:20 +00:00
hiyouga
87d6d7dc61 fix chat engines
Former-commit-id: 3a220b7992d265c77d9a1a406ef86eefbc699cfe
2024-11-04 08:18:12 +00:00
hoshi-hiyouga
00019e2ca4 Merge pull request #5926 from hiyouga/hiyouga/dev_deps
[version] update datasets version

Former-commit-id: 4a24e8fc8e1c229ef8751bd7eafe024661d46661
2024-11-04 16:04:00 +08:00
hiyouga
b104739d63 update datasets version
Former-commit-id: feba2c6418a15715fee77a34428fa3cf47fcee5b
2024-11-04 07:52:26 +00:00
hoshi-hiyouga
b238d1aa04 Merge pull request #5914 from hiyouga/hiyouga/dev_read
[misc] update readme

Former-commit-id: 2897696bad6bcc2d826845750c0c913882449829
2024-11-02 21:44:10 +08:00
hoshi-hiyouga
aa497d5d96 Merge pull request #5475 from menibrief/main
Fix phi-3-small issues 

Former-commit-id: c1daf49a967f6c0b641c9639a78971275aaa7cae
2024-11-02 21:31:34 +08:00
hiyouga
fecf04b2f4 fix phi3 template
Former-commit-id: b62131a3c5b4ff6f2969a8041e6e7b9cf2c444ed
2024-11-02 21:31:23 +08:00
hiyouga
3f157e2f6f update readme
Former-commit-id: 94bae8360b1aa124cc57dca481b9e686ba559f31
2024-11-02 21:28:04 +08:00
hoshi-hiyouga
c7c558562e update template
Former-commit-id: 3559ef6115a831dcd1adf7210995ffd62890cff6
2024-11-02 21:21:22 +08:00
hoshi-hiyouga
c2ea5fb618 Merge branch 'main' into main
Former-commit-id: 154f504fc2cebaae2b58c0121d6d8d8016db1bb2
2024-11-02 21:20:27 +08:00
hoshi-hiyouga
fa9c32bb8d Merge pull request #5913 from hiyouga/hiyouga/dev_metrics
[train] support gather DPO metrics, fix return output

Former-commit-id: a17ac67f22c4de7699a8f2c1d4980af4babd2c7e
2024-11-02 21:13:43 +08:00
hiyouga
c610deb5a2 fix webchat
Former-commit-id: 071fe40f209156f994c069507a2d53cc4f586d67
2024-11-02 21:04:18 +08:00
hiyouga
2bb3255e74 fix dpo metrics
Former-commit-id: 57029280da825a39fbf5a05097921b861f126669
2024-11-02 20:59:01 +08:00
hoshi-hiyouga
b28b74c71e Merge pull request #5880 from sd3ntato/make-image-parametric
make base image parametric.

Former-commit-id: e2ea7c8b67cf598bba2b2b298e638b23712f14b3
2024-11-02 20:26:14 +08:00
hoshi-hiyouga
1ed921bff7 Update Dockerfile
Former-commit-id: 89a1c1eb6d717b20107c06a645652b87fba388e8
2024-11-02 20:20:26 +08:00
hoshi-hiyouga
80f634cc95 Merge pull request #5910 from Cuiyn/index
Support Index series models.

Former-commit-id: b74d9fa8efeb4f52ba0e20538ad90c8b40492e29
2024-11-02 20:16:54 +08:00
Cuiyn
a3eb5e200c fix: rename to Index-1.9B-Charater-Chat and Index-1.9B-Chat-32K
Former-commit-id: 95ab64749155a781ab5e55b989388ccd9e094c8d
2024-11-02 20:04:14 +08:00
hoshi-hiyouga
2d02c0e22d Merge pull request #5912 from hiyouga/hiyouga/dev_logging
[misc] support rank0 logger

Former-commit-id: ed34a6322814f302f050ba8ca4ecc53689f4d646
2024-11-02 18:48:41 +08:00
hiyouga
093eda2ad6 support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
2024-11-02 18:31:04 +08:00
Cuiyn
dbaf621f57 Add support for Index
Former-commit-id: 4e6dba16ca1755235d2ae117b53b68c5ae2f239a
2024-11-02 13:45:27 +08:00
hoshi-hiyouga
ceb701c2d4 Merge pull request #5909 from hiyouga/hiyouga/dev2
[data] support auto convert for single image, add image_dir argument

Former-commit-id: ced43fa0c84f7d0792694721d2c5e572c0d0e718
2024-11-02 13:43:04 +08:00
hoshi-hiyouga
29ad3783f5 Merge pull request #5907 from hiyouga/hiyouga/dev
[data] fix template replace behavior

Former-commit-id: 0a51c0bfdd9b193d2a3ac34a62fe8b073569c41a
2024-11-02 13:42:53 +08:00
hiyouga
fa2386e73c fix #5904
Former-commit-id: 079ebe038b11f36a11681dc8688f8ea48bccf324
2024-11-02 13:08:15 +08:00
hiyouga
e0045e8386 fix #5883
Former-commit-id: 73b93caa9ac16ffd8d3faae24d16210d85ae9754
2024-11-02 13:06:34 +08:00
hoshi-hiyouga
b94c941196 Merge pull request #5906 from hiyouga/dev
[test] update tests

Former-commit-id: f95f2824b3c078508408da23e1958292dc96d0fa
2024-11-02 12:50:43 +08:00
hiyouga
ba66ac084f update tests
Former-commit-id: 4e92b656e324725048d914946e70867be20032ff
2024-11-02 12:41:44 +08:00
hoshi-hiyouga
83479c9ef0 Merge pull request #5895 from hiyouga/dev
[inference] support multiple images

Former-commit-id: 491132e5db483fd00aa9f3cbc201b8fb83693f57
2024-11-01 16:52:55 +08:00
hiyouga
df8ac15ef0 add examples
Former-commit-id: 9eff9625adba643263bc6cba480f30edc6bb086a
2024-11-01 08:41:54 +00:00
hiyouga
8cea5cd967 support multiimage inference
Former-commit-id: 8083e4607549e805eb308c4e93c8aa256202f438
2024-11-01 07:25:20 +00:00
Valerio Mariani
a2d7d6a518 make base image parametric.
default `BASE_IMAGE` is nvcr.io/nvidia/pytorch:24.02-py3 for retro-compatibility


Former-commit-id: db8d00536acb02b29d10a3d735438d194656ece3
2024-10-30 21:53:32 +01:00
hoshi-hiyouga
a63e624eca Merge pull request #5873 from hiyouga/dev
[misc] update readme

Former-commit-id: e02c3bea981dff6beae45a9428d5d88d210db5e1
2024-10-30 17:14:44 +08:00
hiyouga
8596c321ce update readme
Former-commit-id: b3d3b440e8879198603da042441d4b4f84296109
2024-10-30 09:14:01 +00:00
hoshi-hiyouga
54cd799aa0 Merge pull request #5871 from hiyouga/dev
[loss&ui] fix incorrect loss of vlms, add extra args to ui

Former-commit-id: 5f4a62b600ab47db6aab3a1f831ecfe1df4335d9
2024-10-30 17:13:17 +08:00
hiyouga
8185eb1890 fix incorrect loss value for vlms
Former-commit-id: 0aa29a71ce958343a2086090d647eb63b8f5f5be
2024-10-30 08:56:46 +00:00
hiyouga
03213984ec tiny fix
Former-commit-id: b8f4b145506851cf5488cd8551a04d1c7603019b
2024-10-30 08:56:29 +00:00
hiyouga
aeeee9d4b5 support extra args in llamaboard
Former-commit-id: da0a5fd612e2214cc4bcb72516efd768fbe18a20
2024-10-30 08:55:54 +00:00
hoshi-hiyouga
c8a1fb99bf Merge pull request #5581 from Kuangdd01/pixtral-patch
[WIP] Support Pixtral-12B

Former-commit-id: fcddf4ec5c2914f73e23eeda2dbf67b048246669
2024-10-29 22:29:10 +08:00
hoshi-hiyouga
f0181a41ff fix bug
Former-commit-id: e69665746d9fcd17a92ace7d5d9c8de1fc0c29b7
2024-10-29 22:19:04 +08:00
hoshi-hiyouga
f6b06d0c6f Update mm_plugin.py
Former-commit-id: 830315cb438e75b589017fd57f70d0a513780a53
2024-10-29 22:16:22 +08:00
hoshi-hiyouga
1047217f78 Update template.py
Former-commit-id: 99a01547ca31adade1c48feae5796e06b73d387c
2024-10-29 22:11:21 +08:00
hoshi-hiyouga
16a9a44849 Update visual.py
Former-commit-id: 6f1db7b9abfbdea1781452388d66df3e9f9a5dd9
2024-10-29 22:10:29 +08:00
hoshi-hiyouga
58fb24ce41 Update collator.py
Former-commit-id: 941fa8a0d9c3a9106ad0af6e776db7e57f69548f
2024-10-29 22:03:42 +08:00
hoshi-hiyouga
a9afffa246 Update hf_engine.py
Former-commit-id: 7412a8b95678ca6827a8c42c9f4d38115fede897
2024-10-29 22:00:59 +08:00
hoshi-hiyouga
1fdd053022 Update README_zh.md
Former-commit-id: e14535aa97062d0e57bbf1230c050f2c56a45556
2024-10-29 21:58:03 +08:00
hoshi-hiyouga
0a833968a0 Update README.md
Former-commit-id: 65be32f6b12c2be80a12a4e903001820f64a0833
2024-10-29 21:57:28 +08:00
hoshi-hiyouga
58b681de78 Merge pull request #5801 from NLPJCL/main
使用了 LLaMA Factory 的项目:RAG-Retrieval 使用LLaMA-Factory作为生成方法做Reranker任务的微调框架。

Former-commit-id: cc9995cc99a7d7ba2958094bcd3d597eddc349e3
2024-10-29 21:20:16 +08:00
hoshi-hiyouga
22d5fc5f4c Update README_zh.md
Former-commit-id: 9e356805aa631810fd5897cb6a6cfc1fe0e939ab
2024-10-29 21:19:17 +08:00
hoshi-hiyouga
cc0119f698 Update README.md
Former-commit-id: 9181486c630bca23f68868128c9b0e04a0d7cea4
2024-10-29 21:18:15 +08:00
hoshi-hiyouga
580cedebde Merge pull request #5857 from hiyouga/dev
[train] fix saving processor

Former-commit-id: 5aaa90124483c8b54225797fa91065ed072d171a
2024-10-29 21:12:04 +08:00
hiyouga
43bd1b070c fix #5749
Former-commit-id: c36c5c61fc022b3f144d4c798ec584c4954b0181
2024-10-29 13:02:13 +00:00
Kingsley
42aa9c65be Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 438302edfdb66b6397266b8b17ac66f60a89300c
2024-10-29 21:01:25 +08:00
hoshi-hiyouga
b0b87fa33f Merge pull request #5852 from hiyouga/dev
[misc] several important updates

Former-commit-id: 5bc5ddf3b62abc132df08be477ffb46e9257e2ba
2024-10-29 20:30:02 +08:00
hiyouga
22912eba1a fix pissa
Former-commit-id: 4ac65a318b87249d42ffa73cbd3b33f0934f2afa
2024-10-29 12:18:45 +00:00
hiyouga
e2748fa967 fix #5747
Former-commit-id: 26d07de349c98b547cd6a6166ea20616d08ba343
2024-10-29 10:47:04 +00:00
hiyouga
248d5daaff use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
2024-10-29 09:07:46 +00:00
hiyouga
8f5921692e update requires
Former-commit-id: cae0e688ddcead370821e126c192bddc53ff6017
2024-10-29 16:10:07 +08:00
grok
e880eb8844 Update README_zh.md
Former-commit-id: e0c4aa091e71bcb4be44f5a07bdda5df6b949af2
2024-10-23 23:50:56 +08:00
grok
dc076c4e52 Update README.md
update english readme

Former-commit-id: c295a8b549603ec1d58f460c041401e1393d18b5
2024-10-23 23:49:47 +08:00
grok
8306e93ef3 Update README_zh.md
Former-commit-id: 77e39e7c34410a24055ab63cc088e6ec768d49c7
2024-10-23 23:36:14 +08:00
hoshi-hiyouga
6a2cd129c0 fix #5797
Former-commit-id: 71d23ed3444f24b31785d9f0f6dd711f6f516731
2024-10-23 20:49:44 +08:00
KUANGDD
30d7f6a22e rm comment
Former-commit-id: 80b58eaaec1996571d24b2dc2b73859cc28911a1
2024-10-23 15:50:59 +08:00
KUANGDD
5440ebbae6 rm useless code
Former-commit-id: 2dc337a49a8646ce916981b2914718e7472b5946
2024-10-23 15:38:11 +08:00
KUANGDD
22dbe694e9 Merge branch 'pixtral-patch' of https://github.com/Kuangdd01/LLaMA-Factory-X into pixtral-patch
Former-commit-id: 10c58488558549c382f9bba43c487d7f9222f16e
2024-10-23 15:32:50 +08:00
KUANGDD
64ac6ca396 rm import torch
Former-commit-id: 561a0f8155afca20ac699e124320b0eaef6dac07
2024-10-23 15:32:33 +08:00
Kingsley
377d37fa7f Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: f3ad96aea6f2602981bf5f27d2bbd1f729d11aa0
2024-10-23 15:30:03 +08:00
KUANGDD
55296744a8 Merge branch 'pixtral-patch' of https://github.com/Kuangdd01/LLaMA-Factory-X into pixtral-patch
Former-commit-id: 3c1694157d61d88fd53fb3c9197196013b98e0e7
2024-10-23 15:28:19 +08:00
KUANGDD
d0889012c2 modify style & little change
Former-commit-id: c988477d14dc656450d5fec31895781b7f9f7dce
2024-10-23 15:24:07 +08:00
hoshi-hiyouga
3a8b2890eb fix test
Former-commit-id: a0a23f79d2d94d68e3bf1e90b95beff817bc409c
2024-10-22 12:35:36 +08:00
hoshi-hiyouga
5b2284a51d fix #5768
Former-commit-id: 9f9e3fd186ce917f0b323c8cd42cf050ed238c58
2024-10-22 11:06:22 +08:00
hoshi-hiyouga
4807d8a4ef Update misc.py
Former-commit-id: fe9a927f1ea8e44e0429b437e5feecf13e34e9aa
2024-10-17 19:48:51 +08:00
hoshi-hiyouga
c6e1313977 Update loader.py
Former-commit-id: 3b229a27a108b840e6bed3c8684737f51ce9faf4
2024-10-17 19:48:12 +08:00
hoshi-hiyouga
66819fd3ee Update README_zh.md
Former-commit-id: a829d4a28fae77b08a6ea451479c71578b3b552f
2024-10-17 19:47:33 +08:00
hoshi-hiyouga
bd85e370be Update README.md
Former-commit-id: f62b0682e476dd62a4a3ac5620f8fc244e8bf150
2024-10-17 19:46:36 +08:00
BUAADreamer
cc097174cc tiny fix [skip ci]
Former-commit-id: 937f69190e529fe7bf0fdf58d7bbb39017854c5e
2024-10-16 15:55:30 +08:00
KUANGDD
7d135bbdb8 remove useless codes
Former-commit-id: 01247fcdde215398ec67cbd6cf1bc6cfb512a9ba
2024-10-16 01:14:51 +08:00
KUANGDD
4845a76535 fix bug for webui infer
Former-commit-id: 17768832908cc59ab64ed72522b2954c575ce21d
2024-10-16 01:09:33 +08:00
Kingsley
67645c0db8 Merge branch 'pixtral-patch' of https://github.com/Kuangdd01/LLaMA-Factory-X into pixtral-patch
Former-commit-id: 995eae4333f4346734d76f7d18cfffb5147e2f7b
2024-10-15 17:09:56 +08:00
Kingsley
f463b3f038 add extra test for pixtral mm_input
Former-commit-id: c706ec8a5dbd3c72ab15a709668624c0c7bbd8ce
2024-10-15 17:09:24 +08:00
BUAADreamer
01defc2779 tiny fix [skip ci]
Former-commit-id: 95f968eec2628cb26b3c4f4d4e81a9536e23cc31
2024-10-15 13:53:33 +08:00
Kingsley
c9e77ab352 Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: da6eb7bab2b4e551366d33b81083773cfd45ec08
2024-10-15 13:41:10 +08:00
BUAADreamer
c3de160d1c fix some
Former-commit-id: c9b644693996f96d234349823911fc267635acb9
2024-10-15 13:30:41 +08:00
KUANGDD
3693d7b571 plugin test & check
Former-commit-id: 76c7c8c5a729b8b43e3a31efc44f2c9c2678bf3d
2024-10-15 12:12:46 +08:00
hiyouga
a63144c28f fix #5705
Former-commit-id: 0c85fd253f860eee3c7b9b5a4e77ffbf93af372a
2024-10-15 10:10:16 +08:00
KUANGDD
2b3b0473cd required transformers version
Former-commit-id: d9915db327a038c93b5e3421c90b1f218fb23f92
2024-10-14 21:11:09 +08:00
Kingsley
9d929897ce remove bs condition
Former-commit-id: bf3520178ab66058c62a9cf31b42f36a9d88ce20
2024-10-14 16:55:59 +08:00
Kingsley
313a5e1494 Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 28696e2f945a9f55e4ca9e9dc5ebd8af9df45d8b
2024-10-13 17:42:02 +08:00
hiyouga
74dd25224a fix #5668
Former-commit-id: 116f2946201d55305f6b57b3f926670a3e2173c8
2024-10-12 01:24:43 +08:00
hiyouga
c7efc7f2ed tiny fix
Former-commit-id: 1fe424323b212094856f423351dc2a15774d39c3
2024-10-11 23:51:54 +08:00
hoshi-hiyouga
c71c78da50 Merge pull request #5665 from johnnynunez/main
vllm 0.6.3

Former-commit-id: 6f8a9581fa406e255ca6955794f16cc06b5cf287
2024-10-11 23:45:58 +08:00
hoshi-hiyouga
f4897da009 Merge pull request #5642 from huniu20/main
[hub support] add modelers hub support

Former-commit-id: ea96c8ba3f81546df1311ca738ff961aa4ef7446
2024-10-11 23:45:17 +08:00
huniu20
a6951db970 bugs fixed
Former-commit-id: 5457ba7512d70564ea784b9ec6bdb86cfd2d7e3d
2024-10-11 19:56:13 +08:00
Johnny
9d27aaa38f Update parser.py
Former-commit-id: 60b13c86f4feaffbb43f5a23a28376fe416ed118
2024-10-11 12:29:33 +02:00
Johnny
3b19b6f31b Update setup.py
Former-commit-id: f85b756ffafa241304624819b7612603ad5e0ee3
2024-10-11 12:29:09 +02:00
huniu20
5b15ca0b0b add om_hub_token argument
Former-commit-id: b3214e69d32067a1c22dbd60c2cde1545ba75b19
2024-10-10 17:16:46 +08:00
huniu20
aad79127e6 1. add model and dataset info to support webui
Former-commit-id: 92f6226f3fecbd9af744a7232dda2c68b2bb0d86
2024-10-10 16:46:34 +08:00
huniu20
c42dcab32b 1. add modelers hub support
Former-commit-id: 14678eb444d8181176745d18d4a6865fd6860f58
2024-10-09 17:21:37 +08:00
Kingsley
be519c84d9 Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 2076d00dfbe1279a91207157fd6d9a118427626a
2024-10-08 21:04:08 +08:00
hiyouga
b2dc6dc59a tiny fix
Former-commit-id: d8ddd07c2ed14d871fb25743c20265fc99e3e221
2024-10-08 17:48:56 +08:00
hoshi-hiyouga
9df626dc18 Merge pull request #5546 from chengchengpei/cpei/refactor
1, log exceptions in details; 2, check processor is None before calling it

Former-commit-id: 81c23ebdd7ef46102437b1d352818fe205fa3851
2024-10-08 17:46:54 +08:00
hoshi-hiyouga
8d4b9200a1 Merge branch 'main' into cpei/refactor
Former-commit-id: c2951f17f726470bcd5dff6bf7028ec90212442e
2024-10-08 17:31:17 +08:00
hoshi-hiyouga
7806df46ba Merge pull request #5615 from johnnynunez/patch-1
Update setup.py (Compatible with Jetson)

Former-commit-id: baa3cd4c0db2502cf8a606e034df20492a83e6b2
2024-10-07 16:50:34 +08:00
hoshi-hiyouga
bba026a212 Update parser.py
Former-commit-id: e7d291605f184f6ac48429015e15755192d2f274
2024-10-07 16:27:23 +08:00
hoshi-hiyouga
6e111eb29f Update setup.py
Former-commit-id: 4c017fe014b708d79c65eff24329b9c324399461
2024-10-07 16:26:50 +08:00
Johnny
2b69ae0eb2 Update parser.py
Former-commit-id: 55c449b54aec04e2141bffe75d4016cbac9ef4c5
2024-10-07 10:17:45 +02:00
Johnny
13d73574ef Update setup.py
Former-commit-id: 73d3f93496712edace38711613e14768922d6c96
2024-10-07 10:16:53 +02:00
hiyouga
bc264807ae update readme
Former-commit-id: 915f25e9b34fc4554fd1198a383f96a2536fec60
2024-10-07 11:31:18 +08:00
Johnny
f9815dd20a Update parser.py
Former-commit-id: f832edc8dc0e2b78c12dc8edd702fe147a0a5292
2024-10-06 20:34:19 +02:00
Johnny
1f58943b32 Update setup.py
Former-commit-id: b4de2c84b078194bb6358697fd6815d622843f58
2024-10-06 08:53:55 +02:00
hiyouga
6476507429 fix #5611
Former-commit-id: 3bef07ecf0557999bb0b33b650a778addc8e5b91
2024-10-06 10:34:55 +08:00
hiyouga
35862d19ec fix #5611
Former-commit-id: 76c813d37c1d945a8bb6d3e4168e15fbe97c7a87
2024-10-06 10:33:11 +08:00
Kingsley
1272cb00df Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: 9372ac93f304db438383d539ccd00bffe7415dbc
2024-10-01 00:52:31 +08:00
Kingsley
e9ac26db4c unfactor md
Former-commit-id: 1a79d61f8d25a4c1127c2f393418e14ab9d2abd4
2024-09-30 23:36:16 +08:00
hiyouga
20ee1d2e19 fix #5542
Former-commit-id: cf28e7418c2eb07e86923a53ef832ef218e45af1
2024-09-30 23:28:55 +08:00
Kingsley
cbc1dd0c88 sync with former
Former-commit-id: f8707e52586182144c4fb70c7c0de8bf7044ef5e
2024-09-30 20:27:05 +08:00
Kingsley
870bbabbc4 register model fix
Former-commit-id: 077d8e3c0344d944705254cc5a2cd06c9f5dc116
2024-09-30 20:04:47 +08:00
Kingsley
8fd84c375e fix some errors due to inconsistency of model cards
Former-commit-id: dd83265b9b8768eb8732f59ace128dfe4aac1c47
2024-09-30 19:58:34 +08:00
Kingsley
32b5364051 Merge branch 'hiyouga:main' into pixtral-patch
Former-commit-id: df0baeaa3fd093433d92b7921d3a57d88061d6d4
2024-09-30 19:33:29 +08:00
hiyouga
cf72aec098 add patch processor func
Former-commit-id: 0cd6327da6a044b4a62f203a662e5bb6068d9c29
2024-09-30 17:07:43 +08:00
hiyouga
87849d12d2 lint
Former-commit-id: d7564365f4008e468f89102879d6e65c627ad447
2024-09-30 17:00:33 +08:00
hoshi-hiyouga
a19512436f Merge pull request #5585 from shing100/main
Support EXAONE3.0 Model

Former-commit-id: 2fba28d586757bbb3ac57e4dd10c756381766b51
2024-09-30 16:56:08 +08:00
hoshi-hiyouga
6c89d93aea Update constants.py
Former-commit-id: 7c04e1caea38fd1e1e9abcf8ed1bbdc24ddd6df1
2024-09-30 16:47:52 +08:00
hoshi-hiyouga
345f40a660 Update template.py
Former-commit-id: d893289b595c0530b5aeb8902369885118809b86
2024-09-30 16:39:48 +08:00
Zhangchi Feng
8b9a814653 Merge branch 'main' into pixtral-patch
Former-commit-id: 0cf52d48fbc505e2fba29e5df0f2e6722db7ac79
2024-09-30 12:37:03 +08:00
shing100
05fabf9095 fix chat template Exaone3.0
Former-commit-id: 2e32864b59c1ef1a78f3eb1c28fbf578cfaa19cd
2024-09-30 09:44:21 +09:00
Geun, Lim
95eede911a Update README_zh.md
Former-commit-id: c4bf9d86e14a9d7a5ed5f9c49d73006d13df2707
2024-09-30 09:25:02 +09:00
Geun, Lim
7bc7f7d673 Update README.md
Former-commit-id: d014eb931cd9ed70abb8a466281668a0b00ba9f9
2024-09-30 09:24:44 +09:00
shing100
054fdbe186 update docs Support model Exaone3.0
Former-commit-id: e6fbf8fd7c84cfb11a0a4a173657b1541806b5f9
2024-09-30 09:19:27 +09:00
shing100
f0f80819a0 add Exaone3.0 template
Former-commit-id: f7478af1d04353ab13236323e3bfb96fd2870fce
2024-09-30 09:18:25 +09:00
hoshi-hiyouga
e702678252 Merge pull request #5574 from BUAADreamer/main
support llava-next(video)/video-llava

Former-commit-id: bf7611e15a7e7ee9fb870efeba9bdac358c6d462
2024-09-30 00:22:43 +08:00
hoshi-hiyouga
553579986a Update common.py
Former-commit-id: 7f7f4b67b8b757e3787a78993cf083552cd5fbbd
2024-09-29 23:58:09 +08:00
hoshi-hiyouga
622cb04f27 Update README_zh.md
Former-commit-id: 01ee426c745f522bd0dee79ace2c6b2eb52d0510
2024-09-29 23:56:32 +08:00
hoshi-hiyouga
f3ba11a432 Update README.md
Former-commit-id: 45b79a78f62a1d916083f8c74ebf08ad0fb8fe6f
2024-09-29 23:55:55 +08:00
hoshi-hiyouga
8b1f53bca5 Update README.md
Former-commit-id: 0bcf6a30ae95d5c76e477f829f6ba633d9ccdd64
2024-09-29 23:55:21 +08:00
hoshi-hiyouga
ac25fef80e Update constants.py
Former-commit-id: a0dd90fa41fc10d7944521d95a312631be64af8f
2024-09-29 23:45:34 +08:00
hoshi-hiyouga
15f819d273 Update test_mm_plugin.py
Former-commit-id: 8490ba1bb3b429d10c5a1cf791aa1bfe3547fd5f
2024-09-29 22:59:47 +08:00
BUAADreamer
f2d1c43d28 fix template
Former-commit-id: cfd05bb009895a936c59f3d97afebf2ed8006f84
2024-09-29 22:56:36 +08:00
BUAADreamer
464acc7d6c fix template
Former-commit-id: 6291c933448022ae80fd85d7f1d785bf6c0fcb25
2024-09-29 22:55:45 +08:00
BUAADreamer
a96c5da737 fix constants
Former-commit-id: e66a338410be6812064a119d8c6a6644e0f035d1
2024-09-29 22:40:43 +08:00
BUAADreamer
28d09b81c9 Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: 2358bdde973dfde3abff251d02f7622e9c144e4d
2024-09-29 22:00:35 +08:00
BUAADreamer
a769d0e3d4 fix constants
Former-commit-id: 69309a23598995aa1937fd8d80732a018c18db87
2024-09-29 22:00:01 +08:00
hoshi-hiyouga
1b98b5e65c Update requirements.txt
Former-commit-id: bd3b235904aae267ead8db1809d06d6935d2ea30
2024-09-29 21:51:23 +08:00
BUAADreamer
3cc5408da7 fix style
Former-commit-id: dc1bdcb69e6f2c605a2c533dab15613affc902f4
2024-09-29 21:39:37 +08:00
Zhangchi Feng
689f5c4554 Merge branch 'main' into main
Former-commit-id: 7566589b820e6030269523e9d08c312594f893ae
2024-09-29 21:32:54 +08:00
BUAADreamer
ab5d042cd3 add more llava-next series template
Former-commit-id: 93f64f2aebf41582d39aa8a2c6059e562ca694b0
2024-09-29 21:29:29 +08:00
BUAADreamer
4d43317aa1 Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: bf6d6eb0bfe00453a77bbe42a3842b856dd2e47f
2024-09-29 20:55:23 +08:00
BUAADreamer
ed3b0c5b40 fix readme_zh
Former-commit-id: b663d664793b79c02db1b91d206dea2beb168e26
2024-09-29 20:55:18 +08:00
hoshi-hiyouga
67a97794ee Update mm_plugin.py
Former-commit-id: 507de0df036e39eae3a3887ded9165bd918ee48f
2024-09-29 20:54:04 +08:00
hoshi-hiyouga
2c7c93cb9b Update mm_plugin.py
Former-commit-id: b8be270f9c97bfcaf431bbd9f06c4c0b83980539
2024-09-29 20:53:34 +08:00
BUAADreamer
4d4fe08d14 fix readme_zh
Former-commit-id: 4621cc3e0b8a5dc7fcfa7cf2d60ff1838aef9a1a
2024-09-29 20:46:47 +08:00
BUAADreamer
85a919b6f7 fix readme
Former-commit-id: 867e7e70dbff207dbd78668af09a638654937f71
2024-09-29 20:45:02 +08:00
BUAADreamer
fe2abe20fc tiny fix
Former-commit-id: 0c7c875d55bc45795a41c0b8a5c407d72b1f3d8d
2024-09-29 20:38:46 +08:00
BUAADreamer
12444720db fix style
Former-commit-id: 7b922803586c05981cd095cfb730061091f0204c
2024-09-29 20:30:57 +08:00
BUAADreamer
510faf5805 fix tests
Former-commit-id: e932907f6f6473bd6917d61a464366cc9918f66c
2024-09-29 18:00:45 +08:00
BUAADreamer
722e01c8ab fix some
Former-commit-id: aeca8c0f978cb9754e0526b40cd431aaf867044f
2024-09-29 17:55:40 +08:00
hoshi-hiyouga
6050e6cff9 update readme
Former-commit-id: e5c8634cbd4e00459894c031ef0e10fcc6ef5775
2024-09-29 05:02:44 +00:00
hoshi-hiyouga
c8abbe4fc3 Merge pull request #5580 from amrear/main
made a small change to a warning about fa2 for gemma2 models.

Former-commit-id: 5e2d90ab976dd55b8c61a68e929d7e5b3583156c
2024-09-29 12:45:03 +08:00
BUAADreamer
f2881c9d4a fix some params of visual regularize
Former-commit-id: 15cbc35af4559dad73c09317e82a63571a8c3540
2024-09-29 12:38:25 +08:00
hoshi-hiyouga
1ded3abdf1 Update attention.py
Former-commit-id: 2adf79c195053bb4541e0317573a2c89da28b5bc
2024-09-29 10:47:41 +08:00
Kingsley
e641f1215a Tiny fix
Former-commit-id: ae66e1a545f4cd209a57fd824f9bfb7e94436cba
2024-09-29 00:00:23 +08:00
Amirreza A
ca736bcab7 made a small change to a warning about fa2 for gemma2 models.
Former-commit-id: e0695a026d822c896cb4f5b33e0c4f88441d75e9
2024-09-28 19:03:36 +03:30
Kingsley
bddb2646bd tiny fix
Former-commit-id: 35bc71b2a68fd303798c35fe22ad29ceea87cf9b
2024-09-28 22:50:53 +08:00
Kingsley
e4c57f54f8 remove some unnecessary if conditions
Former-commit-id: 482d3e5ff3338385da664475fee88c7dc623c993
2024-09-28 02:14:06 +08:00
BUAADreamer
6de82ca843 fix some
Former-commit-id: 12e509da85af76ccf1e9a879a78e450a7b70cc4b
2024-09-28 01:15:33 +08:00
BUAADreamer
b2c02df555 modify some style
Former-commit-id: 36bc408b8296cfc6d565b2f968fb1059bc6d1305
2024-09-28 01:07:38 +08:00
BUAADreamer
ca86d6361e add tests
Former-commit-id: f0ed66bf6f9b45e0c3fddb5179a93363f5a4194f
2024-09-28 00:59:14 +08:00
BUAADreamer
b6fb00e046 add llava-next/llava-next-video/video-llava
Former-commit-id: a4e4239931b0b0e3fd12c9f9bbfd2c201cbc78ca
2024-09-28 00:57:03 +08:00
Zhangchi Feng
86c84972c8 Merge branch 'hiyouga:main' into main
Former-commit-id: 2695dcdf468f9e39e3aeec7892eb3dad399736ee
2024-09-27 18:14:39 +08:00
Kingsley
9390927875 add pixtral template
Former-commit-id: c7b4e47e0fda955272ccd6340b2047fd92acbfcf
2024-09-26 17:14:51 +08:00
Kingsley
c4a585f232 Merge branches 'pixtral-patch' and 'pixtral-patch' of https://github.com/Kuangdd01/LLaMA-Factory-X into pixtral-patch
Former-commit-id: 197bb14e6308bdf9af65eafe7bf06b36dbf96df6
2024-09-26 12:18:25 +08:00
Kingsley
300feb3245 add pixtral template
Former-commit-id: e0bcaa6c6e902e29361438a6d215bbc2535b648f
2024-09-26 12:11:58 +08:00
Chengcheng Pei
cacafb0038 address comments
Former-commit-id: 6311bb2ca266ce156537cfa477202b2904921593
2024-09-25 21:07:51 -07:00
hoshi-hiyouga
6509114259 Merge pull request #5547 from marko1616/chore/llama3.2
Chore: Support llama3.2.
Former-commit-id: 979ecc92a0db6b90ed8249d9a17120d5ed18b6aa
2024-09-26 11:38:34 +08:00
hoshi-hiyouga
7d4cb79822 add modelscope models
Former-commit-id: 4de3081eea9cede78a1f2db65cf22a5731c54447
2024-09-26 11:22:48 +08:00
marko1616
b867e164fe Chore: Support llama3.2.
Former-commit-id: 2741ac784c1a776bd545fa6dffc07b6346273519
2024-09-25 16:08:44 -04:00
Chengcheng Pei
26bbfc084d 1, log exceptions in details; 2, check processor is None before calling it.
Former-commit-id: 0f0a4813db9ca4e9bb5762a781a0a214129284a6
2024-09-25 12:59:48 -07:00
hiyouga
c376eed31d fix ci
Former-commit-id: f354593ca9b13e542fccd8fe2b64ea0ec4db78b2
2024-09-25 23:14:17 +08:00
hoshi-hiyouga
7c595abc38 Merge pull request #5533 from StrangeBytesOrg/add-docker-args
Add additional install options to Dockerfiles

Former-commit-id: c52aa3d5323e270f6b50a51d97a92e79138b7293
2024-09-25 23:04:57 +08:00
hiyouga
c428ab68d8 optionally replace jinja template
Former-commit-id: f15dec3001f785eeac1ed9cc545fab96bac2c4fd
2024-09-25 23:02:02 +08:00
hiyouga
968b9f1852 update readme
Former-commit-id: 826a47909f22b72228cd8944875a13f5f65232b1
2024-09-25 20:13:04 +08:00
hiyouga
018266c66e update readme
Former-commit-id: fe482183ae9d19cc42f78b5cd144ef21b93ec8d1
2024-09-25 19:39:52 +08:00
StrangeBytesDev
111c644bf1 Add additional install options to Dockerfiles
Former-commit-id: 5310af2f2ac8d226b95785d6b1eb0632312871a7
2024-09-24 16:54:46 -07:00
hoshi-hiyouga
de72d1f0e7 Merge pull request #5483 from whybeyoung/main
fix: 修复function call数据集如果 function_call 值的为不合法json,异常提示且中断训练。
Former-commit-id: 9e36ebebd087cd3b128b9426255d420f3c94353c
2024-09-19 17:01:52 +08:00
hoshi-hiyouga
8bfb856923 flat string
Former-commit-id: f1e7731075e6ded4a5ecac7ef46ca4a318b91597
2024-09-19 16:43:42 +08:00
hoshi-hiyouga
8fdbaab95d lint
Former-commit-id: dd94fdd69c8f36df80d6d70d63ab7403a0e55d46
2024-09-19 16:21:43 +08:00
hoshi-hiyouga
a01668bbe8 fix bug
Former-commit-id: b6d0ee1fd8b555bc6aac8b8686c9a3eea784c3a8
2024-09-19 16:21:21 +08:00
hoshi-hiyouga
3385616a37 improve error message
Former-commit-id: e7735dd487ae4e31c34dcd8e2ea9af0a39d1cf9e
2024-09-19 16:06:00 +08:00
ybyang
1f0d89328d fix: 修复function call数据集如果 function_call 值的为不合法json,异常提示且中断训练。
Former-commit-id: 625a0cd7cb5725a0f76c8c19cd23d6c0275bd146
2024-09-19 15:00:10 +08:00
menibrief
a7feab45d5 fix phi-small template
Former-commit-id: 48fb6bae6245dc6d5f72ebfc1c2bd9ffacd51b86
2024-09-18 23:52:30 +03:00
menibrief
f34322afd7 Update README.md
update readme to phi-small template

Former-commit-id: e9df26aa45f916ab0756db3329dff48dcdfce1f1
2024-09-18 23:51:36 +03:00
hoshi-hiyouga
3815fa40b7 tiny fix
Former-commit-id: 1f45d18a780c2aa501f060688a09ff04071379b9
2024-09-19 02:20:24 +08:00
hoshi-hiyouga
c43050b3fa Update README_zh.md
Former-commit-id: 750c57cbcee3ecdd6a9096f1569b9bee282d5ac7
2024-09-19 02:17:59 +08:00
hoshi-hiyouga
3e152872ad Update README.md
Former-commit-id: 40b0e51092289dbf1f2a112cd8c36df399314c8b
2024-09-19 02:16:16 +08:00
hoshi-hiyouga
ae6ad55758 fix webui
Former-commit-id: aa6e65b24451fe9f65d58e5eca5a56eb9aba71e8
2024-09-19 02:13:39 +08:00
hoshi-hiyouga
0118a2fc04 add qwen2.5 models
Former-commit-id: 408a7d7b2e1a2316cbeefade872b732c88191b75
2024-09-19 02:07:54 +08:00
hoshi-hiyouga
4dd81976f4 Merge pull request #5438 from aliencaocao/patch-1
Add qwen_vl to liger kernel supported list

Former-commit-id: c706ff61dc3e5c152a10789c7524844e2be554a2
2024-09-16 13:40:02 +08:00
Billy Cao
2b4da8baf6 Add qwen_vl to liger kernel supported list
Former-commit-id: 053b2d832450cb6cd6af673b9fc51404f1fb1e41
2024-09-14 19:28:20 +08:00
hoshi-hiyouga
7d1b4071e8 Merge pull request #5427 from HardAndHeavy/update-rocm
Update the ROCm version to 6.2

Former-commit-id: 5dcdf5d16590b59004be9d728887781729344ea0
2024-09-13 10:25:47 +08:00
HardAndHeavy
8fc5377f50 update the ROCm version to 6.2
Former-commit-id: a6eda6a500daa4f3383a7868f6abe2434f967b1d
2024-09-12 23:46:33 +03:00
hiyouga
e5812f261d update ci
https://github.com/huggingface/transformers/pull/33436

Former-commit-id: c723f16cdb919cedbf938d51d422ad49b9c6eecf
2024-09-11 20:44:42 +08:00
hiyouga
f7e85cd7de set dev version
Former-commit-id: 39edf597f050bcb2099a10d6f6018f96e29b7e65
2024-09-11 18:56:37 +08:00
hiyouga
749395420b remove windows in ci
Former-commit-id: 56046767c086853b6d40fbc42e0ed9662546de6b
2024-09-11 18:14:39 +08:00
hiyouga
7d536d1d75 fix ci
Former-commit-id: 627f30200068f58d06eb53b1b4797ed426c9c1f1
2024-09-11 18:01:09 +08:00
hiyouga
7fd0d2fc2f fix #5411
Former-commit-id: 392bdaf1ea9e5baf6289f2d4415a175dd55a479d
2024-09-11 17:36:42 +08:00
BUAADreamer
ec696bbcdd try to past test
Former-commit-id: 2db97e1e5e06370375f4f5c577671524e399321f
2024-09-10 13:29:09 +08:00
BUAADreamer
df24345d65 try to past test
Former-commit-id: 76a4cfcb84b55467792318dc15a5fbcd6807b674
2024-09-10 13:25:30 +08:00
Zhangchi Feng
386dd26097 Merge branch 'hiyouga:main' into main
Former-commit-id: 8619ad7dc124c50e254b1bb2e173ff99ca4f0e22
2024-09-10 13:20:24 +08:00
BUAADreamer
514f976cc1 try to past test
Former-commit-id: 3b6bfae0e5fe795a70d530b2765f27d95c5862f8
2024-09-10 13:12:51 +08:00
BUAADreamer
66b870fd08 try to past test
Former-commit-id: 808a4bd77daca4dd92423652878d8262f3a6f2a4
2024-09-10 12:56:12 +08:00
BUAADreamer
24d3c7e378 resolve confilct
Former-commit-id: d6168da2a1f74424b83416cbcbf685861e76ff5f
2024-09-10 12:39:17 +08:00
BUAADreamer
484128b641 support llava-next(video)
Former-commit-id: 27e94593ac467e56e3a7f5c64f4ff6cee81f4b47
2024-09-10 12:31:53 +08:00
hiyouga
588ea95732 update accelerate ver for schedule_free optimizers
Former-commit-id: 2de74e79049ce8e50f605f649275b1dbfb899c8c
2024-09-09 22:51:08 +08:00
hiyouga
800567cde7 fix mm plugin
Former-commit-id: 6a3549c6c1a8c40de61e748f0b280bfc9e1279a2
2024-09-09 22:41:28 +08:00
hiyouga
7a3ba5a25d fix qwen2vl preprocess
Former-commit-id: 52ddd42b7d2ae9e1aa08c15fd5c13ddad96f1b74
2024-09-09 22:33:33 +08:00
hiyouga
dfff411e1a release v0.9.0 (real)
Former-commit-id: 8ff781c8ae5654680f738f69a6db9d7b95d76baf
2024-09-09 01:00:25 +08:00
hiyouga
e20baa4218 fix constants
Former-commit-id: fce6671d2764d7a2b77c44401fc5582c7cbb77aa
2024-09-08 23:52:30 +08:00
hiyouga
d1ab9b501a release v0.9.0
Former-commit-id: 594c450f648ad326ef39c0f4d70d67cda5f36159
2024-09-08 23:43:35 +08:00
hiyouga
3cbc9109ea tiny fix
Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
2024-09-08 23:18:08 +08:00
hiyouga
3259397f89 update scripts
Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
2024-09-08 14:17:41 +08:00
hiyouga
eb5af3d90b support vllm 0.6.0
Former-commit-id: e39470ec51a9c74ad901871eb816df10e851f351
2024-09-08 02:26:20 +08:00
hiyouga
b6810b209a fix test case
Former-commit-id: b075b2971c6acb2c6039b36420a296f1f4e1b91b
2024-09-08 01:50:51 +08:00
hiyouga
158e0e1f63 add test case
Former-commit-id: c452d65e1551074dddd1d87517c0d44dc014c6aa
2024-09-08 01:40:49 +08:00
hiyouga
294a103ead support activation offloading via unsloth gc
Former-commit-id: d3d0dd0feba3ca6f0ae970d5856bec989d26ef67
2024-09-08 01:22:19 +08:00
hiyouga
7f71276ad8 add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
2024-09-08 00:56:56 +08:00
hoshi-hiyouga
93d4570a59 Merge pull request #5388 from yzoaim/cal_mfu_update
update cal_mfu.py

Former-commit-id: fe5eac2cb6a4646b653232d7c68d535105b60f3a
2024-09-08 00:49:28 +08:00
hoshi-hiyouga
527ba2eb2e fix
Former-commit-id: 53a74cbc3afec58b36c2265e080061bcdf702f98
2024-09-08 00:41:45 +08:00
hoshi-hiyouga
3021b31cf3 Update cal_mfu.py
Former-commit-id: 0c391b2e59943b0ca9dd4e8561398e7c856a4b29
2024-09-08 00:39:48 +08:00
-.-
9f2427907e update cal_mfu.py
Former-commit-id: 1cdbb4c774d463969c6be14fb08d92c7a0bdb565
2024-09-07 23:21:35 +08:00
hoshi-hiyouga
570ce100c1 fix #5384
Former-commit-id: 2e86c54f381f7403c30ba78d2acf5003aab6e049
2024-09-07 01:21:14 +08:00
hiyouga
27547355e6 tiny fix
Former-commit-id: c0e9c0484dae6db93cef5048bad827ff22b1986a
2024-09-05 23:41:16 +08:00
hiyouga
c5ef52a67a fix ci
Former-commit-id: b5ffca5a190f3aed8ba8c49bd8cf3239fb787bf5
2024-09-05 22:39:47 +08:00
hiyouga
b48b47d519 fix ci
Former-commit-id: cf0758b03e9b8b4931ba790a9726b8256ee4286c
2024-09-05 22:27:48 +08:00
hiyouga
9bdba2f6a8 add e2e tests
Former-commit-id: 0156a37450604641c4f5f9756ad84324698fc88c
2024-09-05 21:52:28 +08:00
hoshi-hiyouga
d6ce902d80 Merge pull request #5372 from LDLINGLINGLING/main
增加了对minicpm3.0的适配'

Former-commit-id: 2e3c221d9c87bd59f48648be8878b7b50347280f
2024-09-05 21:35:42 +08:00
liudan
ce6dcf3600 根据代码规范修改了代码
Former-commit-id: fe5351980b42e0e38175b0da2401a61b3807fa7c
2024-09-05 20:17:55 +08:00
hoshi-hiyouga
e7f92d16d8 fix #5366
Former-commit-id: b0a4964846dd5be7aa2c54d43f28ba62985587f1
2024-09-05 18:08:09 +08:00
hiyouga
abd26f5f67 update data readme
Former-commit-id: 0af5f054b7b8da8b39eb44b1dfa76050f0c45667
2024-09-05 04:44:49 +08:00
hiyouga
4d35ace75e update data readme
Former-commit-id: 81adb153b7d0b30e6cd50c9bf4ca1ccf17458611
2024-09-05 04:25:27 +08:00
hiyouga
72222d1598 support Yi-Coder models
Former-commit-id: ea3f1659e70541c4fa8b7079a0a8c94fce9a41c8
2024-09-05 03:12:24 +08:00
hiyouga
26d914b8fc fix ci
Former-commit-id: 280c0f3f2cea4dfced797cc0e15f72b8b3a93542
2024-09-05 03:02:59 +08:00
hiyouga
7b01c0676c fix ci
Former-commit-id: 7899b44b19c3d0a70706d987bb7d2e0e3536014b
2024-09-05 02:49:22 +08:00
hiyouga
571a9b8669 update ci
Former-commit-id: e24bf7345442701ca874d439f0ca3da49fa59a84
2024-09-05 02:26:10 +08:00
hoshi-hiyouga
ed35eb1e9e Merge pull request #5365 from hiyouga/video_finetuning
Support Qwen2-VL Fine-Tuning on Video Datasets

Former-commit-id: 178cc3fbc48bf2c68723b487681db04e660b12fa
2024-09-05 02:24:58 +08:00
hiyouga
d291e0d60d tiny fix
Former-commit-id: 9da6e084e1e5daf7403e7fabeaaec686167fb11f
2024-09-05 02:16:49 +08:00
hiyouga
1874d579c5 video datasets
Former-commit-id: 33f28ce82d9e44d2615909250dc56d6a4a03cd99
2024-09-05 02:04:17 +08:00
liudan
c692339020 增加了对minicpm3.0的适配'
Former-commit-id: 4ad3a761af2452ef3f6c61190b7e47c9ea5227b9
2024-09-04 23:10:05 +08:00
hiyouga
2c1eef34cb fix test
Former-commit-id: 553a83aff9f9da35c9a0eca81f7d2b0bf2adf6ff
2024-09-04 22:38:26 +08:00
hiyouga
af178cbcd1 update get template
Former-commit-id: 21ea0d0786f91c0bce79630963e66b815a6792a0
2024-09-04 22:36:20 +08:00
hoshi-hiyouga
5d85be31ca Merge pull request #5323 from naem1023/feat/add-dataset-map-batch-size-argument
Add batch size of map function in the preprocessed dataset

Former-commit-id: c3428c5807500d087cdee4386798e10e39c9cf30
2024-09-04 22:09:36 +08:00
hoshi-hiyouga
372b71c847 fix #5228
Former-commit-id: 0d332ca8d0987c0331361934ab110fafa6402a7e
2024-09-04 19:10:30 +08:00
hiyouga
41a9c415e1 fix #5252
Former-commit-id: 73f30b4dfffb260e24f9e2332617b8ca2c249ed5
2024-09-04 03:17:54 +08:00
hiyouga
915e32a5f8 add vl_feedback dataset
Former-commit-id: 6ff34ad2db383b5fbd51008bcc5eec880658811e
2024-09-04 03:13:03 +08:00
hiyouga
f4dd429cbf fix #5344
Former-commit-id: 9d445c0b5be5ccc0e6d1979e76a869ddf92d9534
2024-09-04 03:06:06 +08:00
hoshi-hiyouga
7435cde2ef Merge pull request #5346 from hiyouga/lazy_image
[exp] Lazyload for multimodal inputs

Former-commit-id: 4bbd721361a8c5888b28f5fcdcbb2a4ad2305445
2024-09-04 03:00:53 +08:00
hiyouga
7056087e92 lazy image load
Former-commit-id: cdd733b575411e003bc5ffd6560dd8eff8aa09cf
2024-09-04 02:27:08 +08:00
hiyouga
fed7ae5661 fix #5334
Former-commit-id: a5ea0f83f00c81d128a1f50ce244866ce38ee15f
2024-09-03 19:09:42 +08:00
hiyouga
5019c6148b fix #5338
Former-commit-id: a66ddfea218feefde50fa097d20b4bcbe89ab791
2024-09-03 17:45:17 +08:00
hiyouga
2e1396cd6b lint
Former-commit-id: d821d933e6cb982d648a69f85f6ad01d0560ed70
2024-09-03 00:46:25 +08:00
hiyouga
b5e9df5df8 fix #5324
Former-commit-id: f7aa06c9c0b18c28419ea5792410915d3f322cbf
2024-09-02 23:56:21 +08:00
naem1023
3622856994 feat: add batch size of map function in the preprocessed dataset
Former-commit-id: 94b6cf06c2f84d0619b1a2dccaf8abb51de9951c
2024-09-02 13:52:47 +09:00
hoshi-hiyouga
7367c6ec21 fix trainer predict
Former-commit-id: 2790790cd26c6743105555a60523b89f367ebce3
2024-09-02 10:15:29 +08:00
hoshi-hiyouga
6579ec8c4c remove .cpu()
Former-commit-id: 35c57cc9dcba305d40282a9757ddc23968c210ac
2024-09-02 10:10:53 +08:00
hiyouga
a7fbae47d5 fix mm inference
Former-commit-id: fa782c15a07ed40f8a6381acdf2da395377efd04
2024-09-02 01:47:40 +08:00
hiyouga
f203a9d78e tiny fix
Former-commit-id: 8b4f408da110d74285bae20bbd969013a979964b
2024-09-02 01:33:22 +08:00
hiyouga
bae73e676c add image num check
Former-commit-id: 15201113bf16b748c0a758c7a5b363da8272e0e6
2024-09-02 01:31:36 +08:00
hiyouga
806e1061d4 add pokemon dataset
Former-commit-id: 06680158a0f0a1e3c542e77af92ac877fbe357c5
2024-09-02 01:02:25 +08:00
hiyouga
f920091667 update readme
Former-commit-id: 25a05d9f96718e06ce83f5bb1f41d2c001790295
2024-09-01 23:32:39 +08:00
hiyouga
801979f779 update wechat
Former-commit-id: 7f88dfe080db10ff12d1fb80b43099a356c899ea
2024-09-01 23:30:57 +08:00
hoshi-hiyouga
df2d32e7aa Merge pull request #5317 from ByronHsu/patch-1
Add liger kernel link

Former-commit-id: a319b3cf9119fd49cbcfb17b963e111a2f86bb51
2024-09-01 23:30:12 +08:00
hiyouga
60cf12727b add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
2024-09-01 22:57:41 +08:00
hiyouga
7621526d22 tiny fix
Former-commit-id: 8ccaae3871d8d1fe3ea4633d427aecb2ab3addec
2024-09-01 21:15:44 +08:00
hiyouga
559b84dceb fix bug
Former-commit-id: 6e19e56000dd18d5faf84ceabce8d7708ff21e4d
2024-09-01 21:07:49 +08:00
hiyouga
7e4c5d4bb3 fix mixed mm inputs and rlhf-v
Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
2024-09-01 20:52:47 +08:00
Byron Hsu
2a4ed6610e Add liger kernel link
Former-commit-id: 4f313044cf8efd9c6ebcbd4741f6f38d56804b7f
2024-08-30 17:16:16 -07:00
hiyouga
1d8e9c7897 fix ci (temp)
Former-commit-id: 9ebaafd2e5c16ecef0243e4df77344ed7c823e57
2024-08-31 02:03:56 +08:00
hiyouga
43654028eb add test mm plugin
Former-commit-id: ddea5cca5a3174de1dcc7fdee8ec69e77700b6bf
2024-08-31 01:53:38 +08:00
hiyouga
2f6fc27c8b remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
2024-08-31 00:24:51 +08:00
hiyouga
d789b667d7 optimize predict vram
Former-commit-id: a577e44eee351b3ed8011a33ae01cd713354ff97
2024-08-30 23:08:45 +08:00
hiyouga
66a1abac6a add examples
Former-commit-id: 169c68921b1b8ac279834b060d9e7d38a56fe1aa
2024-08-30 21:43:19 +08:00
hiyouga
665db18661 tiny fix
Former-commit-id: 830511a6d0216da99520aee8b3a753d347a71fa9
2024-08-30 03:21:50 +08:00
hiyouga
30d97ca879 fix #5307
Former-commit-id: 63c19ddfe483a16c1c9afc2f1441e8070bb0f7e4
2024-08-30 02:45:40 +08:00
hiyouga
c62a6ca59d refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
2024-08-30 02:14:31 +08:00
hoshi-hiyouga
77c2c7076b Merge pull request #5290 from simonJJJ/qwen2_vl
support qwen2-vl

Former-commit-id: 7156f832af8505b26371559d340c0e69eb962bbc
2024-08-30 02:10:36 +08:00
hoshi-hiyouga
7466fd4387 fix bug
Former-commit-id: 365e6df71509569f59c40743c115f1a4b945ef0f
2024-08-30 02:05:26 +08:00
hiyouga
c1369a1ec9 update liger kernel
Former-commit-id: d6bf6ca2161c99dd5d644e31d2b1df451017b68c
2024-08-29 20:46:08 +08:00
hiyouga
d677fe053d fix #5292
Former-commit-id: dd81ce8ce5fdf450027c5f9634abb6ac2cd52128
2024-08-29 20:37:47 +08:00
hiyouga
7c6785d3df fix #5295
Former-commit-id: c76873b0eb8225f6e6bfc7223c6012387dceb8ed
2024-08-29 20:30:18 +08:00
hiyouga
77341ee3c4 fix #5305
Former-commit-id: a710ebaf97c258c802f24e508d83f1f3f10edc6d
2024-08-29 20:16:01 +08:00
simonJJJ
5b4b60cfb5 update
Former-commit-id: a968a416d5e513320c97109229ca1e6ddc003cb1
2024-08-28 20:22:46 +08:00
simonJJJ
0f3d54d8a0 initial-commit
Former-commit-id: b6a39847a10b417b09db4b5512dd835e9e4ce928
2024-08-28 16:51:35 +08:00
hiyouga
7272792f65 update wechat
Former-commit-id: ef91752cc6f53088eaf7fc2f64f7148821d82ec2
2024-08-27 12:55:23 +08:00
hiyouga
4cc8e16595 add extra requires
Former-commit-id: c47511773ae9886aae4e5ea1841866d2125abc34
2024-08-27 12:52:12 +08:00
hiyouga
ca5a759f94 tiny fix
Former-commit-id: d2cede7023bbe28525ef8b4ad27247445d8c22e5
2024-08-27 12:49:32 +08:00
hoshi-hiyouga
be51e56a2e Merge pull request #5237 from marko1616/patch-1
Fix mllm api

Former-commit-id: 017703c7ab7f3dc566792619537c3202ca4f4bb7
2024-08-27 12:24:43 +08:00
marko1616
3a9171e275 ruff pass.
Former-commit-id: c2f817772f8e7d947dca04f546befc70001abe64
2024-08-27 11:30:16 +08:00
marko1616
bd0f3b4050 Update chat.py
Former-commit-id: 4e5893a5c4a47ff3cb989bbef0841effc713fc08
2024-08-27 11:27:56 +08:00
hiyouga
206a8364d4 support liger kernel
Former-commit-id: 0f4e54abf6c5feb2329855a4047597ad5147720a
2024-08-27 11:20:14 +08:00
marko1616
097d031066 Force re check.
Former-commit-id: 5f04452f7d65e535d0af08944f7b9e29e85f51d7
2024-08-23 14:43:18 +08:00
marko1616
2674b42b59 Update chat.py
Former-commit-id: 206a16c17d253956afb96daea6f24478e17334fc
2024-08-22 12:24:34 +08:00
marko1616
edf2e51bbc Update chat.py
Former-commit-id: edf6dc1995daa6c3635c3fda1052b340693a04f5
2024-08-22 12:14:34 +08:00
MengqingCao
47877acc2a update npu base image
Former-commit-id: 20819f7707cfff6b951484e91fc7ecda2bf68528
2024-08-21 09:12:38 +00:00
hiyouga
d111a324bc tiny fix
Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043
2024-08-20 00:10:52 +08:00
hoshi-hiyouga
388f0a6e05 Merge pull request #5156 from YeQiuO/main
fix Llama-template's system prompt bug

Former-commit-id: 0b57175d3bd029675dae2f55995b7eeb4e9adc7a
2024-08-20 00:09:03 +08:00
hoshi-hiyouga
8c13c02c55 Update template.py
Former-commit-id: f5a075cb1c90f05bb0de26c6aea718f556c54623
2024-08-20 00:03:33 +08:00
hoshi-hiyouga
a101fde917 Merge pull request #5163 from liu-zichen/fix_ppo_optim
fix lr not change

Former-commit-id: f3c03ec6a89bf57f290820fa31eda24291355e4e
2024-08-19 23:56:24 +08:00
hoshi-hiyouga
1f4373b6e5 Merge pull request #5185 from chenhuiyu/feature/add-sailorllm-template
Add SailorLLM template

Former-commit-id: 28387d6b2f9e3bcc6321345c46b525c8180ebf7e
2024-08-19 23:51:49 +08:00
hoshi-hiyouga
525747b472 Merge pull request #5188 from Zxilly/main
fix: report correct device count for intel xpu
Former-commit-id: cd3c536cb3936061d905256850b0e57df4498010
2024-08-19 23:51:39 +08:00
hoshi-hiyouga
472f12c985 Merge pull request #5193 from Ricardo-L-C/main
_is_bf16_available judgment supports npu

Former-commit-id: 18b9ac49c45af773a2ea563f5e1852dc4b775db8
2024-08-19 23:40:59 +08:00
hoshi-hiyouga
b681f24f43 Update template.py
Former-commit-id: c6822a217e1c296f4aedd9a2c7610acd1dbd443e
2024-08-19 23:40:16 +08:00
hiyouga
fd02b089b6 update readme
Former-commit-id: 756e438866876fa54495cf557dd1e299b17a42fb
2024-08-19 23:32:04 +08:00
Ricardo
57d4c4a4f8 _is_bf16_available judgment supports npu
Former-commit-id: 50a1e892a1005b4cdd82dca1005f71db08ed89a2
2024-08-16 02:58:22 +00:00
Zxilly
3595d26846 fix: report correct device count for intel xpu
Former-commit-id: 0618f660b6511599365bd9be64499dbab41a79ba
2024-08-15 08:30:43 +00:00
Huiyu Chen
22a79c169d Add SailorLLM template
Former-commit-id: a594abe0321a718394a97b5a48ded16e2012c1f0
2024-08-15 15:10:14 +08:00
liu-zichen
75dfe259cf fix lr not change
Former-commit-id: 387dd2d51b5d8cd666459040fdd16525b34720d9
2024-08-13 16:33:34 +08:00
codingma
2e257d6af0 add tutorial and doc links
Former-commit-id: 4f6072562a34e0ec97471210ff54244cf0d0f3df
2024-08-13 16:13:10 +08:00
“Wzw”
e734222373 fix Llama-template's system prompt bug
Former-commit-id: 2e3eddcd0918b0c968ded0df7c82e3dcff870381
2024-08-12 19:22:12 +08:00
hiyouga
6a351b9912 update readme
Former-commit-id: 4fecc5ee56873a7ab4941e46a5168cfe2ecb4bb6
2024-08-10 10:17:35 +08:00
hiyouga
cfc04aa162 update readme
Former-commit-id: fa7bc9f1c7347153f9092ffbbb8e88c6b2f59632
2024-08-09 20:46:02 +08:00
hiyouga
943c795318 add magpie ultra dataset
Former-commit-id: 3317b24329b87e30f13a78936ac5554f211abf7a
2024-08-09 20:28:55 +08:00
hiyouga
7fb61bad04 add qwen2 math models
Former-commit-id: 72ff43a1772c9de5ff914d5e1c8bdc8dea9ae0c8
2024-08-09 20:20:35 +08:00
hiyouga
47efcdb1dd update examples
Former-commit-id: d5c57c8b7f64afe8061045ec9689abbac45c1175
2024-08-09 20:13:46 +08:00
hiyouga
59cbce1a46 add adam_mini to readme
Former-commit-id: d610c6bcf8a8ba6f4236f5d11f79571b83f4fb11
2024-08-09 20:02:03 +08:00
hoshi-hiyouga
7e755e9cac Merge pull request #5095 from relic-yuexi/feat-optimizer
Feat optimizer

Former-commit-id: f08390d252d42a812b71a08daba7339cc40889b7
2024-08-09 19:51:33 +08:00
hiyouga
9d1e2c3c1f update scripts
Former-commit-id: dabf5a1dc661a6581474c6a5ec115322d168ed5f
2024-08-09 19:16:23 +08:00
hiyouga
5af32ce705 follow #5115
Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
2024-08-09 18:03:00 +08:00
hoshi-hiyouga
4e8861e653 Merge pull request #5115 from YeQiuO/main
fix: `Train on the last turn only` truncate bug
Former-commit-id: 2c6dae45f7a7b72c961489ac407b1b444ab7752e
2024-08-09 17:58:27 +08:00
hoshi-hiyouga
d4d7ffb17c Merge pull request #5072 from relic-yuexi/main
fix the deepseekcoder template to avoid repeat problem

Former-commit-id: 2ae7d5c91725eab9f994015d8d3577894c7978b6
2024-08-09 16:35:21 +08:00
hoshi-hiyouga
46f834ec75 Update template.py
Former-commit-id: ae2a5221c109ae3474d219c37433be767abbee91
2024-08-09 16:27:42 +08:00
“Wzw”
6ec64a7e56 mask_history args verify valid
Former-commit-id: 2f8388b4f4195d934400ad9267d72e10ca4105a3
2024-08-08 10:12:01 +08:00
“Wzw”
d71446e387 fix mask_history tiny bug
Former-commit-id: cac07aac6196be026f723b2397a343d4fb675973
2024-08-08 10:09:33 +08:00
codingma
eada49e56b fix eval_dataset in example
Former-commit-id: e1ffc54f7e58419cc8da958a4d3c2697e18d5583
2024-08-07 18:24:19 +08:00
moontidef
8f42d7df56 feat: add support for adammini
Former-commit-id: a2d5fafb705ff44db1711e972490f0abebc2012b
2024-08-07 10:08:22 +08:00
moontidef
33a90b9026 fix: rename optimzer to optimizer
Former-commit-id: 186dc1fde822e6a603ac273538741ea3853f243e
2024-08-07 10:05:01 +08:00
moontidef
710902b0d0 Merge branch 'hiyouga:main' into main
Former-commit-id: d1b23283e0e4286f126d38d7bdc55802f74c8922
2024-08-06 00:18:45 +08:00
moontidef
7b4f5d3b21 fix: fix the deepseekcoder template to avoid repeat problem
Former-commit-id: 56294831115f095135f72490a8a435434b2f0a11
2024-08-05 23:55:45 +08:00
hiyouga
13093963b1 fix #5048
Former-commit-id: 71a6861667ae68c1fd6a69acf68e1359b858cf1b
2024-08-05 23:48:19 +08:00
hoshi-hiyouga
2e477e7458 Merge pull request #5037 from codemayq/feature-gemma-2-2b
support gemma-2-2b

Former-commit-id: 6af51fadff92cd3e665c556ac073a1876f792ada
2024-08-05 23:27:37 +08:00
codingma
4b6252151e support gemma-2-2b
Former-commit-id: 7037192cf6049fd7d675aed4a6237ed929c6b170
2024-08-01 13:45:48 +08:00
hoshi-hiyouga
f3765d1996 Merge pull request #5010 from Eruly/main
Add Korean web UI (llamafactory-cli webui)

Former-commit-id: 2050806aa826028df45c0c746b4314afe178dcd3
2024-07-30 01:55:54 +08:00
hoshi-hiyouga
1f5cdd66b7 Merge pull request #4996 from LDLINGLINGLING/main
增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接

Former-commit-id: a86a776fb0f75697b0fee7694a5a0d6bd04fee0a
2024-07-30 01:55:30 +08:00
hoshi-hiyouga
5b0ddbb835 Update README_zh.md
Former-commit-id: 922906faf2d432def7cfdac82f90472fa1bb24a9
2024-07-30 01:55:13 +08:00
hoshi-hiyouga
4f92b56f06 Update README.md
Former-commit-id: 6bc7f71940be0a8f1614f9036b9c539ce46d34e1
2024-07-30 01:53:19 +08:00
hoshi-hiyouga
a1f6ff92be Update README.md
Former-commit-id: 54eecdec0da06677ea55847c74642d0fc12d8908
2024-07-30 01:52:35 +08:00
hoshi-hiyouga
ef98e91618 Merge pull request #4995 from codemayq/fix-pissa
fix pissa callback

Former-commit-id: 052c0f6bd9e872ea325b5a6aef98c4c070733384
2024-07-30 01:47:25 +08:00
eruly
9fdf800750 Add Korean web UI (llamafactory-cli webui)
Former-commit-id: 357a035f2aeb9548368c230c5a17dcdfa4844b17
2024-07-29 13:47:13 +00:00
liudan
32c698e4c2 增加了MiniCPM在页面首页的支持列表,MiniCPM官方github也放了LLama_factory的友情链接
Former-commit-id: f482a6e2fd30aff5113e53f3f07b4649982bcc2e
2024-07-29 10:58:28 +08:00
codingma
75e80fa820 fix pissa save
Former-commit-id: 25a1dad7c8df79c15efecb8c6f871a13a327f57a
2024-07-29 10:44:34 +08:00
hiyouga
f8329bc632 tiny fix
Former-commit-id: 183d8bd500a8e9513a077161ba8e8d61bea9200f
2024-07-26 11:51:00 +08:00
hoshi-hiyouga
9f74d36ba4 Merge pull request #4892 from piamo/main
update deepseek template

Former-commit-id: 3233efc8404972098665286d9dec7312dd6ecfab
2024-07-26 11:49:34 +08:00
hoshi-hiyouga
fc2435f135 Merge pull request #4950 from liuwwang/main and fi
fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 93a68ea1f4372973f745a2c250250ecaac515e27
2024-07-26 11:48:56 +08:00
hoshi-hiyouga
0636519ba3 Merge pull request #4970 from HardAndHeavy/add-rocm
Add ROCm support

Former-commit-id: c0f21d869bce6e59825d57c66bce3fe54f50065f
2024-07-26 11:41:23 +08:00
hoshi-hiyouga
573bf03a6f Update README_zh.md
Former-commit-id: 86a27a97ff67b0d4bcd671c62759cd049542dc1b
2024-07-26 11:30:57 +08:00
hoshi-hiyouga
9e529be4e7 Update README.md
Former-commit-id: 1c167bb2ea3a47bdeeccc044a653662132c61698
2024-07-26 11:29:28 +08:00
hoshi-hiyouga
7af4ffa6cc Update README.md
Former-commit-id: d6e7a69c274c3756587e18a039637dd37fa152b2
2024-07-26 11:29:09 +08:00
HardAndHeavy
5b67ccd1c6 Add ROCm support
Former-commit-id: cf9df10a24936efd420b0fdac541fd6c0808a327
2024-07-25 21:29:28 +03:00
khazic
5166dbbcd3 Added the reference address for TRL PPO details.
Former-commit-id: 509c55608643eae3a6456683d425a7c636cfc3e9
2024-07-25 09:03:21 +08:00
hiyouga
21adb09730 fix #4959
Former-commit-id: 96e8a1d47874708c6157865c78be4cd6c533e01b
2024-07-24 23:44:00 +08:00
hiyouga
28b5f656db update webui
Former-commit-id: 463edec1b1c1345afc791e225deb33f118f3582e
2024-07-24 21:11:51 +08:00
hoshi-hiyouga
68ee2d512f Update README_zh.md
Former-commit-id: 1443e876697e18108573387e501a7453ba9fc06c
2024-07-24 21:08:42 +08:00
hoshi-hiyouga
a5f7e0efc6 Update README.md
Former-commit-id: 07d86e38cfd857d1dfa898541f3e5bd9c6f11581
2024-07-24 21:07:14 +08:00
hiyouga
211038584a tiny fix
Former-commit-id: 28cac0e325bfd7a6c0c344ad2d46511613190cd7
2024-07-24 18:33:39 +08:00
hiyouga
ff5ba97970 fix #4928
Former-commit-id: 6d557e8959678f9d4edbcb3d5a6dfba14b429b18
2024-07-24 17:00:29 +08:00
hiyouga
27f2c3cae1 fix #4925
Former-commit-id: 79c336e2339974471627487858d59e4ed2152370
2024-07-24 16:56:58 +08:00
hiyouga
48f0819327 fix #4944
Former-commit-id: 9e8cf3b21a0b12d1413c3c7f3d60399784909242
2024-07-24 16:42:51 +08:00
hiyouga
5c6d88e91c add mistral nemo model
Former-commit-id: 428bb49f53b32947bc0a62ca19ab10844154c07c
2024-07-24 16:25:53 +08:00
hiyouga
0a04d9470f add llama3.1
Former-commit-id: 3c433890f9b61c520572f5233aae70584da0f330
2024-07-24 16:20:11 +08:00
Liuww
f0408c0dde fix: Repair the issue where quantization failed after merging the adapter.
Former-commit-id: 8109561b7f577d448f8bca7e569f7f443cf6bb52
2024-07-24 14:31:29 +08:00
hiyouga
a041f4a111 tiny fix
Former-commit-id: bf6a2f032c598f969708c1c3db4875d6239c41a9
2024-07-22 21:10:15 +08:00
hoshi-hiyouga
cdf9dae53e fix #4917
Former-commit-id: e26919aafd8436489d065789c9c25d72c8d05a6d
2024-07-22 11:28:31 +08:00
hiyouga
1917f431f5 tiny fix
Former-commit-id: 9133316e558a3c8744f5eb6ab8678686bf4859ed
2024-07-22 00:06:03 +08:00
hiyouga
a770afbff2 fix flashattn + packing
Former-commit-id: 4adc6ce4abc718c25f39b316bfc3352d0d01ed1e
2024-07-21 17:07:45 +08:00
huangpan.foo
b1a5bf025b update deepseek template
Former-commit-id: f5ca86ec95bb301df42ffaa6923fc3037a224e34
2024-07-19 15:02:54 +08:00
hiyouga
adff3e5050 set dev version
Former-commit-id: 0b9a2275dc533b65578278f979ce053e95a644b3
2024-07-19 02:01:46 +08:00
181 changed files with 7514 additions and 2936 deletions

View File

@@ -7,6 +7,8 @@ data
docker docker
saves saves
hf_cache hf_cache
ms_cache
om_cache
output output
.dockerignore .dockerignore
.gitattributes .gitattributes

37
.env.local Normal file
View File

@@ -0,0 +1,37 @@
# Note: actually we do not support .env, just for reference
# api
API_HOST=
API_PORT=
API_KEY=
API_MODEL_NAME=
FASTAPI_ROOT_PATH=
MAX_CONCURRENT=
# general
DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=
MASTER_ADDR=
MASTER_PORT=
NNODES=
NODE_RANK=
NPROC_PER_NODE=
# wandb
WANDB_DISABLED=
WANDB_PROJECT=
WANDB_API_KEY=
# gradio ui
GRADIO_SHARE=
GRADIO_SERVER_NAME=
GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
GRADIO_IPV6=
# setup
ENABLE_SHORT_CONSOLE=1
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=

View File

@@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
### Style guide ### Style guide
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details. LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
### Create a Pull Request
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
git clone git@github.com:[username]/LLaMA-Factory.git
cd LLaMA-Factory
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
```
3. Create a new branch to hold your development changes:
```bash
git checkout -b dev_your_branch
```
4. Set up a development environment by running the following command in a virtual environment:
```bash
pip install -e ".[dev]"
```
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
5. Check code before commit:
```bash
make commit
make style && make quality
make test
```
6. Submit changes:
```bash
git add .
git commit -m "commit message"
git fetch upstream
git rebase upstream/main
git push -u origin dev_your_branch
```
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).

View File

@@ -3,14 +3,14 @@ name: tests
on: on:
push: push:
branches: branches:
- main - "main"
paths: paths:
- "**.py" - "**.py"
- "requirements.txt" - "requirements.txt"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
branches: branches:
- main - "main"
paths: paths:
- "**.py" - "**.py"
- "requirements.txt" - "requirements.txt"
@@ -18,13 +18,27 @@ on:
jobs: jobs:
tests: tests:
runs-on: ubuntu-latest strategy:
fail-fast: false
matrix:
python-version:
- "3.8" # TODO: remove py38 in next transformers release
- "3.9"
- "3.10"
- "3.11"
os:
- "ubuntu-latest"
- "windows-latest"
- "macos-13"
runs-on: ${{ matrix.os }}
environment: environment:
name: tests name: tests
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
steps: steps:
- name: Checkout - name: Checkout
@@ -33,7 +47,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: ${{ matrix.python-version }}
cache: "pip" cache: "pip"
cache-dependency-path: "setup.py" cache-dependency-path: "setup.py"

6
.gitignore vendored
View File

@@ -159,7 +159,13 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/ .idea/
# vscode
.vscode/
# custom .gitignore # custom .gitignore
ms_cache/
hf_cache/
om_cache/
cache/ cache/
config/ config/
saves/ saves/

28
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,28 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-ast
- id: check-added-large-files
args: ['--maxkb=25000']
- id: check-merge-conflict
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@@ -1,6 +1,13 @@
.PHONY: quality style test .PHONY: build commit quality style test
check_dirs := scripts src tests check_dirs := scripts src tests setup.py
build:
pip install build && python -m build
commit:
pre-commit install
pre-commit run --all-files
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
@@ -11,4 +18,4 @@ style:
ruff format $(check_dirs) ruff format $(check_dirs)
test: test:
CUDA_VISIBLE_DEVICES= pytest tests/ CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/

194
README.md
View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-93-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -12,6 +12,7 @@
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -21,13 +22,22 @@
**Fine-tuning a large language model can be easy as...** **Fine-tuning a large language model can be easy as...**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6 https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path: Choose your path:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
Recent activities:
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
## Table of Contents ## Table of Contents
@@ -46,11 +56,11 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,15 +81,27 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We support **[Adam-mini](https://github.com/zyushun/Adam-mini)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
<details><summary>Full Changelog</summary>
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary> [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `paligemma` template for chat completion.
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -91,7 +113,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage. [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
@@ -103,7 +125,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage. [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage. [24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed. [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
@@ -119,7 +141,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
@@ -152,7 +174,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
@@ -161,20 +183,28 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
@@ -200,6 +230,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> The implementation details of PPO can be found in [this blog](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html).
## Provided Datasets ## Provided Datasets
<details><summary>Pre-training datasets</summary> <details><summary>Pre-training datasets</summary>
@@ -259,7 +292,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +311,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +333,20 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### Hardware Requirement ### Hardware Requirement
@@ -338,7 +375,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -390,7 +427,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
### Data Preparation ### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
> [!NOTE] > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. > Please update `data/dataset_info.json` to use your custom dataset.
@@ -422,16 +459,24 @@ For CUDA users:
```bash ```bash
cd docker/docker-cuda/ cd docker/docker-cuda/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
``` ```
For Ascend NPU users: For Ascend NPU users:
```bash ```bash
cd docker/docker-npu/ cd docker/docker-npu/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
```
For AMD ROCm users:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
``` ```
<details><summary>Build without Docker Compose</summary> <details><summary>Build without Docker Compose</summary>
@@ -450,6 +495,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@@ -474,6 +520,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@@ -493,13 +540,44 @@ docker run -dit \
docker exec -it llamafactory bash docker exec -it llamafactory bash
``` ```
For AMD ROCm users:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details> </details>
<details><summary>Details about volume</summary> <details><summary>Details about volume</summary>
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. - `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. - `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine. - `om_cache`: Similar to Hugging Face cache but for Modelers users.
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
</details> </details>
@@ -510,7 +588,9 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document. > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
### Download from ModelScope Hub ### Download from ModelScope Hub
@@ -522,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Download from Modelers Hub
You can also use Modelers Hub to download models and datasets.
```bash
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
```
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
### Use W&B Logger ### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
@@ -600,17 +690,39 @@ If you have a project that should be incorporated, please contact via email or c
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@@ -618,7 +730,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-72-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-93-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@@ -12,6 +12,7 @@
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![SageMaker](https://img.shields.io/badge/SageMaker-Open%20in%20AWS-blue)](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@@ -21,13 +22,23 @@
**微调大模型可以像这样轻松…** **微调大模型可以像这样轻松…**
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594 https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
选择你的打开方式: 选择你的打开方式:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **PAI-DSW**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **Amazon SageMaker**[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
近期活动:
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
## 目录 ## 目录
@@ -46,11 +57,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,15 +82,27 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
[24/08/09] 我们支持了 **[Adam-mini](https://github.com/zyushun/Adam-mini)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
<details><summary>展开日志</summary>
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary> [24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `paligemma` 模板进行微调使其获得对话能力。
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -91,7 +114,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
@@ -103,7 +126,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。 [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。
@@ -152,7 +175,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| ------------------------------------------------------------ | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
@@ -161,20 +184,27 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
@@ -200,6 +230,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!TIP]
> 有关 PPO 的实现细节,请参考[此博客](https://newfacade.github.io/notes-on-reinforcement-learning/17-ppo-trl.html)。
## 数据集 ## 数据集
<details><summary>预训练数据集</summary> <details><summary>预训练数据集</summary>
@@ -259,7 +292,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) - [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) - [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
@@ -276,6 +311,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [VLFeedback (en)](https://huggingface.co/datasets/Zhihui/VLFeedback)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -296,20 +333,20 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.11 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.3.0 | | torch | 1.13.1 | 2.4.0 |
| transformers | 4.41.2 | 4.41.2 | | transformers | 4.41.2 | 4.43.4 |
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.20.0 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.32.0 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.12.0 |
| trl | 0.8.6 | 0.9.4 | | trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.4.3 | | vllm | 0.4.3 | 0.5.0 |
| flash-attn | 2.3.0 | 2.5.9 | | flash-attn | 2.3.0 | 2.6.3 |
### 硬件依赖 ### 硬件依赖
@@ -338,7 +375,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -390,7 +427,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
### 数据准备 ### 数据准备
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
@@ -422,16 +459,24 @@ CUDA 用户:
```bash ```bash
cd docker/docker-cuda/ cd docker/docker-cuda/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
``` ```
昇腾 NPU 用户: 昇腾 NPU 用户:
```bash ```bash
cd docker/docker-npu/ cd docker/docker-npu/
docker-compose up -d docker compose up -d
docker-compose exec llamafactory bash docker compose exec llamafactory bash
```
AMD ROCm 用户:
```bash
cd docker/docker-rocm/
docker compose up -d
docker compose exec llamafactory bash
``` ```
<details><summary>不使用 Docker Compose 构建</summary> <details><summary>不使用 Docker Compose 构建</summary>
@@ -450,6 +495,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@@ -474,6 +520,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@@ -493,13 +540,44 @@ docker run -dit \
docker exec -it llamafactory bash docker exec -it llamafactory bash
``` ```
AMD ROCm 用户:
```bash
docker build -f ./docker/docker-rocm/Dockerfile \
--build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg INSTALL_FLASHATTN=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \
-v ./output:/app/output \
-v ./saves:/app/saves \
-p 7860:7860 \
-p 8000:8000 \
--device /dev/kfd \
--device /dev/dri \
--shm-size 16G \
--name llamafactory \
llamafactory:latest
docker exec -it llamafactory bash
```
</details> </details>
<details><summary>数据卷详情</summary> <details><summary>数据卷详情</summary>
- hf_cache使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。 - `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- data宿主机中存放数据集的文件夹路径 - `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供
- output将导出目录设置为该路径后即可在宿主机中访问导出后的模型 - `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供
- `data`:宿主机中存放数据集的文件夹路径。
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
</details> </details>
@@ -510,7 +588,9 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。 > API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
### 从魔搭社区下载 ### 从魔搭社区下载
@@ -522,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 从魔乐社区下载
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
```bash
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
```
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`
### 使用 W&B 面板 ### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。 若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
@@ -600,17 +690,38 @@ run_name: test_run # 可选
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) 1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) 1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) 1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburghs Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh's Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. Li et al. Calibrating LLMs with Preference Optimization on Thought Trees for Generating Rationale in Science Question Scoring. 2024. [[arxiv]](https://arxiv.org/abs/2406.19949)
1. Yang et al. Financial Knowledge Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2407.00365)
1. Lin et al. DogeRM: Equipping Reward Models with Domain Knowledge through Model Merging. 2024. [[arxiv]](https://arxiv.org/abs/2407.01470)
1. Bako et al. Evaluating the Semantic Profiling Abilities of LLMs for Natural Language Utterances in Data Visualization. 2024. [[arxiv]](https://arxiv.org/abs/2407.06129)
1. Huang et al. RoLoRA: Fine-tuning Rotated Outlier-free LLMs for Effective Weight-Activation Quantization. 2024. [[arxiv]](https://arxiv.org/abs/2407.08044)
1. Jiang et al. LLM-Collaboration on Automatic Science Journalism for the General Audience. 2024. [[arxiv]](https://arxiv.org/abs/2407.09756)
1. Inouye et al. Applied Auto-tuning on LoRA Hyperparameters. 2024. [[paper]](https://scholarcommons.scu.edu/cseng_senior/272/)
1. Qi et al. Research on Tibetan Tourism Viewpoints information generation system based on LLM. 2024. [[arxiv]](https://arxiv.org/abs/2407.13561)
1. Xu et al. Course-Correction: Safety Alignment Using Synthetic Preferences. 2024. [[arxiv]](https://arxiv.org/abs/2407.16637)
1. Sun et al. LAMBDA: A Large Model Based Data Agent. 2024. [[arxiv]](https://arxiv.org/abs/2407.17535)
1. Zhu et al. CollectiveSFT: Scaling Large Language Models for Chinese Medical Benchmark with Collective Instructions in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2407.19705)
1. Yu et al. Correcting Negative Bias in Large Language Models through Negative Attention Score Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2408.00137)
1. Xie et al. The Power of Personalized Datasets: Advancing Chinese Composition Writing for Elementary School through Targeted Model Fine-Tuning. IALP 2024. [[paper]](https://www.asianlp.sg/conferences/ialp2024/proceedings/papers/IALP2024_P055.pdf)
1. Liu et al. Instruct-Code-Llama: Improving Capabilities of Language Model in Competition Level Code Generation by Online Judge Feedback. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_11)
1. Wang et al. Cybernetic Sentinels: Unveiling the Impact of Safety Data Selection on Model Security in Supervised Fine-Tuning. ICIC 2024. [[paper]](https://link.springer.com/chapter/10.1007/978-981-97-5669-8_23)
1. Xia et al. Understanding the Performance and Estimating the Cost of LLM Fine-Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2408.04693)
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。 1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@@ -618,7 +729,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@@ -23,6 +23,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"system": "the column name in the dataset containing the system prompts. (default: None)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)", "tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)", "images": "the column name in the dataset containing the image inputs. (default: None)",
"videos": "the column name in the dataset containing the videos inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)", "chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)", "rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)" "kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
@@ -107,7 +108,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### Preference Dataset ### Preference Dataset
Preference datasets are used for reward modeling, DPO training and ORPO training. Preference datasets are used for reward modeling, DPO training, ORPO and SimPO training.
It requires a better response in `chosen` column and a worse response in `rejected` column. It requires a better response in `chosen` column and a worse response in `rejected` column.
@@ -139,67 +140,15 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
### KTO Dataset ### KTO Dataset
- [Example dataset](kto_en_demo.json) An additional column `kto_tag` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
KTO datasets require a extra `kto_tag` column containing the boolean human feedback. ### Multimodal Image Dataset
```json An additional column `images` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be: ### Multimodal Video Dataset
```json An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Dataset
- [Example dataset](mllm_demo.json)
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
```json
[
{
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt Format ## Sharegpt Format
@@ -252,6 +201,10 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
### Pre-training Dataset
Not yet supported, please use the [alpaca](#alpaca-format) format.
### Preference Dataset ### Preference Dataset
- [Example dataset](dpo_en_demo.json) - [Example dataset](dpo_en_demo.json)
@@ -302,6 +255,125 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
### KTO Dataset
- [Example dataset](kto_en_demo.json)
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
```json
[
{
"conversations": [
{
"from": "human",
"value": "human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"kto_tag": "human feedback [true/false] (required)"
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### Multimodal Image Dataset
- [Example dataset](mllm_demo.json)
Multimodal image datasets require a `images` column containing the paths to the input images.
The number of images should be identical to the `<image>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"images": [
"image path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### Multimodal Video Dataset
- [Example dataset](mllm_video_demo.json)
Multimodal video datasets require a `videos` column containing the paths to the input videos.
The number of videos should be identical to the `<video>` tokens in the conversations.
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>human instruction"
},
{
"from": "gpt",
"value": "model response"
}
],
"videos": [
"video path (required)"
]
}
]
```
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
```json
"dataset_name": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI Format ### OpenAI Format
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt. The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
@@ -345,7 +417,3 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
} }
``` ```
The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
Pre-training datasets are **incompatible** with the sharegpt format.

View File

@@ -23,6 +23,7 @@
"system": "数据集代表系统提示的表头名称默认None", "system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None", "tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None", "images": "数据集代表图像输入的表头名称默认None",
"videos": "数据集代表视频输入的表头名称默认None",
"chosen": "数据集代表更优回答的表头名称默认None", "chosen": "数据集代表更优回答的表头名称默认None",
"rejected": "数据集代表更差回答的表头名称默认None", "rejected": "数据集代表更差回答的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None" "kto_tag": "数据集代表 KTO 标签的表头名称默认None"
@@ -107,7 +108,7 @@
### 偏好数据集 ### 偏好数据集
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。 偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。 它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
@@ -139,67 +140,15 @@
### KTO 数据集 ### KTO 数据集
- [样例数据集](kto_en_demo.json) KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#sharegpt-格式)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。 ### 多模态图像数据集
```json 多模态图像数据集需要提供额外的 `images` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为: ### 多模态视频数据集
```json 多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"kto_tag": "kto_tag"
}
}
```
### 多模态数据集
- [样例数据集](mllm_demo.json)
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
```json
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"images": "images"
}
}
```
## Sharegpt 格式 ## Sharegpt 格式
@@ -252,6 +201,10 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
``` ```
### 预训练数据集
尚不支持,请使用 [alpaca](#alpaca-格式) 格式。
### 偏好数据集 ### 偏好数据集
- [样例数据集](dpo_zh_demo.json) - [样例数据集](dpo_zh_demo.json)
@@ -302,6 +255,125 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
} }
``` ```
### KTO 数据集
- [样例数据集](kto_en_demo.json)
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
```json
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"kto_tag": "人类反馈 [true/false](必填)"
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"kto_tag": "kto_tag"
}
}
```
### 多模态图像数据集
- [样例数据集](mllm_demo.json)
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<image>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"images": [
"图像路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"images": "images"
}
}
```
### 多模态视频数据集
- [样例数据集](mllm_video_demo.json)
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
```json
[
{
"conversations": [
{
"from": "human",
"value": "<video>人类指令"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"videos": [
"视频路径(必填)"
]
}
]
```
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
```json
"数据集名称": {
"file_name": "data.json",
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"videos": "videos"
}
}
```
### OpenAI 格式 ### OpenAI 格式
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。 OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
@@ -345,7 +417,3 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
} }
} }
``` ```
Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似。
预训练数据集**不支持** sharegpt 格式。

View File

@@ -17,9 +17,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
_LICENSE = "gpl-3.0" _LICENSE = "gpl-3.0"
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
@@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = [] conversations = []

View File

@@ -8,9 +8,9 @@ import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit" _LICENSE = "mit"
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = { _URLS = {
"train": [ "train": [
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
@@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
data = json.loads(row) data = json.loads(row)
chosen = data["chosen"] chosen = data["chosen"]

BIN
data/mllm_demo_data/1.mp4 Normal file

Binary file not shown.

BIN
data/mllm_demo_data/2.avi Normal file

Binary file not shown.

BIN
data/mllm_demo_data/3.mp4 Normal file

Binary file not shown.

View File

@@ -20,9 +20,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0" _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT) _BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
@@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)

View File

@@ -1,6 +1,7 @@
# Use the NVIDIA official image with PyTorch 2.3.0 # Default use the NVIDIA official image with PyTorch 2.3.0
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
FROM nvcr.io/nvidia/pytorch:24.02-py3 ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
FROM ${BASE_IMAGE}
# Define environments # Define environments
ENV MAX_JOBS=4 ENV MAX_JOBS=4
@@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory # Set the working directory
@@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]" pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention # Rebuild flash attention

View File

@@ -8,11 +8,15 @@ services:
INSTALL_VLLM: false INSTALL_VLLM: false
INSTALL_DEEPSPEED: false INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
INSTALL_EETQ: false
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
ports: ports:
@@ -20,6 +24,7 @@ services:
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true tty: true
shm_size: '16gb'
stdin_open: true stdin_open: true
command: bash command: bash
deploy: deploy:

View File

@@ -1,9 +1,9 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1 # Use the Ubuntu 22.04 image with CANN 8.0.rc1
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags # More versions can be found at https://hub.docker.com/r/ascendai/cann/tags
# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04 # FROM ascendai/cann:8.0.rc1-910-ubuntu22.04-py3.8
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04 FROM ascendai/cann:8.0.rc1-910b-ubuntu22.04-py3.8
# FROM cosdt/cann:8.0.rc1-910-openeuler22.03 # FROM ascendai/cann:8.0.rc1-910-openeuler22.03-py3.8
# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03 # FROM ascendai/cann:8.0.rc1-910b-openeuler22.03-py3.8
# Define environments # Define environments
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive

View File

@@ -10,6 +10,7 @@ services:
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
- /usr/local/dcmi:/usr/local/dcmi - /usr/local/dcmi:/usr/local/dcmi
@@ -21,6 +22,7 @@ services:
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true tty: true
shm_size: '16gb'
stdin_open: true stdin_open: true
command: bash command: bash
devices: devices:

View File

@@ -0,0 +1,65 @@
FROM hardandheavy/transformers-rocm:2.2.0
# Define environments
ENV MAX_JOBS=4
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
# Define installation arguments
ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory
WORKDIR /app
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url "$PIP_INDEX" && \
pip config set global.extra-index-url "$PIP_INDEX" && \
python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
# Copy the rest of the application into the image
COPY . /app
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_BNB" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention
RUN pip uninstall -y transformer-engine flash-attn && \
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
pip uninstall -y ninja && pip install ninja && \
pip install --no-cache-dir flash-attn --no-build-isolation; \
fi
# Set up volumes
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
# Expose port 7860 for the LLaMA Board
ENV GRADIO_SERVER_PORT 7860
EXPOSE 7860
# Expose port 8000 for the API service
ENV API_PORT 8000
EXPOSE 8000

View File

@@ -0,0 +1,33 @@
services:
llamafactory:
build:
dockerfile: ./docker/docker-rocm/Dockerfile
context: ../..
args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes:
- ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data
- ../../output:/app/output
- ../../saves:/app/saves
ports:
- "7860:7860"
- "8000:8000"
ipc: host
tty: true
shm_size: '16gb'
stdin_open: true
command: bash
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
restart: unless-stopped

View File

@@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
df = pd.read_csv(filepath, header=None) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): yield from enumerate(df.to_dict(orient="records"))
yield i, instance

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### Multimodal DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
``` ```
#### Reward Modeling #### Reward Modeling
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training #### KTO Training
```bash ```bash
@@ -82,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Multimodal Supervised Fine-Tuning
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using Adam-mini
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash

View File

@@ -33,6 +33,19 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### 多模态 DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
@@ -47,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练 #### KTO 训练
```bash ```bash
@@ -82,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### 多机指令监督微调 #### 多机指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
@@ -133,6 +140,12 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 多模态指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
@@ -189,6 +202,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### 使用 Adam-mini 进行全参数训练
```bash
llamafactory-cli train examples/extras/adam_mini/qwen2_full_sft.yaml
```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash

View File

@@ -1,27 +1,22 @@
### model ### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct model_name_or_path: Qwen/Qwen2-1.5B-Instruct
### method ### method
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
use_badam: true use_adam_mini: true
badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: qwen
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/llama3-8b/full/sft output_dir: saves/qwen2-1_5b/full/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
@@ -30,10 +25,12 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -10,11 +10,12 @@ badam_mode: layer
badam_switch_mode: ascending badam_switch_mode: ascending
badam_switch_interval: 50 badam_switch_interval: 50
badam_verbose: 2 badam_verbose: 2
# deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
@@ -29,7 +30,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -11,7 +11,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -14,7 +14,7 @@ galore_scale: 2.0
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
@@ -29,11 +29,12 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -2,5 +2,5 @@
python scripts/llama_pro.py \ python scripts/llama_pro.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-instruct-pro \ --output_dir models/llama3-8b-pro \
--num_expand 8 --num_expand 8

View File

@@ -1,5 +1,5 @@
### model ### model
model_name_or_path: models/llama3-8b-instruct-pro model_name_or_path: models/llama3-8b-pro
### method ### method
stage: sft stage: sft
@@ -12,13 +12,13 @@ use_llama_pro: true
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/llama3-8b-instruct-pro/freeze/sft output_dir: saves/llama3-8b-pro/freeze/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true

View File

@@ -11,7 +11,7 @@ loraplus_lr_ratio: 16.0
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ mixture_of_depths: convert
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
@@ -26,7 +26,7 @@ overwrite_output_dir: true
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
optim: paged_adamw_8bit optim: paged_adamw_8bit
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -0,0 +1,5 @@
#!/bin/bash
python scripts/pissa_init.py \
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
--output_dir models/llama3-8b-pissa

View File

@@ -13,7 +13,7 @@ pissa_convert: true
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -1,3 +1,2 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
template: vicuna template: llava
visual_inputs: true

View File

@@ -0,0 +1,2 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl

View File

@@ -0,0 +1,13 @@
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
template: qwen2_vl
finetuning_type: lora
### export
export_dir: models/qwen2_vl_lora_sft
export_size: 2
export_device: cpu
export_legacy_format: false

View File

@@ -7,9 +7,9 @@ do_predict: true
finetuning_type: full finetuning_type: full
### dataset ### dataset
dataset: identity,alpaca_en_demo eval_dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
@@ -25,7 +25,7 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 1.0e-4 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -12,7 +12,7 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset ### dataset
dataset: dpo_en_demo dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -11,7 +11,7 @@ pref_beta: 0.1
### dataset ### dataset
dataset: kto_en_demo dataset: kto_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -11,7 +11,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ finetuning_type: lora
### dataset ### dataset
eval_dataset: identity,alpaca_en_demo eval_dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 50 max_samples: 50
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -9,7 +9,7 @@ lora_target: all
### dataset ### dataset
dataset: c4_demo dataset: c4_demo
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: dpo_en_demo dataset: dpo_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z0_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -1,6 +1,5 @@
### model ### model
model_name_or_path: llava-hf/llava-1.5-7b-hf model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true
### method ### method
stage: sft stage: sft
@@ -10,8 +9,8 @@ lora_target: all
### dataset ### dataset
dataset: mllm_demo dataset: mllm_demo
template: vicuna template: llava
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -0,0 +1,41 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: rlhf_v
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -0,0 +1,39 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
### dataset
dataset: mllm_demo,identity # video: mllm_video_demo
template: qwen2_vl
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -10,7 +10,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -12,7 +12,7 @@ lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
cutoff_len: 1024 cutoff_len: 2048
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16

View File

@@ -1,9 +1,9 @@
transformers>=4.41.2 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0 datasets>=2.16.0,<=3.1.0
accelerate>=0.30.1 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
gradio>=4.0.0 gradio>=4.0.0,<5.0.0
pandas>=2.0.0 pandas>=2.0.0
scipy scipy
einops einops
@@ -19,3 +19,5 @@ fire
packaging packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0
av
tyro<0.9.0

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# #
# This code is inspired by the Microsoft's DeepSpeed library. # This code is inspired by the Microsoft's DeepSpeed library.
@@ -27,7 +26,7 @@ from llamafactory.chat import ChatModel
def calculate_flops( def calculate_flops(
model_name_or_path: str, model_name_or_path: str,
batch_size: int = 1, batch_size: int = 1,
seq_length: int = 256, seq_length: int = 512,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r""" r"""
@@ -36,9 +35,11 @@ def calculate_flops(
""" """
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()} input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True) flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops) print("FLOPs:", flops)
print("MACs:", macs) print("MACs:", macs)
print("Params:", params) print("Params:", params)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team. # Copyright 2024 imoneoi and the LlamaFactory team.
# #
# This code is inspired by the imoneoi's OpenChat library. # This code is inspired by the imoneoi's OpenChat library.
@@ -25,7 +24,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -39,16 +38,17 @@ def calculate_lr(
model_name_or_path: str, model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Literal["pt", "sft"] = "sft", stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate, is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False, packing: bool = False,
): ):
r""" r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
""" """
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
@@ -66,13 +66,14 @@ def calculate_lr(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft": elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0 valid_tokens, total_tokens = 0, 0
@@ -84,7 +85,7 @@ def calculate_lr(
valid_ratio = valid_tokens / total_tokens valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr lr = lr / 6.0 if is_mistral_or_gemma else lr
print( print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format( "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len lr, valid_ratio * 100, batch_valid_len

163
scripts/cal_mfu.py Normal file
View File

@@ -0,0 +1,163 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
intermediate_size = getattr(config, "intermediate_size", None)
num_attention_heads = getattr(config, "num_attention_heads", None)
num_key_value_heads = getattr(config, "num_key_value_heads", None)
num_hidden_layers = getattr(config, "num_hidden_layers", None)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
o_flops_per_token = BASE * hidden_size * hidden_size
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
non_embedding_coeff, embedding_coeff = 1, 1
if include_backward:
non_embedding_coeff += 2
embedding_coeff += 2
if include_recompute:
non_embedding_coeff += 1
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
if include_flashattn:
total_flops += sdpa_flops
return total_flops
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
"model_name_or_path": model_name_or_path,
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"logging_strategy": "no",
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":
fire.Fire(calculate_mfu)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,7 +22,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer from llamafactory.model import load_model, load_tokenizer
@@ -55,12 +54,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
return super().__call__(chosen_features) return super().__call__(chosen_features)
def cal_ppl( def calculate_ppl(
model_name_or_path: str, model_name_or_path: str,
save_name: str, save_name: str,
batch_size: int = 4, batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft", stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 1024, cutoff_len: int = 1024,
@@ -69,7 +68,7 @@ def cal_ppl(
): ):
r""" r"""
Calculates the ppl on the dataset of the pre-trained models. Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json Usage: python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
""" """
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
@@ -88,7 +87,8 @@ def cal_ppl(
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False) model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt": if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
@@ -99,7 +99,7 @@ def cal_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
) )
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
@@ -124,9 +124,9 @@ def cal_ppl(
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2) json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities))) print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print("Perplexities have been saved at {}.".format(save_name)) print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(cal_ppl) fire.Fire(calculate_ppl)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,21 +17,21 @@ from collections import defaultdict
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from llamafactory.data import get_dataset from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
def length_cdf( def length_cdf(
model_name_or_path: str, model_name_or_path: str,
dataset: str = "alpaca_en", dataset: str = "alpaca_en_demo",
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r""" r"""
Calculates the distribution of the input lengths in the dataset. Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
""" """
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
@@ -48,7 +47,8 @@ def length_cdf(
) )
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"] template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset) total_num = len(trainset)
length_dict = defaultdict(int) length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]): for sample in tqdm(trainset["input_ids"]):
@@ -60,7 +60,7 @@ def length_cdf(
for length, count in length_tuples: for length, count in length_tuples:
count_accu += count count_accu += count
prob_accu += count / total_num * 100 prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval)) print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# #
# This code is inspired by the Tencent's LLaMA-Pro library. # This code is inspired by the Tencent's LLaMA-Pro library.
@@ -19,7 +18,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import fire import fire
import torch import torch
@@ -40,15 +39,15 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str: def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index)) return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion( def block_expansion(
model_name_or_path: str, model_name_or_path: str,
output_dir: str, output_dir: str,
num_expand: int, num_expand: int,
shard_size: Optional[str] = "2GB", shard_size: str = "2GB",
save_safetensors: Optional[bool] = False, save_safetensors: bool = True,
): ):
r""" r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
@@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict() state_dict = model.state_dict()
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand)) raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
output_state_dict = OrderedDict() output_state_dict = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1 layer_cnt += 1
if (i + 1) % split == 0: if (i + 1) % split == 0:
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key: if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else: else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1 layer_cnt += 1
for key, value in state_dict.items(): for key, value in state_dict.items():
@@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze") print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand)) print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true") print("use_llama_pro: true")

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional from typing import Any, Dict
import fire import fire
import torch import torch
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@@ -82,11 +81,14 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2( def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = True,
): ):
r""" r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B. Converts the Baichuan2-7B model in the same format as LLaMA2-7B.

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional from typing import Any, Dict
import fire import fire
import torch import torch
@@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key: elif "lm_head" in key:
llama2_state_dict[key] = value llama2_state_dict[key] = value
else: else:
raise KeyError("Unable to process key {}".format(key)) raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "") return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: Dict[str, Any] = OrderedDict()
@@ -135,11 +134,14 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen( def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = False,
): ):
r""" r"""
Converts the Qwen models in the same format as LLaMA2. Converts the Qwen models in the same format as LLaMA2.

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@@ -67,22 +66,22 @@ def quantize_loftq(
loftq_dir = os.path.join(output_dir, "loftq_init") loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model # Save LoftQ model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir)) print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(loftq_dir)) print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits)) print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@@ -31,7 +30,7 @@ if TYPE_CHECKING:
def quantize_pissa( def quantize_pissa(
model_name_or_path: str, model_name_or_path: str,
output_dir: str, output_dir: str,
pissa_iter: int = 4, pissa_iter: int = 16,
lora_alpha: int = None, lora_alpha: int = None,
lora_rank: int = 16, lora_rank: int = 16,
lora_dropout: float = 0, lora_dropout: float = 0,
@@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
target_modules=lora_target, target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
) )
# Init PiSSA model # Init PiSSA model
@@ -62,19 +61,20 @@ def quantize_pissa(
pissa_dir = os.path.join(output_dir, "pissa_init") pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model # Save PiSSA model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir)) print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(pissa_dir)) print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("pissa_init: false") print("pissa_init: false")
print("pissa_convert: true") print("pissa_convert: true")

65
scripts/test_image.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
messages = []
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "Output the color and number of each box."},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 1:", result.choices[0].message.content)
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "What kind of flower is this?"},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 2:", result.choices[0].message.content)
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
if __name__ == "__main__":
main()

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -14,42 +14,54 @@
import os import os
import re import re
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
def get_version(): def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
return version return version
def get_requires(): def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"], "deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
"eetq": ["eetq"], "eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3"], "vllm": ["vllm>=0.4.3,<0.6.4"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
} }
@@ -60,7 +72,7 @@ def main():
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",
@@ -70,7 +82,7 @@ def main():
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]}, entry_points={"console_scripts": get_console_scripts()},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@@ -20,22 +20,28 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0 datasets>=2.16.0,<=3.1.0
accelerate>=0.30.1 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1 peft>=0.11.1,<=0.12.0
trl>=0.8.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.42.4 transformers>=4.41.2,<=4.46.1
patcher:
transformers==4.41.2 (chatglm) Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
""" """
from .cli import VERSION from .extras.env import VERSION
__version__ = VERSION __version__ = VERSION

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from typing import Optional from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -50,14 +52,24 @@ if is_uvicorn_available():
import uvicorn import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield yield
torch_gc() torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan) root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
@@ -65,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY") api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@@ -79,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(
@@ -116,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None: def run_api() -> None:
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@@ -16,11 +16,12 @@ import base64
import io import io
import json import json
import os import os
import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras import logging
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
@@ -51,13 +52,12 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__) logger = logging.get_logger(__name__)
ROLE_MAPPING = { ROLE_MAPPING = {
Role.USER: DataRole.USER.value, Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value, Role.ASSISTANT: DataRole.ASSISTANT.value,
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None images = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -104,15 +104,14 @@ def _process_request(
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else: else:
image_url = input_item.image_url.url image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb") image_stream = open(image_url, "rb")
else: # web uri else: # web uri
image_path = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB") images.append(Image.open(image_stream).convert("RGB"))
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@@ -125,7 +124,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, image return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@@ -143,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response( async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@@ -170,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
@@ -194,8 +193,8 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response( async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -209,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)

View File

@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"] model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments", data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod @abstractmethod
async def chat( async def chat(
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ... ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...

View File

@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -38,14 +37,23 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface": if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
@@ -56,10 +64,16 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
)
return task.result() return task.result()
async def achat( async def achat(
@@ -67,20 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs) r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs) r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@@ -93,10 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(
@@ -104,6 +130,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
@@ -112,6 +141,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@@ -20,25 +20,26 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model( self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab
@@ -62,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning("There is no current event loop, creating a new one.") logger.warning_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if ( mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
processor is not None if images is not None:
and image is not None mm_input_dict.update({"images": images, "imglens": [len(images)]})
and not hasattr(processor, "image_seq_length") if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
and template.image_token not in messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
pixel_values = None prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.mm_plugin.process_token_ids(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
) )
if processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool) attention_mask = torch.ones_like(inputs, dtype=torch.bool)
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
@@ -164,8 +164,14 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if pixel_values is not None: mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
gen_kwargs["pixel_values"] = pixel_values for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@@ -180,11 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@@ -215,11 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@@ -242,37 +270,28 @@ class HuggingfaceEngine(BaseEngine):
batch_input: List[str], batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]: ) -> List[float]:
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer( inputs: Dict[str, "torch.Tensor"] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
input_ids: torch.Tensor = inputs["input_ids"] scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores return scores
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@@ -288,19 +307,22 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@@ -316,7 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@@ -328,6 +351,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration: except StopAsyncIteration:
break break
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@@ -15,36 +15,35 @@
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
if is_vllm_version_greater_than_0_5_1():
pass
elif is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from ..data.mm_plugin import ImageInput, VideoInput
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine): class VllmEngine(BaseEngine):
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {
@@ -84,19 +83,13 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank, "max_lora_rank": model_args.vllm_max_lora_rank,
} }
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
if model_args.visual_inputs:
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
@@ -110,40 +103,28 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = f"chatcmpl-{uuid.uuid4().hex}"
if images is not None:
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if ( if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
self.processor is not None image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
and image is not None
and not hasattr(self.processor, "image_seq_length")
and self.template.image_token not in messages[0]["content"]
): # llava-like models (TODO: paligemma models)
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
)
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5_1():
multi_modal_data = {"image": pixel_values}
elif is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else: else:
multi_modal_data = None image_str = self.template.mm_plugin.image_token or ""
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None) top_k: Optional[float] = input_kwargs.pop("top_k", None)
@@ -154,6 +135,9 @@ class VllmEngine(BaseEngine):
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
if "max_new_tokens" in self.generating_args: if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"] max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args: elif "max_length" in self.generating_args:
@@ -177,32 +161,47 @@ class VllmEngine(BaseEngine):
temperature=temperature if temperature is not None else self.generating_args["temperature"], temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"], top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
if images is not None: # add image features
image_data = []
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image_data.append(image)
multi_modal_data = {"image": image_data}
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return result_generator return result_generator
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@@ -219,21 +218,24 @@ class VllmEngine(BaseEngine):
return results return results
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["NDArray"] = None, images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
yield delta_text yield delta_text
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@@ -22,8 +22,8 @@ from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@@ -47,7 +47,7 @@ USAGE = (
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) + f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION)) + " " * (21 - len(VERSION))
+ "|\n|" + "|\n|"
+ " " * 56 + " " * 56
@@ -56,7 +56,7 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
@@ -86,25 +86,26 @@ def main():
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format( )
nnodes=os.environ.get("NNODES", "1"), .format(
node_rank=os.environ.get("RANK", "0"), nnodes=os.getenv("NNODES", "1"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
file_name=launcher.__file__, file_name=launcher.__file__,
args=" ".join(sys.argv[1:]), args=" ".join(sys.argv[1:]),
), )
shell=True, .split()
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: else:
@@ -118,4 +119,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}".format(command)) raise NotImplementedError(f"Unknown command: {command}.")

View File

@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask", "SFTDataCollatorWith4DAttentionMask",
"Role", "Role",

View File

@@ -14,11 +14,9 @@
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from datasets import Features from ..extras import logging
from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
@@ -27,88 +25,123 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(
images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
outputs = [] if not isinstance(images, list):
if dataset_attr.load_from in ["script", "file"]: images = [images]
for image in images: elif len(images) == 0:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): return None
outputs.append(os.path.join(data_args.dataset_dir, image))
else: else:
outputs.append(image) images = images[:]
return outputs if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
return images
def _convert_videos(
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
def convert_alpaca( def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts alpaca format dataset to the standard format. Converts alpaca format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in examples[dataset_attr.history][i]: for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt}) prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
content = [] query = []
if dataset_attr.prompt and examples[dataset_attr.prompt][i]: if dataset_attr.prompt and example[dataset_attr.prompt]:
content.append(examples[dataset_attr.prompt][i]) query.append(example[dataset_attr.prompt])
if dataset_attr.query and examples[dataset_attr.query][i]: if dataset_attr.query and example[dataset_attr.query]:
content.append(examples[dataset_attr.query][i]) query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if examples[dataset_attr.kto_tag][i]: if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}] response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else: else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif ( elif (
dataset_attr.ranking dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str) and isinstance(example[dataset_attr.chosen], str)
and isinstance(examples[dataset_attr.rejected][i], str) and isinstance(example[dataset_attr.rejected], str)
): # pairwise example ): # pairwise example
response = [ response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]}, {"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]}, {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
] ]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised else: # unsupervised
response = [] response = []
outputs["prompt"].append(prompt) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
outputs["response"].append(response) convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") output = {
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") "_prompt": prompt,
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) "_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
return outputs "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt( def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" example: Dict[str, Any],
) -> Dict[str, List[Any]]: dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r""" r"""
Converts sharegpt format dataset to the standard format. Converts sharegpt format dataset to the standard format.
""" """
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value, dataset_attr.assistant_tag: Role.ASSISTANT.value,
@@ -119,21 +152,22 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags) accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): messages = example[dataset_attr.messages]
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag] system = messages[0][dataset_attr.content_tag]
messages = messages[1:] messages = messages[1:]
else: else:
system = examples[dataset_attr.system][i] if dataset_attr.system else "" system = example[dataset_attr.system] if dataset_attr.system else ""
if len(messages) == 0:
continue
aligned_messages = [] aligned_messages = []
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
@@ -143,28 +177,28 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1] prompt = aligned_messages[:-1]
response = aligned_messages[-1:] response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]: if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}] response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else: else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif ( elif (
dataset_attr.ranking dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict) and isinstance(example[dataset_attr.chosen], dict)
and isinstance(examples[dataset_attr.rejected][i], dict) and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example ): # pairwise example
chosen = examples[dataset_attr.chosen][i] chosen = example[dataset_attr.chosen]
rejected = examples[dataset_attr.rejected][i] rejected = example[dataset_attr.rejected]
if ( if (
chosen[dataset_attr.role_tag] not in accept_tags[-1] chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True broken_data = True
prompt = aligned_messages prompt = aligned_messages
@@ -177,16 +211,20 @@ def convert_sharegpt(
response = aligned_messages[-1:] response = aligned_messages[-1:]
if broken_data: if broken_data:
logger.warning("Skipping this abnormal example.") logger.warning_rank0("Skipping this abnormal example.")
continue prompt, response = [], []
outputs["prompt"].append(prompt) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
outputs["response"].append(response) convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
outputs["system"].append(system) output = {
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") "_prompt": prompt,
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) "_response": response,
"_system": system,
return outputs "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset( def align_dataset(
@@ -197,11 +235,12 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1) _prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) _response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." _system: "..."
tools: "...", _tools: "...",
images: [], _images: [],
_videos: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
@@ -209,19 +248,6 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@@ -232,8 +258,7 @@ def align_dataset(
return dataset.map( return dataset.map(
convert_func, convert_func,
batched=True, batched=False,
remove_columns=column_names, remove_columns=column_names,
features=features,
**kwargs, **kwargs,
) )

View File

@@ -16,12 +16,18 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@@ -62,7 +68,45 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for 4d attention mask. Data collator for 4d attention mask.
""" """
@@ -80,7 +124,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
@@ -96,23 +140,19 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
for key in ("chosen", "rejected"): for key in ("chosen", "rejected"):
for feature in features: for feature in features:
target_feature = { target_feature = {
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature[f"{key}_attention_mask"],
"labels": feature["{}_labels".format(key)], "labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
return super().__call__(concatenated_features) return super().__call__(concatenated_features)
@dataclass @dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""
Data collator for KTO data. Data collator for KTO data.
""" """
@@ -126,19 +166,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids": feature["input_ids"], "input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"], "attention_mask": feature["attention_mask"],
"labels": feature["labels"], "labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
kl_feature = { kl_feature = {
"input_ids": feature["kl_input_ids"], "input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
} }
if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])
@@ -148,7 +185,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags) batch["kto_tags"] = torch.tensor(kto_tags)

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@@ -49,16 +49,19 @@ class DatasetModule(TypedDict):
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.") logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
""" """
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)

View File

@@ -16,21 +16,36 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None tool_format: Optional[str] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError raise NotImplementedError
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder: if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.") raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
return self.slots return self.slots
@@ -60,20 +76,21 @@ class StringFormatter(Formatter):
if not has_placeholder: if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.") raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
elements = [] elements = []
for slot in self.slots: for slot in self.slots:
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str): if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value)) raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = [] functions: List[Tuple[str, str]] = []
@@ -100,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
functions = [] raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
@@ -111,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": self.tool_utils = get_tool_utils(self.tool_format)
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: @override
return self._tool_extractor(content) def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@@ -20,14 +20,13 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -40,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _load_single_dataset( def _load_single_dataset(
@@ -49,9 +48,12 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr)) r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
@@ -67,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File {} not found.".format(local_path)) raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
@@ -96,10 +97,27 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else: else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
@@ -109,16 +127,13 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes) target_num -= len(indexes)
if target_num > 0: if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num) expand_indexes = np.random.choice(len(dataset), target_num)
@@ -126,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
@@ -142,6 +157,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
@@ -165,6 +183,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
@@ -180,7 +201,13 @@ def _get_preprocessed_dataset(
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log: if training_args.should_log:
try: try:
@@ -196,6 +223,7 @@ def _get_preprocessed_dataset(
def get_dataset( def get_dataset(
template: "Template",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@@ -203,20 +231,20 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) r"""
if data_args.train_on_prompt and template.efficient_eos: Gets the train dataset and optionally gets the evaluation dataset.
raise ValueError("Current template does not support `train_on_prompt`.") """
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
@@ -262,14 +290,15 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)
dataset_module = {} dataset_module = {}
if "train" in dataset_dict: if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"] dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]

View File

@@ -0,0 +1,787 @@
import math
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_pyav_available():
import av
if is_transformers_version_greater_than("4.45.0"):
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
if TYPE_CHECKING:
from av.stream import Stream
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
VideoInput = str
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
def _validate_input(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""
Pre-processes a single image.
"""
image_resolution: int = kwargs.get("image_resolution")
if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
return image
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
r"""
Computes video sample frames according to fps.
"""
video_fps: float = kwargs.get("video_fps")
video_maxlen: int = kwargs.get("video_maxlen")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(total_frames, video_maxlen, sample_frames)
return math.floor(sample_frames)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs))
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
input_dict["images"] = images
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
input_dict["videos"] = videos
mm_inputs = {}
if image_processor != video_processor:
if input_dict.get("images") is not None:
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
return mm_inputs
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
return {}
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs)
if min(image.width, image.height) < 28:
width, height = max(image.width, 28), max(image.height, 28)
image = image.resize((width, height), resample=Image.NEAREST)
if image.width / image.height > 200:
width, height = image.height * 180, image.height
image = image.resize((width, height), resample=Image.NEAREST)
if image.height / image.width > 200:
width, height = image.width, image.width * 180
image = image.resize((width, height), resample=Image.NEAREST)
return image
@override
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
sample_frames = sample_frames // 2 * 2
return sample_frames
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
return image_processor([[image] for image in images], return_tensors="pt")
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
if len(images) != len(batch_ids):
raise ValueError("Mllama only supports one image per sample.")
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
return mm_inputs
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
"mllama": MllamaPlugin,
}
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token)

View File

@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
@@ -30,7 +30,7 @@ class DatasetAttr:
""" """
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False ranking: bool = False
@@ -43,6 +43,7 @@ class DatasetAttr:
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns # rlhf columns
chosen: Optional[str] = None chosen: Optional[str] = None
rejected: Optional[str] = None rejected: Optional[str] = None
@@ -86,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: List["DatasetAttr"] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub" if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name] has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url: if has_hf_url or has_ms_url or has_om_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url): if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else: else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
@@ -126,7 +135,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr.set_attr("num_samples", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"]) column_names.extend(["prompt", "query", "response", "history"])
else: else:

View File

@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset, preprocess_packed_supervised_dataset,
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor,
data_args=data_args, data_args=data_args,
) )
else: else:

View File

@@ -12,21 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from .processor_utils import infer_seqlen
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_feedback_example( def _encode_feedback_example(
@@ -35,14 +37,13 @@ def _encode_feedback_example(
kl_response: Sequence[Dict[str, str]], kl_response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@@ -55,6 +56,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
@@ -62,15 +65,13 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len] response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len) kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len] kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len] kl_response_ids = kl_response_ids[:kl_target_len]
@@ -78,7 +79,6 @@ def _encode_feedback_example(
labels = [IGNORE_INDEX] * source_len + response_ids labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag return input_ids, labels, kl_input_ids, kl_labels, kto_tag
@@ -88,39 +88,29 @@ def preprocess_feedback_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1] kl_response = examples["_response"][::-1]
model_inputs = { model_inputs = defaultdict(list)
"input_ids": [], for i in range(len(examples["_prompt"])):
"attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"labels": [], logger.warning_rank0(
"kl_input_ids": [], "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
"kl_attention_mask": [], )
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
@@ -129,15 +119,12 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0: if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.") logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs return model_inputs

View File

@@ -12,21 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from .processor_utils import infer_seqlen
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_pairwise_example( def _encode_pairwise_example(
@@ -34,16 +36,15 @@ def _encode_pairwise_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
prompt[0]["content"] = template.image_token + prompt[0]["content"] rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
@@ -51,13 +52,9 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) # consider the response is more important
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
) # consider the response is more important
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len] chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len] rejected_ids = rejected_ids[:target_len]
@@ -66,7 +63,6 @@ def _encode_pairwise_example(
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
@@ -76,36 +72,27 @@ def preprocess_pairwise_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = { model_inputs = defaultdict(list)
"chosen_input_ids": [], for i in range(len(examples["_prompt"])):
"chosen_attention_mask": [], if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
"chosen_labels": [], logger.warning_rank0(
"rejected_input_ids": [], "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
"rejected_attention_mask": [], )
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["chosen_input_ids"].append(chosen_input_ids) model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
@@ -113,15 +100,8 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"].append(
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
)
model_inputs["rejected_token_type_ids"].append(
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
)
return model_inputs return model_inputs
@@ -132,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"])) print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"])) print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@@ -27,16 +27,16 @@ if TYPE_CHECKING:
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@@ -13,20 +13,7 @@
# limitations under the License. # limitations under the License.
import bisect import bisect
from typing import TYPE_CHECKING, List, Sequence, Tuple from typing import List, Sequence, Tuple
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
@@ -61,23 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r""" r"""
Computes the real sequence length after truncation by the cutoff_len. Computes the real sequence length after truncation by the cutoff_len.

View File

@@ -15,19 +15,20 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from .processor_utils import greedy_knapsack, infer_seqlen
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_supervised_example( def _encode_supervised_example(
@@ -35,45 +36,47 @@ def _encode_supervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
prompt[0]["content"] = template.image_token + prompt[0]["content"] input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0 total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len: if total_length >= cutoff_len:
break break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length) source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len] source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len] target_ids = target_ids[:target_len]
total_length += source_len + target_len total_length += source_len + target_len
if data_args.train_on_prompt: if train_on_prompt:
source_label = source_ids source_label = source_ids
elif turn_idx != 0 and template.efficient_eos: elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else: else:
source_label = [IGNORE_INDEX] * source_len source_label = [IGNORE_INDEX] * source_len
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1: if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len target_label = [IGNORE_INDEX] * target_len
else: else:
target_label = target_ids target_label = target_ids
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids input_ids += source_ids + target_ids
labels += source_label + target_label labels += source_label + target_label
@@ -90,37 +93,36 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning_rank0(
model_inputs["token_type_ids"] = [] "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs
@@ -129,47 +131,60 @@ def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0 valid_num = 0
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = [] lengths = []
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=None, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else: else:
lengths.append(length) lengths.append(length)
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids) batch_input_ids.append(input_ids)
batch_labels.append(labels) batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1 valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing: if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@@ -190,6 +205,8 @@ def preprocess_packed_supervised_dataset(
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs return model_inputs
@@ -199,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@@ -12,21 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_unsupervised_example( def _encode_unsupervised_example(
@@ -34,28 +36,25 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
return input_ids, labels return input_ids, labels
@@ -67,36 +66,33 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning_rank0(
model_inputs["token_type_ids"] = [] "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
data_args=data_args, cutoff_len=data_args.cutoff_len,
) )
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs

View File

@@ -15,18 +15,24 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras import logging
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
@@ -41,9 +47,10 @@ class Template:
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
self, self,
@@ -140,13 +147,14 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids return token_ids
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
@override
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@@ -190,7 +198,7 @@ class Llama2Template(Template):
return encoded_messages return encoded_messages
TEMPLATES: Dict[str, Template] = {} TEMPLATES: Dict[str, "Template"] = {}
def _register_template( def _register_template(
@@ -205,9 +213,10 @@ def _register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@@ -254,9 +263,10 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin,
) )
@@ -265,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added: if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else: else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
@@ -300,6 +310,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = "" jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@@ -310,14 +323,15 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += ( jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}" "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template): if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}" jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}" jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template): if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
@@ -338,23 +352,28 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template return jinja_template
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
tokenizer: "PreTrainedTokenizer", r"""
name: Optional[str] = None, Gets chat template and fixes the tokenizer.
tool_format: Optional[str] = None, """
) -> Template: if data_args.template is None:
if name is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(name, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError(f"Template {data_args.template} does not exist.")
if tool_format is not None: if template.mm_plugin.__class__.__name__ != "BasePlugin":
logger.info("Using tool format: {}.".format(tool_format)) require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
@@ -369,20 +388,21 @@ def get_template_and_fix_tokenizer(
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if tokenizer.chat_template is None or template.replace_jinja_template:
try: try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError: except ValueError as e:
logger.info("Cannot add this chat template to tokenizer.") logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template return template
@@ -549,6 +569,15 @@ _register_template(
) )
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -578,6 +607,7 @@ _register_template(
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
@@ -585,14 +615,14 @@ _register_template(
_register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n"]), format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=( default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. " "developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, " "For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n" "and other non-computer science questions, you will refuse to answer.\n"
), ),
) )
@@ -611,6 +641,14 @@ _register_template(
) )
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@@ -635,6 +673,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=False,
) )
@@ -652,6 +691,14 @@ _register_template(
) )
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
@@ -711,6 +758,146 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
)
_register_template(
name="mllama",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
) )
@@ -751,6 +938,19 @@ _register_template(
) )
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@@ -758,6 +958,19 @@ _register_template(
) )
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@@ -769,6 +982,25 @@ _register_template(
) )
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -778,6 +1010,35 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
)
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
replace_eos=True,
) )
@@ -818,6 +1079,17 @@ _register_template(
) )
_register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
@@ -878,6 +1150,7 @@ _register_template(
), ),
stop_words=["###"], stop_words=["###"],
efficient_eos=True, efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
) )

View File

@@ -15,9 +15,12 @@
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
) )
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
@staticmethod """
@abstractmethod Base class for tool utilities.
def get_function_slots() -> SLOTS: ... """
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"] return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"] return ["{{name}}\n{{arguments}}"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text) return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content
@@ -138,3 +166,17 @@ class GLM4ToolUtils(ToolUtils):
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))] return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils

View File

@@ -39,7 +39,7 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
@@ -54,18 +54,22 @@ from ..model import load_model, load_tokenizer
from .template import get_eval_template from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@@ -83,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
@@ -132,10 +136,10 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items() for category_name, category_correct in category_corrects.items()
if len(category_correct) if len(category_correct)
] ]

Some files were not shown because too many files have changed in this diff Show More