Compare commits
419 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7468f2535c | ||
|
|
38e4f22605 | ||
|
|
2bc2fe7b5e | ||
|
|
6d0140d8a0 | ||
|
|
7856f98965 | ||
|
|
e25ddef08c | ||
|
|
95a4589bbf | ||
|
|
566d71b7a9 | ||
|
|
6030a4a720 | ||
|
|
5dc0cb94d4 | ||
|
|
325dafcbb0 | ||
|
|
1a8a8b8651 | ||
|
|
61a495cb1e | ||
|
|
75866aa020 | ||
|
|
9e4fda326d | ||
|
|
1131ddfaff | ||
|
|
9f437b5c43 | ||
|
|
0cc03d3f05 | ||
|
|
04fc2f78bf | ||
|
|
3ac333fc6a | ||
|
|
a246ac1914 | ||
|
|
48ceac845c | ||
|
|
b1986a06b9 | ||
|
|
43d134ba29 | ||
|
|
1348f7d860 | ||
|
|
f6530222f7 | ||
|
|
a74a7585e0 | ||
|
|
5bf0cca2b8 | ||
|
|
755b6511ff | ||
|
|
35621c6089 | ||
|
|
38b59664e6 | ||
|
|
933a084999 | ||
|
|
c1510d19c7 | ||
|
|
2074cf99fb | ||
|
|
b12176d818 | ||
|
|
117b67ea30 | ||
|
|
03e20bb5c6 | ||
|
|
0c4a1381a4 | ||
|
|
9e14501edb | ||
|
|
1dc963caa6 | ||
|
|
85726c91ce | ||
|
|
40211db275 | ||
|
|
e7f13098c6 | ||
|
|
61eb3a3d46 | ||
|
|
be0a807e8c | ||
|
|
52d402e2a9 | ||
|
|
c5a46f9113 | ||
|
|
00e17a377c | ||
|
|
9abd83adb1 | ||
|
|
f0d2afcf90 | ||
|
|
1aba442bcd | ||
|
|
d764cd8736 | ||
|
|
526111a303 | ||
|
|
b8364046df | ||
|
|
1f617c6e08 | ||
|
|
a6858a36c0 | ||
|
|
6198121923 | ||
|
|
b0efebf853 | ||
|
|
fbd0584391 | ||
|
|
50224b09cc | ||
|
|
32dcc5a491 | ||
|
|
9408366a36 | ||
|
|
f0e564beaa | ||
|
|
14b75a0b93 | ||
|
|
59e6ebf039 | ||
|
|
dc540dfaa8 | ||
|
|
587e65e442 | ||
|
|
a916688723 | ||
|
|
3336422760 | ||
|
|
04423b916f | ||
|
|
bf8d2f8eda | ||
|
|
2a5d02fd0f | ||
|
|
ea550ed9e0 | ||
|
|
02665cd42b | ||
|
|
0c6a94e66d | ||
|
|
ebd6bc2604 | ||
|
|
daab85e3e6 | ||
|
|
769d81a83d | ||
|
|
ac2a401b1d | ||
|
|
bb53c18153 | ||
|
|
04e0fe9147 | ||
|
|
39f75c7001 | ||
|
|
7f99cb1817 | ||
|
|
c555b2cce3 | ||
|
|
2eba1c6851 | ||
|
|
edeed55664 | ||
|
|
92248f9cb2 | ||
|
|
c548ad5e69 | ||
|
|
a57d839e1d | ||
|
|
d88a34bc79 | ||
|
|
60cbc9d0e5 | ||
|
|
d5005e766f | ||
|
|
4d0753cffe | ||
|
|
1cf0f11840 | ||
|
|
052e8b2cc6 | ||
|
|
8963e89633 | ||
|
|
935ee0a023 | ||
|
|
5ed234ca63 | ||
|
|
04884a0911 | ||
|
|
c7af26a9e3 | ||
|
|
d8073488be | ||
|
|
6fc2d7e063 | ||
|
|
e93c7cdb80 | ||
|
|
c32d6c8250 | ||
|
|
757158da63 | ||
|
|
ffdacaa618 | ||
|
|
e194efab10 | ||
|
|
772fc2eac7 | ||
|
|
ed020579dc | ||
|
|
096869c7b6 | ||
|
|
c6873211e9 | ||
|
|
623ee1bd88 | ||
|
|
aabe90343e | ||
|
|
764cfb506d | ||
|
|
249ad56075 | ||
|
|
46f99ff277 | ||
|
|
73f4513c84 | ||
|
|
3c91e86268 | ||
|
|
42473ec150 | ||
|
|
6a4e4b9c5b | ||
|
|
9a784fb4f3 | ||
|
|
43fd80a1aa | ||
|
|
e6ab1a57ea | ||
|
|
282edb9161 | ||
|
|
dff77004f2 | ||
|
|
6c1b4aec75 | ||
|
|
7814db1b42 | ||
|
|
c9ed3fc3a4 | ||
|
|
9ee416a8fc | ||
|
|
4f9a47c026 | ||
|
|
3fcb1c6d09 | ||
|
|
7c492864e9 | ||
|
|
7ff8a064f3 | ||
|
|
c635bbe465 | ||
|
|
4881f4e631 | ||
|
|
c631799f5d | ||
|
|
48846676d8 | ||
|
|
f37d481c5d | ||
|
|
5d7d8bd55c | ||
|
|
8ed1463236 | ||
|
|
43b2ede0f8 | ||
|
|
2f095e2017 | ||
|
|
9b55bb964c | ||
|
|
9b97b23ce7 | ||
|
|
53ab28533e | ||
|
|
940c00e7ae | ||
|
|
18cfd5f349 | ||
|
|
6169df1c52 | ||
|
|
d46c2bbcba | ||
|
|
48d4364586 | ||
|
|
8042c66a76 | ||
|
|
3879d79b89 | ||
|
|
e416cecf62 | ||
|
|
81fcb80466 | ||
|
|
bf812fbe40 | ||
|
|
1e6fb6c8aa | ||
|
|
5d0c95bd02 | ||
|
|
7cd2417002 | ||
|
|
16851d66e5 | ||
|
|
056d2d956a | ||
|
|
9a69cadab3 | ||
|
|
3de642bffd | ||
|
|
286b9d9849 | ||
|
|
cef1ede826 | ||
|
|
5007566588 | ||
|
|
e93fb3cc6c | ||
|
|
7578209735 | ||
|
|
67f02f75d0 | ||
|
|
73d9dfc7ab | ||
|
|
6b407092d9 | ||
|
|
3168abc0a1 | ||
|
|
46ee267cfc | ||
|
|
a10bead9b5 | ||
|
|
3553e301dd | ||
|
|
02b838b9b0 | ||
|
|
b1de6d1025 | ||
|
|
bc67872218 | ||
|
|
0229fffde5 | ||
|
|
3555b87363 | ||
|
|
2dca53962e | ||
|
|
f4f71f2797 | ||
|
|
77ab9457ed | ||
|
|
4fa53b6282 | ||
|
|
790b73586b | ||
|
|
9c29c2a172 | ||
|
|
863960d33e | ||
|
|
330e5381b4 | ||
|
|
5bb411fdb8 | ||
|
|
59a9a5994e | ||
|
|
5306a71b42 | ||
|
|
3eafa2dd9e | ||
|
|
88fddb879d | ||
|
|
71491825bf | ||
|
|
30855b924a | ||
|
|
48d2e6d7fe | ||
|
|
041c83ea03 | ||
|
|
0e621c2dc9 | ||
|
|
544e7a491b | ||
|
|
a2c881fa08 | ||
|
|
c53c7af168 | ||
|
|
a2d93e5269 | ||
|
|
b392e6cfb9 | ||
|
|
13aa2d389a | ||
|
|
1e7962dfc4 | ||
|
|
1c9556c84c | ||
|
|
ca3ca7a5b5 | ||
|
|
0500befdb4 | ||
|
|
f618feab51 | ||
|
|
4b06aa134f | ||
|
|
9cde56d760 | ||
|
|
d0ea203694 | ||
|
|
c5eb3fba62 | ||
|
|
a8bc32553c | ||
|
|
88f3358320 | ||
|
|
a85bdcf2f6 | ||
|
|
caf56b313e | ||
|
|
75603c45fc | ||
|
|
89f86cc970 | ||
|
|
c09a0e4f08 | ||
|
|
7bac6c9460 | ||
|
|
0c7d0bf172 | ||
|
|
a274900188 | ||
|
|
67deefe527 | ||
|
|
823f618cba | ||
|
|
bc16c9a54a | ||
|
|
a3f30038a0 | ||
|
|
e237f618c2 | ||
|
|
688adad665 | ||
|
|
0158812afb | ||
|
|
e52e0d9b07 | ||
|
|
eb2aa2c073 | ||
|
|
debfd46749 | ||
|
|
5ccf8fcd6b | ||
|
|
7bd1991513 | ||
|
|
456e4ca569 | ||
|
|
6bf0fe4913 | ||
|
|
596b6828cb | ||
|
|
b403f8d8a8 | ||
|
|
590b6c2143 | ||
|
|
5537ef1e7d | ||
|
|
5f83860aa1 | ||
|
|
62b6a7971a | ||
|
|
1d16e87c5f | ||
|
|
1955a8ea5a | ||
|
|
a41fa6e730 | ||
|
|
b98a64448a | ||
|
|
1ce82f391a | ||
|
|
4d473894fd | ||
|
|
5788b7c7d0 | ||
|
|
04515f6b55 | ||
|
|
96f8ccf3d5 | ||
|
|
2c3ef480a6 | ||
|
|
fa6873122c | ||
|
|
34bc0c22b1 | ||
|
|
e5484b2729 | ||
|
|
f67f781fed | ||
|
|
b564b97b7e | ||
|
|
0dd68d1e06 | ||
|
|
73f40f1ca4 | ||
|
|
ea53bebac4 | ||
|
|
00418012bd | ||
|
|
5f3d8c514b | ||
|
|
cb39a3f1c4 | ||
|
|
4d78fe6ece | ||
|
|
a3e3ea9846 | ||
|
|
feba34e82d | ||
|
|
e134013e04 | ||
|
|
5589d0296a | ||
|
|
de0ebab464 | ||
|
|
f2e7122a96 | ||
|
|
996cc5d900 | ||
|
|
a2ae5bd867 | ||
|
|
5fa52e87cb | ||
|
|
bcd76d2c7a | ||
|
|
36fcbedc11 | ||
|
|
1dad01cc53 | ||
|
|
5fb21f6e54 | ||
|
|
08dfac8352 | ||
|
|
956751e419 | ||
|
|
fe2ae04c91 | ||
|
|
5b8712d061 | ||
|
|
dc7ff90c1e | ||
|
|
1ace676170 | ||
|
|
8947a87b95 | ||
|
|
786a2f1103 | ||
|
|
36ac14a566 | ||
|
|
7a048fc91d | ||
|
|
3f3756b113 | ||
|
|
b36c4b99cc | ||
|
|
9856a2276e | ||
|
|
b6dc3ed3ad | ||
|
|
75be329994 | ||
|
|
1fe1ca1c8b | ||
|
|
882a6a1d51 | ||
|
|
712ab4ae7a | ||
|
|
18ad259fb3 | ||
|
|
fe4d93c6db | ||
|
|
c6ba588e37 | ||
|
|
3fda60fca0 | ||
|
|
96531a0ef8 | ||
|
|
7abc3065fb | ||
|
|
013ded4bac | ||
|
|
010c3c7348 | ||
|
|
bf075c075c | ||
|
|
41b34e5f60 | ||
|
|
5a889398e7 | ||
|
|
054cae86d8 | ||
|
|
cd1cb8b83c | ||
|
|
a34779c027 | ||
|
|
d19cb77d74 | ||
|
|
ab67528e89 | ||
|
|
27f281480a | ||
|
|
50459a39f4 | ||
|
|
5c9815ef6f | ||
|
|
aed00a97b6 | ||
|
|
7543dc4a9d | ||
|
|
841fa0030f | ||
|
|
66e0e651b9 | ||
|
|
1750218057 | ||
|
|
80637fc06d | ||
|
|
8efc055511 | ||
|
|
be61bfda93 | ||
|
|
1a39f529c0 | ||
|
|
0868d5c550 | ||
|
|
384f0e7678 | ||
|
|
9b390c4bea | ||
|
|
42a13fec46 | ||
|
|
790acc4c17 | ||
|
|
b74cf27538 | ||
|
|
ffc874ec6f | ||
|
|
546d6bd0b2 | ||
|
|
8b68ca029e | ||
|
|
502f84b30c | ||
|
|
b7df920860 | ||
|
|
e4a424cb6a | ||
|
|
d8affd3967 | ||
|
|
a423274fd9 | ||
|
|
f7329b1a0e | ||
|
|
48eb07c956 | ||
|
|
636d8a886c | ||
|
|
97b52c7fdf | ||
|
|
344412e66e | ||
|
|
5cdea14cdf | ||
|
|
7b1a56b96f | ||
|
|
d1ec884e75 | ||
|
|
aa72a4349e | ||
|
|
5ab7fd0842 | ||
|
|
86d5e9802a | ||
|
|
18df39e3a1 | ||
|
|
cfe1e24471 | ||
|
|
2edbe87a8c | ||
|
|
880055bc90 | ||
|
|
ad99bd0a14 | ||
|
|
c5f099138d | ||
|
|
6e64e02f71 | ||
|
|
f95f6ec009 | ||
|
|
8aeecc20e1 | ||
|
|
38d0f6c63f | ||
|
|
ac8534a9e7 | ||
|
|
73cab9d9d4 | ||
|
|
64246d42d2 | ||
|
|
6fa6d4532e | ||
|
|
92b9956c06 | ||
|
|
4d6669c268 | ||
|
|
89f4ae51f9 | ||
|
|
af0659f573 | ||
|
|
45a10d501e | ||
|
|
e529ff1245 | ||
|
|
b29371dc87 | ||
|
|
0bef890000 | ||
|
|
75fe1404b1 | ||
|
|
b460c9372f | ||
|
|
c3e574ceaa | ||
|
|
04ae80a52e | ||
|
|
a7ff095399 | ||
|
|
a655dcebaf | ||
|
|
8c74851b70 | ||
|
|
7168392a51 | ||
|
|
ccc5b324fe | ||
|
|
e85c205a81 | ||
|
|
7e225be16e | ||
|
|
ebb32e85f8 | ||
|
|
90d279f39f | ||
|
|
af3f5b6e16 | ||
|
|
53d7c5109f | ||
|
|
bf381563ff | ||
|
|
de4b9334e1 | ||
|
|
c33fbea469 | ||
|
|
921f593632 | ||
|
|
940403720a | ||
|
|
f869e44fe5 | ||
|
|
bcc92919a0 | ||
|
|
306a70c7ba | ||
|
|
d358d955e5 | ||
|
|
0fdd6074c3 | ||
|
|
6faf9c35a9 | ||
|
|
1066898e32 | ||
|
|
d05febe5de | ||
|
|
67f7034a21 | ||
|
|
79f301a2c6 | ||
|
|
31cbc67986 | ||
|
|
fe66bf3663 | ||
|
|
4691d4b35d | ||
|
|
acf5241845 | ||
|
|
2bce99b82f | ||
|
|
3c330869ef | ||
|
|
dba1af4841 | ||
|
|
2b1e52dcc9 | ||
|
|
b5238e945a | ||
|
|
afc0f29704 | ||
|
|
de0bb1d2da | ||
|
|
cc16ece283 | ||
|
|
31ba802fc9 | ||
|
|
4b27cf5460 | ||
|
|
a53b2a643f | ||
|
|
d925ecae1b | ||
|
|
13fd751a78 | ||
|
|
74575f8922 | ||
|
|
5e7bb5fe73 |
11
.dockerignore
Normal file
11
.dockerignore
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
.vscode
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
.venv
|
||||||
|
cache
|
||||||
|
data
|
||||||
|
examples
|
||||||
|
.dockerignore
|
||||||
|
.gitattributes
|
||||||
|
.gitignore
|
||||||
|
Dockerfile
|
||||||
21
.github/CONTRIBUTING.md
vendored
Normal file
21
.github/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Contributing to LLaMA Factory
|
||||||
|
|
||||||
|
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||||
|
|
||||||
|
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||||
|
|
||||||
|
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
|
||||||
|
|
||||||
|
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
|
||||||
|
|
||||||
|
## Ways to contribute
|
||||||
|
|
||||||
|
There are several ways you can contribute to LLaMA Factory:
|
||||||
|
|
||||||
|
* Fix outstanding issues with the existing code.
|
||||||
|
* Submit issues related to bugs or desired new features.
|
||||||
|
* Contribute to the examples or to the documentation.
|
||||||
|
|
||||||
|
### Style guide
|
||||||
|
|
||||||
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# What does this PR do?
|
||||||
|
|
||||||
|
Fixes # (issue)
|
||||||
|
|
||||||
|
## Before submitting
|
||||||
|
|
||||||
|
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
||||||
7
.github/SECURITY.md
vendored
Normal file
7
.github/SECURITY.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Reporting Security Issues
|
||||||
|
|
||||||
|
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/hiyouga/LLaMA-Factory/security/advisories/new) tab.
|
||||||
|
|
||||||
|
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||||
|
|
||||||
|
Report security bugs in third-party modules to the person or team maintaining the module.
|
||||||
29
.github/workflows/tests.yml
vendored
Normal file
29
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_code_quality:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.8"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install ruff
|
||||||
|
|
||||||
|
- name: Check quality
|
||||||
|
run: |
|
||||||
|
make style && make quality
|
||||||
37
CITATION.cff
Normal file
37
CITATION.cff
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
date-released: 2024-03
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
authors:
|
||||||
|
- family-names: "Zheng"
|
||||||
|
given-names: "Yaowei"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Richong"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Junhao"
|
||||||
|
- family-names: "Ye"
|
||||||
|
given-names: "Yanhan"
|
||||||
|
- family-names: "Luo"
|
||||||
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Ma"
|
||||||
|
given-names: "Yongqiang"
|
||||||
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
|
preferred-citation:
|
||||||
|
type: article
|
||||||
|
authors:
|
||||||
|
- family-names: "Zheng"
|
||||||
|
given-names: "Yaowei"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Richong"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Junhao"
|
||||||
|
- family-names: "Ye"
|
||||||
|
given-names: "Yanhan"
|
||||||
|
- family-names: "Luo"
|
||||||
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Ma"
|
||||||
|
given-names: "Yongqiang"
|
||||||
|
journal: "arXiv preprint arXiv:2403.13372"
|
||||||
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
|
year: 2024
|
||||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
FROM nvcr.io/nvidia/pytorch:24.01-py3
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt /app/
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
|
||||||
|
|
||||||
|
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
CMD [ "python", "src/train_web.py" ]
|
||||||
11
Makefile
Normal file
11
Makefile
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
.PHONY: quality style
|
||||||
|
|
||||||
|
check_dirs := scripts src tests
|
||||||
|
|
||||||
|
quality:
|
||||||
|
ruff check $(check_dirs)
|
||||||
|
ruff format --check $(check_dirs)
|
||||||
|
|
||||||
|
style:
|
||||||
|
ruff check $(check_dirs) --fix
|
||||||
|
ruff format $(check_dirs)
|
||||||
596
README.md
596
README.md
@@ -5,27 +5,30 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg).
|
||||||
|
|
||||||
\[ English | [中文](README_zh.md) \]
|
\[ English | [中文](README_zh.md) \]
|
||||||
|
|
||||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
**Fine-tuning a large language model can be easy as...**
|
||||||
|
|
||||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
|
||||||
|
|
||||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
Choose your path:
|
||||||
|
|
||||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Features](#features)
|
||||||
- [Benchmark](#benchmark)
|
- [Benchmark](#benchmark)
|
||||||
- [Changelog](#changelog)
|
- [Changelog](#changelog)
|
||||||
- [Supported Models](#supported-models)
|
- [Supported Models](#supported-models)
|
||||||
@@ -38,9 +41,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [Citation](#citation)
|
- [Citation](#citation)
|
||||||
- [Acknowledgement](#acknowledgement)
|
- [Acknowledgement](#acknowledgement)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||||
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
|
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||||
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -49,18 +62,40 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
||||||
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
||||||
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
||||||
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
|
||||||
|
|
||||||
|
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
||||||
|
|
||||||
|
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/extras/fsdp_qlora` for usage.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
|
||||||
|
|
||||||
|
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
|
||||||
|
|
||||||
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
||||||
|
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
|
||||||
|
|
||||||
|
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||||
|
|
||||||
|
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||||
|
|
||||||
|
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||||
|
|
||||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||||
@@ -93,20 +128,25 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
| Model | Model size | Default module | Template |
|
| Model | Model size | Default module | Template |
|
||||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||||
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||||
@@ -115,18 +155,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||||
|
|
||||||
|
You also can add a custom chat template to [template.py](src/llmtuner/data/template.py).
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
> [!NOTE]
|
|
||||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
|
||||||
|
|
||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
@@ -148,8 +188,8 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
|
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self Cognition (zh)](data/self_cognition.json)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
@@ -165,11 +205,14 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
|
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
|
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@@ -177,6 +220,17 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
|
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||||
|
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||||
|
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||||
|
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||||
|
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||||
|
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -185,12 +239,12 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for details.
|
|
||||||
|
|
||||||
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -200,394 +254,198 @@ huggingface-cli login
|
|||||||
|
|
||||||
## Requirement
|
## Requirement
|
||||||
|
|
||||||
- Python 3.8+ and PyTorch 1.13.1+
|
| Mandatory | Minimum | Recommend |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf and tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
- gradio and matplotlib (used in web UI)
|
| transformers | 4.37.2 | 4.39.3 |
|
||||||
- uvicorn, fastapi and sse-starlette (used in API)
|
| datasets | 2.14.3 | 2.18.0 |
|
||||||
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
|
||||||
|
| Optional | Minimum | Recommend |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.43.0 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.6 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
\* *estimated*
|
||||||
|
|
||||||
|
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
|
||||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||||
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Data Preparation (optional)
|
### Data Preparation
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||||
|
|
||||||
### Dependence Installation (optional)
|
### Dependence Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
conda create -n llama_factory python=3.10
|
conda create -n llama_factory python=3.10
|
||||||
conda activate llama_factory
|
conda activate llama_factory
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -r requirements.txt
|
pip install -e .[metrics]
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
Extra dependencies available: deepspeed, metrics, unsloth, galore, vllm, bitsandbytes, gptq, awq, aqlm, qwen, modelscope, quality
|
||||||
|
|
||||||
|
<details><summary>For Windows users</summary>
|
||||||
|
|
||||||
|
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use ModelScope Hub (optional)
|
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||||
|
|
||||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
</details>
|
||||||
|
|
||||||
|
### LLaMA Board GUI
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
|
||||||
|
|
||||||
|
#### Use local environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
|
||||||
|
python src/train_web.py # or python -m llmtuner.webui.interface
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./Dockerfile -t llama-factory:latest .
|
||||||
|
docker run --gpus=all \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-e CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
-p 7860:7860 \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llama_factory \
|
||||||
|
-d llama-factory:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use Docker Compose
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f ./docker-compose.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>Details about volume</summary>
|
||||||
|
|
||||||
|
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
|
- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
|
- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Command Line Interface
|
||||||
|
|
||||||
|
See [examples/README.md](examples/README.md) for usage.
|
||||||
|
|
||||||
|
Use `python src/train_bash.py -h` to display arguments description.
|
||||||
|
|
||||||
|
### Deploy with OpenAI-style API and vLLM
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||||
|
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
|
||||||
|
--template mistral \
|
||||||
|
--infer_backend vllm \
|
||||||
|
--vllm_enforce_eager
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use ModelScope Hub
|
||||||
|
|
||||||
|
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||||
```
|
```
|
||||||
|
|
||||||
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
|
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `modelscope/Llama-2-7b-ms`.
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
|
||||||
... # arguments (same as above)
|
|
||||||
```
|
|
||||||
|
|
||||||
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Train on a single GPU
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
|
|
||||||
|
|
||||||
#### Pre-Training
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage pt \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--dataset wiki_demo \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_pt_checkpoint \
|
|
||||||
--overwrite_cache \
|
|
||||||
--per_device_train_batch_size 4 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 5e-5 \
|
|
||||||
--num_train_epochs 3.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage sft \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--dataset alpaca_gpt4_en \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_sft_checkpoint \
|
|
||||||
--overwrite_cache \
|
|
||||||
--per_device_train_batch_size 4 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 5e-5 \
|
|
||||||
--num_train_epochs 3.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Reward Modeling
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage rm \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset comparison_gpt4_en \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_rm_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-6 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### PPO Training
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage ppo \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset alpaca_gpt4_en \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--reward_model path_to_rm_checkpoint \
|
|
||||||
--output_dir path_to_ppo_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--top_k 0 \
|
|
||||||
--top_p 0.9 \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
|
||||||
|
|
||||||
#### DPO Training
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage dpo \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset comparison_gpt4_en \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_dpo_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
### Distributed Training
|
|
||||||
|
|
||||||
#### Use Huggingface Accelerate
|
|
||||||
|
|
||||||
```bash
|
|
||||||
accelerate config # configure the environment
|
|
||||||
accelerate launch src/train_bash.py # arguments (same as above)
|
|
||||||
```
|
|
||||||
|
|
||||||
<details><summary>Example config for LoRA training</summary>
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
compute_environment: LOCAL_MACHINE
|
|
||||||
distributed_type: MULTI_GPU
|
|
||||||
downcast_bf16: 'no'
|
|
||||||
gpu_ids: all
|
|
||||||
machine_rank: 0
|
|
||||||
main_training_function: main
|
|
||||||
mixed_precision: fp16
|
|
||||||
num_machines: 1
|
|
||||||
num_processes: 4
|
|
||||||
rdzv_backend: static
|
|
||||||
same_network: true
|
|
||||||
tpu_env: []
|
|
||||||
tpu_use_cluster: false
|
|
||||||
tpu_use_sudo: false
|
|
||||||
use_cpu: false
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
#### Use DeepSpeed
|
|
||||||
|
|
||||||
```bash
|
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|
||||||
--deepspeed ds_config.json \
|
|
||||||
... # arguments (same as above)
|
|
||||||
```
|
|
||||||
|
|
||||||
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"zero_allow_untested_optimizer": true,
|
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 2,
|
|
||||||
"allgather_partitions": true,
|
|
||||||
"allgather_bucket_size": 5e8,
|
|
||||||
"reduce_scatter": true,
|
|
||||||
"reduce_bucket_size": 5e8,
|
|
||||||
"overlap_comm": false,
|
|
||||||
"contiguous_gradients": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Merge LoRA weights and export model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/export_model.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--export_dir path_to_export
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> Merging LoRA weights into a quantized model is not supported.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model.
|
|
||||||
|
|
||||||
### API Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/api_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> Visit `http://localhost:8000/docs` for API documentation.
|
|
||||||
|
|
||||||
### CLI Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/cli_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
### Web Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/web_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template vanilla \
|
|
||||||
--finetuning_type lora
|
|
||||||
--task mmlu \
|
|
||||||
--split test \
|
|
||||||
--lang en \
|
|
||||||
--n_shot 5 \
|
|
||||||
--batch_size 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### Predict
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage sft \
|
|
||||||
--do_predict \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--dataset alpaca_gpt4_en \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--output_dir path_to_predict_result \
|
|
||||||
--per_device_eval_batch_size 8 \
|
|
||||||
--max_samples 100 \
|
|
||||||
--predict_with_generate \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
|
||||||
|
|
||||||
## Projects using LLaMA Factory
|
## Projects using LLaMA Factory
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
|
||||||
|
|
||||||
> [!TIP]
|
<details><summary>Click to show</summary>
|
||||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
|
||||||
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If this work is helpful, please kindly cite as:
|
If this work is helpful, please kindly cite as:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@Misc{llama-factory,
|
@article{zheng2024llamafactory,
|
||||||
title = {LLaMA Factory},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author = {hiyouga},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
journal={arXiv preprint arXiv:2403.13372},
|
||||||
year = {2023}
|
year={2024},
|
||||||
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
|||||||
592
README_zh.md
592
README_zh.md
@@ -5,27 +5,30 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||||
|
|
||||||
\[ [English](README.md) | 中文 \]
|
\[ [English](README.md) | 中文 \]
|
||||||
|
|
||||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
**微调大模型可以像这样轻松…**
|
||||||
|
|
||||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
选择你的打开方式:
|
||||||
|
|
||||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
- **Colab**:https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
|
- [项目特色](#项目特色)
|
||||||
- [性能指标](#性能指标)
|
- [性能指标](#性能指标)
|
||||||
- [更新日志](#更新日志)
|
- [更新日志](#更新日志)
|
||||||
- [模型](#模型)
|
- [模型](#模型)
|
||||||
@@ -38,9 +41,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [引用](#引用)
|
- [引用](#引用)
|
||||||
- [致谢](#致谢)
|
- [致谢](#致谢)
|
||||||
|
|
||||||
|
## 项目特色
|
||||||
|
|
||||||
|
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
|
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||||
|
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||||
|
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
|
||||||
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
|
||||||
## 性能指标
|
## 性能指标
|
||||||
|
|
||||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -49,18 +62,40 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
||||||
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
||||||
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
||||||
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。
|
||||||
|
|
||||||
|
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
||||||
|
|
||||||
|
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/extras/fsdp_qlora`。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
|
||||||
|
|
||||||
|
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
|
||||||
|
|
||||||
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
||||||
|
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`。
|
||||||
|
|
||||||
|
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||||
|
|
||||||
|
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||||
|
|
||||||
|
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||||
|
|
||||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
@@ -93,20 +128,25 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B | q_proj,v_proj | mistral |
|
||||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [Qwen1.5 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||||
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||||
@@ -115,6 +155,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||||
|
|
||||||
|
您也可以在 [template.py](src/llmtuner/data/template.py) 中添加自己的对话模板。
|
||||||
|
|
||||||
## 训练方法
|
## 训练方法
|
||||||
|
|
||||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||||
@@ -124,9 +166,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
> [!NOTE]
|
|
||||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
|
||||||
|
|
||||||
## 数据集
|
## 数据集
|
||||||
|
|
||||||
@@ -148,8 +188,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self Cognition (zh)](data/self_cognition.json)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
@@ -165,11 +205,14 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
|
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
|
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@@ -177,6 +220,17 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
|
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||||
|
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||||
|
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||||
|
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||||
|
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||||
|
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -185,12 +239,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
|
|
||||||
|
|
||||||
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
|
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -200,49 +254,129 @@ huggingface-cli login
|
|||||||
|
|
||||||
## 软硬件依赖
|
## 软硬件依赖
|
||||||
|
|
||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
| 必需项 | 至少 | 推荐 |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf 和 tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
| transformers | 4.37.2 | 4.39.3 |
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
| datasets | 2.14.3 | 2.18.0 |
|
||||||
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
|
||||||
|
| 可选项 | 至少 | 推荐 |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.43.0 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.6 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
\* *估算值*
|
||||||
|
|
||||||
|
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
|
||||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||||
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
### 数据准备(可跳过)
|
### 数据准备
|
||||||
|
|
||||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
|
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||||
|
|
||||||
### 环境搭建(可跳过)
|
### 安装依赖
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
conda create -n llama_factory python=3.10
|
conda create -n llama_factory python=3.10
|
||||||
conda activate llama_factory
|
conda activate llama_factory
|
||||||
cd LLaMA-Factory
|
cd LLaMA-Factory
|
||||||
pip install -r requirements.txt
|
pip install -e .[metrics]
|
||||||
```
|
```
|
||||||
|
|
||||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
可选的额外依赖项:deepspeed、metrics、unsloth、galore、vllm、bitsandbytes、gptq、awq、aqlm、qwen、modelscope、quality
|
||||||
|
|
||||||
|
<details><summary>Windows 用户指南</summary>
|
||||||
|
|
||||||
|
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### 使用魔搭社区(可跳过)
|
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### LLaMA Board 可视化界面
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行分布式训练。
|
||||||
|
|
||||||
|
#### 使用本地环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
|
||||||
|
python src/train_web.py # 或 python -m llmtuner.webui.interface
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用 Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./Dockerfile -t llama-factory:latest .
|
||||||
|
docker run --gpus=all \
|
||||||
|
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||||
|
-v ./data:/app/data \
|
||||||
|
-v ./output:/app/output \
|
||||||
|
-e CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
-p 7860:7860 \
|
||||||
|
--shm-size 16G \
|
||||||
|
--name llama_factory \
|
||||||
|
-d llama-factory:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用 Docker Compose
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f ./docker-compose.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>数据卷详情</summary>
|
||||||
|
|
||||||
|
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
|
- data:宿主机中存放数据集的文件夹路径。
|
||||||
|
- output:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 命令行接口
|
||||||
|
|
||||||
|
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
|
||||||
|
|
||||||
|
使用 `python src/train_bash.py -h` 查看参数文档。
|
||||||
|
|
||||||
|
### 使用 OpenAI 风格 API 和 vLLM 部署
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||||
|
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
|
||||||
|
--template mistral \
|
||||||
|
--infer_backend vllm \
|
||||||
|
--vllm_enforce_eager
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用魔搭社区
|
||||||
|
|
||||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||||
|
|
||||||
@@ -250,344 +384,68 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||||
```
|
```
|
||||||
|
|
||||||
接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
|
将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `modelscope/Llama-2-7b-ms`。
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
|
||||||
... # 参数同上
|
|
||||||
```
|
|
||||||
|
|
||||||
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### 单 GPU 训练
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
|
|
||||||
|
|
||||||
#### 预训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage pt \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--dataset wiki_demo \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_pt_checkpoint \
|
|
||||||
--overwrite_cache \
|
|
||||||
--per_device_train_batch_size 4 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 5e-5 \
|
|
||||||
--num_train_epochs 3.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage sft \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--dataset alpaca_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_sft_checkpoint \
|
|
||||||
--overwrite_cache \
|
|
||||||
--per_device_train_batch_size 4 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 5e-5 \
|
|
||||||
--num_train_epochs 3.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 奖励模型训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage rm \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset comparison_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_rm_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-6 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
#### PPO 训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage ppo \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset alpaca_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--reward_model path_to_rm_checkpoint \
|
|
||||||
--output_dir path_to_ppo_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--top_k 0 \
|
|
||||||
--top_p 0.9 \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
|
||||||
|
|
||||||
#### DPO 训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage dpo \
|
|
||||||
--do_train \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_sft_checkpoint \
|
|
||||||
--create_new_adapter \
|
|
||||||
--dataset comparison_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--lora_target q_proj,v_proj \
|
|
||||||
--output_dir path_to_dpo_checkpoint \
|
|
||||||
--per_device_train_batch_size 2 \
|
|
||||||
--gradient_accumulation_steps 4 \
|
|
||||||
--lr_scheduler_type cosine \
|
|
||||||
--logging_steps 10 \
|
|
||||||
--save_steps 1000 \
|
|
||||||
--learning_rate 1e-5 \
|
|
||||||
--num_train_epochs 1.0 \
|
|
||||||
--plot_loss \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
### 多 GPU 分布式训练
|
|
||||||
|
|
||||||
#### 使用 Huggingface Accelerate
|
|
||||||
|
|
||||||
```bash
|
|
||||||
accelerate config # 首先配置分布式环境
|
|
||||||
accelerate launch src/train_bash.py # 参数同上
|
|
||||||
```
|
|
||||||
|
|
||||||
<details><summary>LoRA 训练的 Accelerate 配置示例</summary>
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
compute_environment: LOCAL_MACHINE
|
|
||||||
distributed_type: MULTI_GPU
|
|
||||||
downcast_bf16: 'no'
|
|
||||||
gpu_ids: all
|
|
||||||
machine_rank: 0
|
|
||||||
main_training_function: main
|
|
||||||
mixed_precision: fp16
|
|
||||||
num_machines: 1
|
|
||||||
num_processes: 4
|
|
||||||
rdzv_backend: static
|
|
||||||
same_network: true
|
|
||||||
tpu_env: []
|
|
||||||
tpu_use_cluster: false
|
|
||||||
tpu_use_sudo: false
|
|
||||||
use_cpu: false
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
#### 使用 DeepSpeed
|
|
||||||
|
|
||||||
```bash
|
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|
||||||
--deepspeed ds_config.json \
|
|
||||||
... # 参数同上
|
|
||||||
```
|
|
||||||
|
|
||||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"gradient_clipping": "auto",
|
|
||||||
"zero_allow_untested_optimizer": true,
|
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 2,
|
|
||||||
"allgather_partitions": true,
|
|
||||||
"allgather_bucket_size": 5e8,
|
|
||||||
"reduce_scatter": true,
|
|
||||||
"reduce_bucket_size": 5e8,
|
|
||||||
"overlap_comm": false,
|
|
||||||
"contiguous_gradients": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 合并 LoRA 权重并导出完整模型
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/export_model.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--export_dir path_to_export
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> 使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化导出模型。
|
|
||||||
|
|
||||||
### API 服务
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/api_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
|
||||||
|
|
||||||
### 命令行测试
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/cli_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
### 浏览器测试
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/web_demo.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora
|
|
||||||
```
|
|
||||||
|
|
||||||
### 模型评估
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--template vanilla \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--task ceval \
|
|
||||||
--split validation \
|
|
||||||
--lang zh \
|
|
||||||
--n_shot 5 \
|
|
||||||
--batch_size 4
|
|
||||||
```
|
|
||||||
|
|
||||||
### 模型预测
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|
||||||
--stage sft \
|
|
||||||
--do_predict \
|
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
|
||||||
--dataset alpaca_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--output_dir path_to_predict_result \
|
|
||||||
--per_device_eval_batch_size 8 \
|
|
||||||
--max_samples 100 \
|
|
||||||
--predict_with_generate \
|
|
||||||
--fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的预测,请使用 `--per_device_eval_batch_size=1`。
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
|
||||||
|
|
||||||
## 使用了 LLaMA Factory 的项目
|
## 使用了 LLaMA Factory 的项目
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
|
||||||
|
|
||||||
> [!TIP]
|
<details><summary>点击显示</summary>
|
||||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
|
||||||
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
|
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||||
|
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 协议
|
## 协议
|
||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@Misc{llama-factory,
|
@article{zheng2024llamafactory,
|
||||||
title = {LLaMA Factory},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author = {hiyouga},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
journal={arXiv preprint arXiv:2403.13372},
|
||||||
year = {2023}
|
year={2024},
|
||||||
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 致谢
|
## 致谢
|
||||||
|
|
||||||
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
|
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[TRL](https://github.com/huggingface/trl)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
|||||||
@@ -2,29 +2,40 @@ If you are using a custom dataset, please provide your dataset definition in the
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
|
||||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
"ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)",
|
||||||
|
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
|
||||||
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||||
"columns": {
|
"columns (optional)": {
|
||||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
|
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||||
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
|
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||||
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
|
"response": "the column name in the dataset containing the responses. (default: output)",
|
||||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
"history": "the column name in the dataset containing the histories. (default: None)",
|
||||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||||
"content": "the key in the message represents the content. (default: value, for sharegpt)",
|
"tools": "the column name in the dataset containing the tool description. (default: None)"
|
||||||
"system": "the column name in the dataset containing the system prompts. (default: None, for both)"
|
},
|
||||||
|
"tags (optional, used for the sharegpt format)": {
|
||||||
|
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||||
|
"content_tag": "the key in the message represents the content. (default: value)",
|
||||||
|
"user_tag": "the value of the role_tag represents the user. (default: human)",
|
||||||
|
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
||||||
|
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
||||||
|
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
||||||
|
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -56,9 +67,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||||
|
|
||||||
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
|
||||||
|
|
||||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||||
|
|
||||||
@@ -75,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Remember to set `"ranking": true` for the preference datasets.
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
The dataset in sharegpt format should follow the below format:
|
The dataset in sharegpt format should follow the below format:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -90,7 +105,8 @@ The dataset in sharegpt format should follow the below format:
|
|||||||
"value": "model response"
|
"value": "model response"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"system": "system prompt (optional)"
|
"system": "system prompt (optional)",
|
||||||
|
"tools": "tool description (optional)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@@ -101,13 +117,18 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"system": "system",
|
||||||
"content": "value",
|
"tools": "tools"
|
||||||
"system": "system"
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "from",
|
||||||
|
"content_tag": "value",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "gpt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
|
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||||
|
|
||||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||||
|
|||||||
@@ -2,29 +2,40 @@
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
"数据集名称": {
|
"数据集名称": {
|
||||||
"hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)",
|
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||||
|
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
|
||||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||||
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||||
"columns": {
|
"columns(可选)": {
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)",
|
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||||
"query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)",
|
"query": "数据集代表请求的表头名称(默认:input)",
|
||||||
"response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)",
|
"response": "数据集代表回答的表头名称(默认:output)",
|
||||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
"history": "数据集代表历史对话的表头名称(默认:None)",
|
||||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)",
|
"tools": "数据集代表工具描述的表头名称(默认:None)"
|
||||||
"system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)"
|
},
|
||||||
|
"tags(可选,用于 sharegpt 格式)": {
|
||||||
|
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||||
|
"content_tag": "消息中代表文本内容的键名(默认:value)",
|
||||||
|
"user_tag": "消息中代表用户的 role_tag(默认:human)",
|
||||||
|
"assistant_tag": "消息中代表助手的 role_tag(默认:gpt)",
|
||||||
|
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||||
|
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||||
|
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -56,9 +67,9 @@
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||||
|
|
||||||
`system` 为模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
|
||||||
|
|
||||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||||
|
|
||||||
@@ -75,6 +86,10 @@
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
添加偏好数据集需要额外指定 `"ranking": true`。
|
||||||
|
|
||||||
|
----
|
||||||
|
|
||||||
而 sharegpt 格式的数据集按照以下方式组织:
|
而 sharegpt 格式的数据集按照以下方式组织:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -90,7 +105,8 @@
|
|||||||
"value": "模型回答"
|
"value": "模型回答"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"system": "系统提示词(选填)"
|
"system": "系统提示词(选填)",
|
||||||
|
"tools": "工具描述(选填)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@@ -101,13 +117,18 @@
|
|||||||
"数据集名称": {
|
"数据集名称": {
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"system": "system",
|
||||||
"content": "value",
|
"tools": "tools"
|
||||||
"system": "system"
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "from",
|
||||||
|
"content_tag": "value",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "gpt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||||
|
|
||||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
|
||||||
_DESCRIPTION = "BELLE multiturn chat dataset."
|
_DESCRIPTION = "BELLE multiturn chat dataset."
|
||||||
|
|
||||||
_CITATION = """\
|
_CITATION = """\
|
||||||
@@ -13,9 +16,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M"
|
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, Generator, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
_DESCRIPTION = "An example of dataset."
|
_DESCRIPTION = "An example of dataset."
|
||||||
@@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
|
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||||
for key, example in enumerate(example_dataset):
|
for key, example in enumerate(example_dataset):
|
||||||
yield key, example
|
yield key, example
|
||||||
|
|||||||
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
4748dff00d1dc42768a5b6cc772143c313017812
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
|
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
|
|||||||
1
data/orca_rlhf.json.REMOVED.git-id
Normal file
1
data/orca_rlhf.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
736bcedea2b24a1414765c6d69cbdafaea839f3c
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
|
||||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||||
|
|
||||||
@@ -16,9 +18,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
|
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
|
|||||||
25
docker-compose.yml
Normal file
25
docker-compose.yml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
llama-factory:
|
||||||
|
build:
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
context: .
|
||||||
|
container_name: llama_factory
|
||||||
|
volumes:
|
||||||
|
- ./hf_cache:/root/.cache/huggingface/
|
||||||
|
- ./data:/app/data
|
||||||
|
- ./output:/app/output
|
||||||
|
environment:
|
||||||
|
- CUDA_VISIBLE_DEVICES=0
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
ipc: host
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: "all"
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
||||||
43
examples/README.md
Normal file
43
examples/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
We provide diverse examples about fine-tuning LLMs.
|
||||||
|
|
||||||
|
```
|
||||||
|
examples/
|
||||||
|
├── lora_single_gpu/
|
||||||
|
│ ├── pretrain.sh: Do pre-training
|
||||||
|
│ ├── sft.sh: Do supervised fine-tuning
|
||||||
|
│ ├── reward.sh: Do reward modeling
|
||||||
|
│ ├── ppo.sh: Do PPO training
|
||||||
|
│ ├── dpo.sh: Do DPO training
|
||||||
|
│ ├── orpo.sh: Do ORPO training
|
||||||
|
│ ├── prepare.sh: Save tokenized dataset
|
||||||
|
│ └── predict.sh: Do batch predict
|
||||||
|
├── qlora_single_gpu/
|
||||||
|
│ ├── bitsandbytes.sh: Fine-tune 4/8-bit BNB models
|
||||||
|
│ ├── gptq.sh: Fine-tune 4/8-bit GPTQ models
|
||||||
|
│ ├── awq.sh: Fine-tune 4-bit AWQ models
|
||||||
|
│ └── aqlm.sh: Fine-tune 2-bit AQLM models
|
||||||
|
├── lora_multi_gpu/
|
||||||
|
│ ├── single_node.sh: Fine-tune model with Accelerate on single node
|
||||||
|
│ └── multi_node.sh: Fine-tune model with Accelerate on multiple nodes
|
||||||
|
├── full_multi_gpu/
|
||||||
|
│ ├── single_node.sh: Fine-tune model with DeepSpeed on single node
|
||||||
|
│ └── multi_node.sh: Fine-tune model with DeepSpeed on multiple nodes
|
||||||
|
├── merge_lora/
|
||||||
|
│ ├── merge.sh: Merge LoRA weights into the pre-trained models
|
||||||
|
│ └── quantize.sh: Quantize fine-tuned model with AutoGPTQ
|
||||||
|
├── inference/
|
||||||
|
│ ├── cli_demo.sh: Launch a command line interface
|
||||||
|
│ ├── api_demo.sh: Launch an OpenAI-style API
|
||||||
|
│ ├── web_demo.sh: Launch a web interface
|
||||||
|
│ └── evaluate.sh: Evaluate model on the MMLU benchmark
|
||||||
|
└── extras/
|
||||||
|
├── galore/
|
||||||
|
│ └── sft.sh: Fine-tune model with GaLore
|
||||||
|
├── loraplus/
|
||||||
|
│ └── sft.sh: Fine-tune model with LoRA+
|
||||||
|
├── llama_pro/
|
||||||
|
│ ├── expand.sh: Expand layers in the model
|
||||||
|
│ └── sft.sh: Fine-tune expanded model
|
||||||
|
└── fsdp_qlora/
|
||||||
|
└── sft.sh: Fine-tune quantized model with FSDP
|
||||||
|
```
|
||||||
43
examples/README_zh.md
Normal file
43
examples/README_zh.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
我们提供了多样化的示例脚本。
|
||||||
|
|
||||||
|
```
|
||||||
|
examples/
|
||||||
|
├── lora_single_gpu/
|
||||||
|
│ ├── pretrain.sh: 进行预训练
|
||||||
|
│ ├── sft.sh: 进行指令监督微调
|
||||||
|
│ ├── reward.sh: 进行奖励模型训练
|
||||||
|
│ ├── ppo.sh: 进行 PPO 训练
|
||||||
|
│ ├── dpo.sh: 进行 DPO 训练
|
||||||
|
│ ├── orpo.sh: 进行 ORPO 训练
|
||||||
|
│ ├── prepare.sh: 保存预处理后的数据集
|
||||||
|
│ └── predict.sh: 进行批量预测
|
||||||
|
├── qlora_single_gpu/
|
||||||
|
│ ├── bitsandbytes.sh: 微调 4/8 比特 BNB 模型
|
||||||
|
│ ├── gptq.sh: 微调 4/8 比特 GPTQ 模型
|
||||||
|
│ ├── awq.sh: 微调 4 比特 AWQ 模型
|
||||||
|
│ └── aqlm.sh: 微调 2 比特 AQLM 模型
|
||||||
|
├── lora_multi_gpu/
|
||||||
|
│ ├── single_node.sh: 使用 Accelerate 进行单节点训练
|
||||||
|
│ └── multi_node.sh: 使用 Accelerate 进行多节点训练
|
||||||
|
├── full_multi_gpu/
|
||||||
|
│ ├── single_node.sh: 使用 DeepSpeed 进行单节点训练
|
||||||
|
│ └── multi_node.sh: 使用 DeepSpeed 进行多节点训练
|
||||||
|
├── merge_lora/
|
||||||
|
│ ├── merge.sh: 将 LoRA 权重合并到预训练模型中
|
||||||
|
│ └── quantize.sh: 使用 AutoGPTQ 量化模型
|
||||||
|
├── inference/
|
||||||
|
│ ├── cli_demo.sh: 启动命令行推理接口
|
||||||
|
│ ├── api_demo.sh: 启动 OpenAI 风格 API
|
||||||
|
│ ├── web_demo.sh: 启动浏览器推理接口
|
||||||
|
│ └── evaluate.sh: 在 MMLU 数据集上评测模型
|
||||||
|
└── extras/
|
||||||
|
├── galore/
|
||||||
|
│ └── sft.sh: 使用 GaLore 训练模型
|
||||||
|
├── loraplus/
|
||||||
|
│ └── sft.sh: 使用 LoRA+ 训练模型
|
||||||
|
├── llama_pro/
|
||||||
|
│ ├── expand.sh: 扩展模型中的层
|
||||||
|
│ └── sft.sh: 训练扩展后的模型
|
||||||
|
└── fsdp_qlora/
|
||||||
|
└── sft.sh: 使用 FSDP 微调量化模型
|
||||||
|
```
|
||||||
25
examples/accelerate/fsdp_config.yaml
Normal file
25
examples/accelerate/fsdp_config.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1 # the number of nodes
|
||||||
|
num_processes: 2 # the number of GPUs in all nodes
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
18
examples/accelerate/master_config.yaml
Normal file
18
examples/accelerate/master_config.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: 192.168.0.1
|
||||||
|
main_process_port: 29555
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 2 # the number of nodes
|
||||||
|
num_processes: 16 # the number of GPUs in all nodes
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
16
examples/accelerate/single_config.yaml
Normal file
16
examples/accelerate/single_config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1 # the number of nodes
|
||||||
|
num_processes: 4 # the number of GPUs in all nodes
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
18
examples/accelerate/slave_config.yaml
Normal file
18
examples/accelerate/slave_config.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 1
|
||||||
|
main_process_ip: 192.168.0.1
|
||||||
|
main_process_port: 29555
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 2 # the number of nodes
|
||||||
|
num_processes: 16 # the number of GPUs in all nodes
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
40
examples/extras/fsdp_qlora/sft.sh
Normal file
40
examples/extras/fsdp_qlora/sft.sh
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
pip install "transformers>=4.39.1"
|
||||||
|
pip install "accelerate>=0.28.0"
|
||||||
|
pip install "bitsandbytes>=0.43.0"
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||||
|
--config_file ../../accelerate/fsdp_config.yaml \
|
||||||
|
../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-70b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../../saves/LLaMA2-70B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
35
examples/extras/galore/sft.sh
Normal file
35
examples/extras/galore/sft.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--use_galore \
|
||||||
|
--galore_layerwise \
|
||||||
|
--galore_target mlp,self_attn \
|
||||||
|
--galore_rank 128 \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--pure_bf16
|
||||||
6
examples/extras/llama_pro/expand.sh
Normal file
6
examples/extras/llama_pro/expand.sh
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python ../../../scripts/llama_pro.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--output_dir ../../../models/llama2-7b-pro \
|
||||||
|
--num_expand 8
|
||||||
34
examples/extras/llama_pro/sft.sh
Normal file
34
examples/extras/llama_pro/sft.sh
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path ../../../models/llama2-7b-pro \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type freeze \
|
||||||
|
--name_module_trainable all \
|
||||||
|
--num_layer_trainable 8 \
|
||||||
|
--use_llama_pro \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B-Pro/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
33
examples/extras/loraplus/sft.sh
Normal file
33
examples/extras/loraplus/sft.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16 \
|
||||||
|
--loraplus_lr_ratio 16.0
|
||||||
38
examples/full_multi_gpu/multi_node.sh
Normal file
38
examples/full_multi_gpu/multi_node.sh
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python -m torch.distributed.run \
|
||||||
|
--nproc_per_node $NPROC_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/full_multi_gpu/single_node.sh
Normal file
32
examples/full_multi_gpu/single_node.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||||
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
7
examples/inference/api_demo.sh
Normal file
7
examples/inference/api_demo.sh
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora
|
||||||
7
examples/inference/cli_demo.sh
Normal file
7
examples/inference/cli_demo.sh
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora
|
||||||
12
examples/inference/evaluate.sh
Normal file
12
examples/inference/evaluate.sh
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template vanilla \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--task mmlu \
|
||||||
|
--split test \
|
||||||
|
--lang en \
|
||||||
|
--n_shot 5 \
|
||||||
|
--batch_size 4
|
||||||
7
examples/inference/web_demo.sh
Normal file
7
examples/inference/web_demo.sh
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora
|
||||||
35
examples/lora_multi_gpu/multi_node.sh
Normal file
35
examples/lora_multi_gpu/multi_node.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||||
|
--config_file ../accelerate/master_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
35
examples/lora_multi_gpu/single_node.sh
Normal file
35
examples/lora_multi_gpu/single_node.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
||||||
|
--config_file ../accelerate/single_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
35
examples/lora_single_gpu/dpo.sh
Normal file
35
examples/lora_single_gpu/dpo.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset orca_rlhf \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--dpo_ftx 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/lora_single_gpu/orpo.sh
Normal file
32
examples/lora_single_gpu/orpo.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage orpo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset orca_rlhf \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/orpo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/lora_single_gpu/ppo.sh
Normal file
32
examples/lora_single_gpu/ppo.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage ppo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset alpaca_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--reward_model ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 512 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--top_k 0 \
|
||||||
|
--top_p 0.9 \
|
||||||
|
--max_new_tokens 256 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
19
examples/lora_single_gpu/predict.sh
Normal file
19
examples/lora_single_gpu/predict.sh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_predict \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/predict \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--max_samples 20 \
|
||||||
|
--predict_with_generate
|
||||||
18
examples/lora_single_gpu/prepare.sh
Normal file
18
examples/lora_single_gpu/prepare.sh
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--tokenized_path ../../saves/datasets/sft
|
||||||
31
examples/lora_single_gpu/pretrain.sh
Normal file
31
examples/lora_single_gpu/pretrain.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage pt \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset c4_demo \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 10000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
33
examples/lora_single_gpu/reward.sh
Normal file
33
examples/lora_single_gpu/reward.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage rm \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset orca_rlhf \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 5000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/lora_single_gpu/sft.sh
Normal file
32
examples/lora_single_gpu/sft.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
11
examples/merge_lora/merge.sh
Normal file
11
examples/merge_lora/merge.sh
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES= python ../../src/export_model.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--export_dir ../../models/llama2-7b-sft \
|
||||||
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
10
examples/merge_lora/quantize.sh
Normal file
10
examples/merge_lora/quantize.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
|
--model_name_or_path ../../models/llama2-7b-sft \
|
||||||
|
--template default \
|
||||||
|
--export_dir ../../models/llama2-7b-sft-int4 \
|
||||||
|
--export_quantization_bit 4 \
|
||||||
|
--export_quantization_dataset ../../data/c4_demo.json \
|
||||||
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
30
examples/qlora_single_gpu/aqlm.sh
Normal file
30
examples/qlora_single_gpu/aqlm.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/awq.sh
Normal file
30
examples/qlora_single_gpu/awq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/gptq.sh
Normal file
30
examples/qlora_single_gpu/gptq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
@@ -1,3 +1,33 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61.0"]
|
requires = ["setuptools>=61.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py38"
|
||||||
|
line-length = 119
|
||||||
|
indent-width = 4
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||||
|
select = ["C", "E", "F", "I", "W"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
lines-after-imports = 2
|
||||||
|
known-first-party = ["llmtuner"]
|
||||||
|
known-third-party = [
|
||||||
|
"accelerate",
|
||||||
|
"datasets",
|
||||||
|
"gradio",
|
||||||
|
"numpy",
|
||||||
|
"peft",
|
||||||
|
"torch",
|
||||||
|
"transformers",
|
||||||
|
"trl"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
docstring-code-format = true
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.36.1
|
transformers>=4.37.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.27.2
|
||||||
peft>=0.7.0
|
peft>=0.10.0
|
||||||
trl==0.7.4
|
trl>=0.8.1
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=4.0.0,<=4.21.0
|
||||||
scipy
|
scipy
|
||||||
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
|
||||||
jieba
|
|
||||||
rouge-chinese
|
|
||||||
nltk
|
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic
|
pydantic
|
||||||
fastapi
|
fastapi
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
|
fire
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
|
|
||||||
@@ -16,25 +17,13 @@ def calculate_flops(
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: Optional[int] = 1,
|
batch_size: Optional[int] = 1,
|
||||||
seq_length: Optional[int] = 256,
|
seq_length: Optional[int] = 256,
|
||||||
flash_attn: Optional[bool] = False
|
flash_attn: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
|
||||||
model_name_or_path=model_name_or_path,
|
|
||||||
template="vanilla",
|
|
||||||
flash_attn=flash_attn
|
|
||||||
))
|
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||||
input_dict = {
|
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||||
"input_ids": fake_input,
|
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||||
"labels": fake_input.clone()
|
|
||||||
}
|
|
||||||
flops, macs, params = get_model_profile(
|
|
||||||
chat_model.model,
|
|
||||||
kwargs=input_dict,
|
|
||||||
print_profile=True,
|
|
||||||
detailed=True
|
|
||||||
)
|
|
||||||
print("FLOPs:", flops)
|
print("FLOPs:", flops)
|
||||||
print("MACs:", macs)
|
print("MACs:", macs)
|
||||||
print("Params:", params)
|
print("Params:", params)
|
||||||
77
scripts/cal_lr.py
Normal file
77
scripts/cal_lr.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
|
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||||
|
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
from llmtuner.data import get_dataset
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.hparams import get_train_args
|
||||||
|
from llmtuner.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||||
|
BASE_BS = 4_000_000 # from llama paper
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_lr(
|
||||||
|
model_name_or_path: str,
|
||||||
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
|
stage: Optional[str] = "sft",
|
||||||
|
dataset: Optional[str] = "alpaca_en",
|
||||||
|
dataset_dir: Optional[str] = "data",
|
||||||
|
template: Optional[str] = "default",
|
||||||
|
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||||
|
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||||
|
):
|
||||||
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
|
dict(
|
||||||
|
stage=stage,
|
||||||
|
model_name_or_path=model_name_or_path,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
template=template,
|
||||||
|
cutoff_len=cutoff_len,
|
||||||
|
output_dir="dummy_dir",
|
||||||
|
overwrite_cache=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(model_args)
|
||||||
|
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
|
||||||
|
if stage == "pt":
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
elif stage == "sft":
|
||||||
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||||
|
)
|
||||||
|
valid_tokens, total_tokens = 0, 0
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||||
|
total_tokens += torch.numel(batch["labels"])
|
||||||
|
|
||||||
|
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||||
|
valid_ratio = valid_tokens / total_tokens
|
||||||
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
|
lr = lr / 6.0 if is_mistral else lr
|
||||||
|
print(
|
||||||
|
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(calculate_lr)
|
||||||
52
scripts/length_cdf.py
Normal file
52
scripts/length_cdf.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Calculates the distribution of the input lengths in the dataset.
|
||||||
|
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from llmtuner.data import get_dataset
|
||||||
|
from llmtuner.hparams import get_train_args
|
||||||
|
from llmtuner.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def length_cdf(
|
||||||
|
model_name_or_path: str,
|
||||||
|
dataset: Optional[str] = "alpaca_en",
|
||||||
|
dataset_dir: Optional[str] = "data",
|
||||||
|
template: Optional[str] = "default",
|
||||||
|
interval: Optional[int] = 1000,
|
||||||
|
):
|
||||||
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
|
dict(
|
||||||
|
stage="sft",
|
||||||
|
model_name_or_path=model_name_or_path,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
template=template,
|
||||||
|
cutoff_len=1_000_000,
|
||||||
|
output_dir="dummy_dir",
|
||||||
|
overwrite_cache=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(model_args)
|
||||||
|
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||||
|
total_num = len(trainset)
|
||||||
|
length_dict = defaultdict(int)
|
||||||
|
for sample in tqdm(trainset["input_ids"]):
|
||||||
|
length_dict[len(sample) // interval * interval] += 1
|
||||||
|
|
||||||
|
length_tuples = list(length_dict.items())
|
||||||
|
length_tuples.sort()
|
||||||
|
count_accu, prob_accu = 0, 0
|
||||||
|
for length, count in length_tuples:
|
||||||
|
count_accu += count
|
||||||
|
prob_accu += count / total_num * 100
|
||||||
|
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(length_cdf)
|
||||||
115
scripts/llama_pro.py
Normal file
115
scripts/llama_pro.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Performs block expansion for LLaMA, Mistral or Qwen1.5 models.
|
||||||
|
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||||
|
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||||
|
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||||
|
|
||||||
|
|
||||||
|
def block_expansion(
|
||||||
|
model_name_or_path: str,
|
||||||
|
output_dir: str,
|
||||||
|
num_expand: int,
|
||||||
|
shard_size: Optional[str] = "2GB",
|
||||||
|
save_safetensors: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
num_layers = getattr(config, "num_hidden_layers")
|
||||||
|
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||||
|
config.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
|
||||||
|
if save_safetensors:
|
||||||
|
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
|
||||||
|
|
||||||
|
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
torch_dtype="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
if num_layers % num_expand != 0:
|
||||||
|
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||||
|
|
||||||
|
split = num_layers // num_expand
|
||||||
|
layer_cnt = 0
|
||||||
|
output_state_dict = OrderedDict()
|
||||||
|
for i in range(num_layers):
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if ".{:d}.".format(i) in key:
|
||||||
|
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||||
|
|
||||||
|
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||||
|
layer_cnt += 1
|
||||||
|
if (i + 1) % split == 0:
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if ".{:d}.".format(i) in key:
|
||||||
|
if "down_proj" in key or "o_proj" in key:
|
||||||
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||||
|
else:
|
||||||
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||||
|
|
||||||
|
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||||
|
layer_cnt += 1
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key not in output_state_dict:
|
||||||
|
output_state_dict[key] = value
|
||||||
|
|
||||||
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
|
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
|
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||||
|
if save_safetensors:
|
||||||
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||||
|
else:
|
||||||
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
print("Fine-tune this model with:")
|
||||||
|
print(" --model_name_or_path {} \\".format(output_dir))
|
||||||
|
print(" --finetuning_type freeze \\")
|
||||||
|
print(" --name_module_trainable all \\")
|
||||||
|
print(" --num_layer_trainable {} \\".format(num_expand))
|
||||||
|
print(" --use_llama_pro")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(block_expansion)
|
||||||
@@ -1,60 +1,68 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB
|
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from typing import Any, Dict, Optional
|
||||||
from typing import Any, Dict
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|
||||||
def save_weight(
|
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str
|
|
||||||
):
|
|
||||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in os.listdir(input_dir):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||||
baichuan2_state_dict.update(shard_weight)
|
baichuan2_state_dict.update(shard_weight)
|
||||||
|
|
||||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for key, value in baichuan2_state_dict.items():
|
for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
|
||||||
if "W_pack" in key:
|
if "W_pack" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
|
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
|
||||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
|
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
|
||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
||||||
else:
|
else:
|
||||||
llama2_state_dict[key] = value
|
llama2_state_dict[key] = value
|
||||||
|
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
for shard_file, shard in shards.items():
|
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
|
||||||
|
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||||
|
if save_safetensors:
|
||||||
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||||
else:
|
else:
|
||||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(input_dir: str, output_dir: str):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
@@ -69,17 +77,15 @@ def save_config(
|
|||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
input_dir: str,
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
output_dir: str,
|
|
||||||
shard_size: str
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise print("Output dir already exists", e)
|
raise print("Output dir already exists", e)
|
||||||
|
|
||||||
save_weight(input_dir, output_dir, shard_size)
|
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||||
save_config(input_dir, output_dir)
|
save_config(input_dir, output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1,33 +1,40 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Qwen models in the same format as LLaMA2.
|
# Converts the Qwen models in the same format as LLaMA2.
|
||||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||||
|
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from safetensors.torch import save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
shard_checkpoint,
|
||||||
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_min_version("4.34.0")
|
check_min_version("4.34.0")
|
||||||
except:
|
except Exception:
|
||||||
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
|
||||||
def save_weight(
|
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
shard_size: str
|
|
||||||
) -> str:
|
|
||||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
for filepath in os.listdir(input_dir):
|
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
@@ -35,7 +42,7 @@ def save_weight(
|
|||||||
|
|
||||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||||
torch_dtype = None
|
torch_dtype = None
|
||||||
for key, value in qwen_state_dict.items():
|
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
|
||||||
if torch_dtype is None:
|
if torch_dtype is None:
|
||||||
torch_dtype = value.dtype
|
torch_dtype = value.dtype
|
||||||
if "wte" in key:
|
if "wte" in key:
|
||||||
@@ -47,13 +54,15 @@ def save_weight(
|
|||||||
if "attn.c_attn" in key:
|
if "attn.c_attn" in key:
|
||||||
proj_size = value.size(0) // 3
|
proj_size = value.size(0) // 3
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
|
||||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
|
proj_size : 2 * proj_size, ...
|
||||||
|
]
|
||||||
|
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
|
||||||
elif "attn.c_proj" in key:
|
elif "attn.c_proj" in key:
|
||||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
|
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
|
||||||
torch.zeros_like(value[:, 0]).squeeze()
|
value[:, 0]
|
||||||
)
|
).squeeze()
|
||||||
elif "ln_1" in key:
|
elif "ln_1" in key:
|
||||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||||
elif "ln_2" in key:
|
elif "ln_2" in key:
|
||||||
@@ -69,25 +78,27 @@ def save_weight(
|
|||||||
else:
|
else:
|
||||||
raise KeyError("Unable to process key {}".format(key))
|
raise KeyError("Unable to process key {}".format(key))
|
||||||
|
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
for shard_file, shard in shards.items():
|
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
|
||||||
|
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||||
|
if save_safetensors:
|
||||||
|
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||||
else:
|
else:
|
||||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
input_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
torch_dtype: str
|
|
||||||
):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
@@ -118,17 +129,15 @@ def save_config(
|
|||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
input_dir: str,
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
output_dir: str,
|
|
||||||
shard_size: str
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise print("Output dir already exists", e)
|
raise print("Output dir already exists", e)
|
||||||
|
|
||||||
torch_dtype = save_weight(input_dir, output_dir, shard_size)
|
torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
||||||
save_config(input_dir, output_dir, torch_dtype)
|
save_config(input_dir, output_dir, torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -4,16 +4,20 @@
|
|||||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class Shell(nn.Module):
|
class Shell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||||
@@ -22,7 +26,7 @@ class Shell(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
|
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
||||||
parent_name = ".".join(name.split(".")[:-1])
|
parent_name = ".".join(name.split(".")[:-1])
|
||||||
child_name = name.split(".")[-1]
|
child_name = name.split(".")[-1]
|
||||||
parent_module = model.get_submodule(parent_name)
|
parent_module = model.get_submodule(parent_name)
|
||||||
@@ -31,7 +35,7 @@ def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
|||||||
weight = getattr(base_layer, "weight", None)
|
weight = getattr(base_layer, "weight", None)
|
||||||
bias = getattr(base_layer, "bias", None)
|
bias = getattr(base_layer, "bias", None)
|
||||||
setattr(parent_module, child_name, Shell(weight, bias))
|
setattr(parent_module, child_name, Shell(weight, bias))
|
||||||
|
|
||||||
print("Model unwrapped.")
|
print("Model unwrapped.")
|
||||||
|
|
||||||
|
|
||||||
@@ -42,7 +46,8 @@ def quantize_loftq(
|
|||||||
loftq_iter: Optional[int] = 1,
|
loftq_iter: Optional[int] = 1,
|
||||||
lora_alpha: Optional[int] = None,
|
lora_alpha: Optional[int] = None,
|
||||||
lora_rank: Optional[int] = 16,
|
lora_rank: Optional[int] = 16,
|
||||||
lora_target: Optional[str] = "q_proj,v_proj"
|
lora_target: Optional[str] = "q_proj,v_proj",
|
||||||
|
save_safetensors: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
@@ -55,21 +60,21 @@ def quantize_loftq(
|
|||||||
lora_dropout=0.1,
|
lora_dropout=0.1,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||||
init_lora_weights="loftq",
|
init_lora_weights="loftq",
|
||||||
loftq_config=loftq_config
|
loftq_config=loftq_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init LoftQ model
|
# Init LoftQ model
|
||||||
lora_model = get_peft_model(model, lora_config)
|
lora_model = get_peft_model(model, lora_config)
|
||||||
base_model = lora_model.get_base_model()
|
base_model: "PreTrainedModel" = lora_model.get_base_model()
|
||||||
|
|
||||||
# Save LoftQ model
|
# Save LoftQ model
|
||||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
||||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
||||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"))
|
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
unwrap_model(base_model)
|
unwrap_model(base_model)
|
||||||
base_model.save_pretrained(save_dir)
|
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(save_dir)
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
|
||||||
28
setup.py
28
setup.py
@@ -1,13 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
||||||
version, = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
@@ -18,8 +19,23 @@ def get_requires():
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def main():
|
extra_require = {
|
||||||
|
"deepspeed": ["deepspeed>=0.10.0"],
|
||||||
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
|
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
||||||
|
"galore": ["galore-torch"],
|
||||||
|
"vllm": ["vllm>=0.3.3"],
|
||||||
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
|
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||||
|
"awq": ["autoawq"],
|
||||||
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
|
"qwen": ["tiktoken", "transformers_stream_generator"],
|
||||||
|
"modelscope": ["modelscope"],
|
||||||
|
"quality": ["ruff"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
setup(
|
setup(
|
||||||
name="llmtuner",
|
name="llmtuner",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
@@ -35,8 +51,9 @@ def main():
|
|||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.8.0",
|
||||||
install_requires=get_requires(),
|
install_requires=get_requires(),
|
||||||
|
extras_require=extra_require,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
@@ -46,8 +63,9 @@ def main():
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import ChatModel, create_app
|
from llmtuner import ChatModel, create_app
|
||||||
@@ -6,8 +8,8 @@ from llmtuner import ChatModel, create_app
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
print("Visit http://localhost:8000/docs for API document.")
|
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
if platform.system() != "Windows":
|
||||||
import readline
|
import readline # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Install `readline` for a better experience.")
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
history = []
|
messages = []
|
||||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -27,20 +29,20 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if query.strip() == "clear":
|
if query.strip() == "clear":
|
||||||
history = []
|
messages = []
|
||||||
torch_gc()
|
torch_gc()
|
||||||
print("History has been removed.")
|
print("History has been removed.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in chat_model.stream_chat(query, history):
|
for new_text in chat_model.stream_chat(messages):
|
||||||
print(new_text, end="", flush=True)
|
print(new_text, end="", flush=True)
|
||||||
response += new_text
|
response += new_text
|
||||||
print()
|
print()
|
||||||
|
messages.append({"role": "assistant", "content": response})
|
||||||
history = history + [(query, response)]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from llmtuner import Evaluator
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
evaluator = Evaluator()
|
Evaluator().eval()
|
||||||
evaluator.eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||||
|
|
||||||
from llmtuner.api import create_app
|
from .api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from .chat import ChatModel
|
||||||
from llmtuner.eval import Evaluator
|
from .eval import Evaluator
|
||||||
from llmtuner.train import export_model, run_exp
|
from .train import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from .webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.6.2"
|
||||||
|
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.api.app import create_app
|
from .app import create_app
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
|
|||||||
@@ -1,28 +1,30 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Tuple
|
import os
|
||||||
from pydantic import BaseModel
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
from llmtuner.api.protocol import (
|
from pydantic import BaseModel
|
||||||
Role,
|
|
||||||
Finish,
|
from ..chat import ChatModel
|
||||||
ModelCard,
|
from ..data import Role as DataRole
|
||||||
ModelList,
|
from ..extras.misc import torch_gc
|
||||||
ChatMessage,
|
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
DeltaMessage,
|
from .protocol import (
|
||||||
|
ChatCompletionMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage,
|
ChatCompletionResponseUsage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
Finish,
|
||||||
|
Function,
|
||||||
|
FunctionCall,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
Role,
|
||||||
ScoreEvaluationRequest,
|
ScoreEvaluationRequest,
|
||||||
ScoreEvaluationResponse
|
ScoreEvaluationResponse,
|
||||||
)
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
|
||||||
from llmtuner.extras.packages import (
|
|
||||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -40,15 +42,22 @@ if is_uvicorn_available():
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def to_json(data: BaseModel) -> str:
|
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||||
try: # pydantic v2
|
try: # pydantic v2
|
||||||
|
return data.model_dump(exclude_unset=True)
|
||||||
|
except AttributeError: # pydantic v1
|
||||||
|
return data.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def jsonify(data: "BaseModel") -> str:
|
||||||
|
try: # pydantic v2
|
||||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||||
except: # pydantic v1
|
except AttributeError: # pydantic v1
|
||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +72,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
role_mapping = {
|
||||||
|
Role.USER: DataRole.USER.value,
|
||||||
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
|
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||||
|
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||||
|
Role.TOOL: DataRole.OBSERVATION.value,
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||||
@@ -70,106 +87,138 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if not chat_model.can_generate:
|
if not chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
|
||||||
query = request.messages[-1].content
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
prev_messages = request.messages[:-1]
|
system = request.messages.pop(0).content
|
||||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
|
||||||
system = prev_messages.pop(0).content
|
|
||||||
else:
|
else:
|
||||||
system = None
|
system = ""
|
||||||
|
|
||||||
history = []
|
if len(request.messages) % 2 == 0:
|
||||||
if len(prev_messages) % 2 == 0:
|
|
||||||
for i in range(0, len(prev_messages), 2):
|
|
||||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
|
input_messages = []
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
|
name = message.tool_calls[0].function.name
|
||||||
|
arguments = message.tool_calls[0].function.arguments
|
||||||
|
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||||
|
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
|
||||||
|
else:
|
||||||
|
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||||
|
|
||||||
|
tool_list = request.tools
|
||||||
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
|
try:
|
||||||
|
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
|
else:
|
||||||
|
tools = ""
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, system, request)
|
if tools:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
|
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = await chat_model.achat(
|
||||||
query, history, system,
|
input_messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
num_return_sequences=request.n
|
num_return_sequences=request.n,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_length, response_length = 0, 0
|
prompt_length, response_length = 0, 0
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
choices.append(ChatCompletionResponseChoice(
|
if tools:
|
||||||
index=i,
|
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
else:
|
||||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
result = response.response_text
|
||||||
))
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
name, arguments = result
|
||||||
|
function = Function(name=name, arguments=arguments)
|
||||||
|
response_message = ChatCompletionMessage(
|
||||||
|
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||||
|
)
|
||||||
|
finish_reason = Finish.TOOL
|
||||||
|
else:
|
||||||
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||||
|
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||||
|
)
|
||||||
prompt_length = response.prompt_length
|
prompt_length = response.prompt_length
|
||||||
response_length += response.response_length
|
response_length += response.response_length
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
total_tokens=prompt_length+response_length
|
total_tokens=prompt_length + response_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
async def stream_chat_completion(
|
||||||
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
async for new_token in chat_model.astream_chat(
|
||||||
query, history, system,
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens
|
max_new_tokens=request.max_tokens,
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_token) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||||
delta=DeltaMessage(content=new_text),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason=Finish.STOP
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||||
if chat_model.can_generate:
|
if chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
@@ -178,4 +227,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||||
|
|||||||
@@ -1,62 +1,94 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum, unique
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
FUNCTION = "function"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Finish(str, Enum):
|
class Finish(str, Enum):
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
LENGTH = "length"
|
LENGTH = "length"
|
||||||
|
TOOL = "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Optional[str] = "model"
|
object: Literal["model"] = "model"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Optional[str] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
object: Optional[str] = "list"
|
object: Literal["list"] = "list"
|
||||||
data: Optional[List[ModelCard]] = []
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionAvailable(BaseModel):
|
||||||
|
type: Literal["function", "code_interpreter"] = "function"
|
||||||
|
function: Optional[FunctionDefinition] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
id: Literal["call_default"] = "call_default"
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Role
|
role: Role
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[FunctionCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[Role] = None
|
role: Optional[Role] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[FunctionCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
do_sample: Optional[bool] = True
|
tools: Optional[List[FunctionAvailable]] = None
|
||||||
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
n: Optional[int] = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatCompletionMessage
|
||||||
finish_reason: Finish
|
finish_reason: Finish
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: ChatCompletionMessage
|
||||||
finish_reason: Optional[Finish] = None
|
finish_reason: Optional[Finish] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -67,18 +99,18 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: ChatCompletionResponseUsage
|
usage: ChatCompletionResponseUsage
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
@@ -90,7 +122,7 @@ class ScoreEvaluationRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ScoreEvaluationResponse(BaseModel):
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
id: Optional[str] = "scoreeval-default"
|
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||||
object: Optional[str] = "score.evaluation"
|
object: Literal["score.evaluation"] = "score.evaluation"
|
||||||
model: str
|
model: str
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
from llmtuner.chat.chat_model import ChatModel
|
from .base_engine import BaseEngine
|
||||||
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BaseEngine", "ChatModel"]
|
||||||
|
|||||||
69
src/llmtuner/chat/base_engine.py
Normal file
69
src/llmtuner/chat/base_engine.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
response_text: str
|
||||||
|
response_length: int
|
||||||
|
prompt_length: int
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
can_generate: bool
|
||||||
|
template: "Template"
|
||||||
|
generating_args: Dict[str, Any]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def start(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]: ...
|
||||||
@@ -1,172 +1,91 @@
|
|||||||
import torch
|
import asyncio
|
||||||
import tiktoken
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..hparams import get_infer_args
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from .hf_engine import HuggingfaceEngine
|
||||||
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
from .vllm_engine import VllmEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
if TYPE_CHECKING:
|
||||||
class Response:
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
response_text: str
|
|
||||||
response_length: int
|
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
prompt_length: int
|
asyncio.set_event_loop(loop)
|
||||||
finish_reason: Literal["stop", "length"]
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
self.can_generate = (finetuning_args.stage == "sft")
|
if model_args.infer_backend == "huggingface":
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
elif model_args.infer_backend == "vllm":
|
||||||
)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
else:
|
||||||
self.model = dispatch_model(self.model)
|
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
|
||||||
|
|
||||||
def _process_args(
|
self._loop = asyncio.new_event_loop()
|
||||||
self,
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
query: str,
|
self._thread.start()
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||||
system: Optional[str] = None,
|
|
||||||
**input_kwargs
|
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
|
||||||
prompt, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
|
||||||
)
|
|
||||||
prompt_length = len(prompt)
|
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
|
||||||
|
|
||||||
do_sample = input_kwargs.pop("do_sample", None)
|
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
|
||||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
|
||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
|
||||||
|
|
||||||
generating_args = self.generating_args.to_dict()
|
|
||||||
generating_args.update(dict(
|
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
|
||||||
temperature=temperature or generating_args["temperature"],
|
|
||||||
top_p=top_p or generating_args["top_p"],
|
|
||||||
top_k=top_k or generating_args["top_k"],
|
|
||||||
num_return_sequences=num_return_sequences or 1,
|
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id
|
|
||||||
))
|
|
||||||
|
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
|
||||||
generating_args["do_sample"] = True
|
|
||||||
|
|
||||||
if max_length:
|
|
||||||
generating_args.pop("max_new_tokens", None)
|
|
||||||
generating_args["max_length"] = max_length
|
|
||||||
|
|
||||||
if max_new_tokens:
|
|
||||||
generating_args.pop("max_length", None)
|
|
||||||
generating_args["max_new_tokens"] = max_new_tokens
|
|
||||||
|
|
||||||
gen_kwargs = dict(
|
|
||||||
inputs=input_ids,
|
|
||||||
generation_config=GenerationConfig(**generating_args),
|
|
||||||
logits_processor=get_logits_processor()
|
|
||||||
)
|
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
messages: Sequence[Dict[str, str]],
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
tools: Optional[str] = None,
|
||||||
) -> List[Response]:
|
**input_kwargs,
|
||||||
r"""
|
) -> List["Response"]:
|
||||||
Args: query, history, system, **input_kwargs
|
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||||
|
return task.result()
|
||||||
|
|
||||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
async def achat(
|
||||||
"""
|
self,
|
||||||
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
messages: Sequence[Dict[str, str]],
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
system: Optional[str] = None,
|
||||||
response_ids = generate_output[:, prompt_length:]
|
tools: Optional[str] = None,
|
||||||
response = self.tokenizer.batch_decode(
|
**input_kwargs,
|
||||||
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
) -> List["Response"]:
|
||||||
)
|
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||||
results = []
|
|
||||||
for i in range(len(response)):
|
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
|
||||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
|
||||||
results.append(Response(
|
|
||||||
response_text=response[i],
|
|
||||||
response_length=response_length,
|
|
||||||
prompt_length=prompt_length,
|
|
||||||
finish_reason="stop" if len(eos_index) else "length"
|
|
||||||
))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
messages: Sequence[Dict[str, str]],
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
while True:
|
||||||
gen_kwargs["streamer"] = streamer
|
try:
|
||||||
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
|
yield task.result()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
async def astream_chat(
|
||||||
thread.start()
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||||
|
yield new_token
|
||||||
|
|
||||||
yield from streamer
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||||
kwargs = dict(allowed_special="all")
|
return task.result()
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
async def aget_scores(
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
inputs = self.tokenizer(
|
**input_kwargs,
|
||||||
batch_input,
|
) -> List[float]:
|
||||||
padding=True,
|
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||||
truncation=True,
|
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
|
||||||
pad_to_multiple_of=8,
|
|
||||||
return_tensors="pt",
|
|
||||||
**kwargs
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
|
||||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|||||||
264
src/llmtuner/chat/hf_engine.py
Normal file
264
src/llmtuner/chat/hf_engine.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
from threading import Thread
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_logits_processor
|
||||||
|
from ..model import load_model, load_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
self.tokenizer = load_tokenizer(model_args)
|
||||||
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.model = load_model(
|
||||||
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
|
) # must after fixing tokenizer to resize vocab
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_args(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = template.encode_oneturn(
|
||||||
|
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
|
||||||
|
do_sample = input_kwargs.pop("do_sample", None)
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
|
generating_args["do_sample"] = True
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args.pop("max_new_tokens", None)
|
||||||
|
generating_args["max_length"] = max_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args.pop("max_length", None)
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
gen_kwargs = dict(
|
||||||
|
inputs=inputs,
|
||||||
|
generation_config=GenerationConfig(**generating_args),
|
||||||
|
logits_processor=get_logits_processor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List["Response"]:
|
||||||
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
generate_output = model.generate(**gen_kwargs)
|
||||||
|
response_ids = generate_output[:, prompt_length:]
|
||||||
|
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
|
results = []
|
||||||
|
for i in range(len(response)):
|
||||||
|
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||||
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=response[i],
|
||||||
|
response_length=response_length,
|
||||||
|
prompt_length=prompt_length,
|
||||||
|
finish_reason="stop" if len(eos_index) else "length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _stream_chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Callable[[], str]:
|
||||||
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
gen_kwargs["streamer"] = streamer
|
||||||
|
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
try:
|
||||||
|
return streamer.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _get_scores(
|
||||||
|
model: "PreTrainedModelWrapper",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
batch_input: List[str],
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List[float]:
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
|
inputs = tokenizer(
|
||||||
|
batch_input,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(input_ids.size(0)):
|
||||||
|
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
||||||
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
|
scores.append(values[i, end_index].nan_to_num().item())
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
stream = self._stream_chat(*input_args)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await loop.run_in_executor(pool, stream)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
if self.can_generate:
|
||||||
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||||
149
src/llmtuner/chat/vllm_engine.py
Normal file
149
src/llmtuner/chat/vllm_engine.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_device_count
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..model import load_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class VllmEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model_args.model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=model_args.vllm_maxlen,
|
||||||
|
tensor_parallel_size=get_device_count() or 1,
|
||||||
|
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||||
|
disable_log_stats=True,
|
||||||
|
disable_log_requests=True,
|
||||||
|
enforce_eager=model_args.vllm_enforce_eager,
|
||||||
|
)
|
||||||
|
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
self.tokenizer = load_tokenizer(model_args)
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
|
async def _generate(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
|
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(
|
||||||
|
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args = self.generating_args.copy()
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=generating_args["num_return_sequences"],
|
||||||
|
repetition_penalty=generating_args["repetition_penalty"],
|
||||||
|
temperature=generating_args["temperature"],
|
||||||
|
top_p=generating_args["top_p"],
|
||||||
|
top_k=generating_args["top_k"],
|
||||||
|
use_beam_search=generating_args["num_beams"] > 1,
|
||||||
|
length_penalty=generating_args["length_penalty"],
|
||||||
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
|
max_tokens=generating_args["max_new_tokens"],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
result_generator = self.model.generate(
|
||||||
|
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||||
|
)
|
||||||
|
return result_generator
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
final_output = None
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for request_output in generator:
|
||||||
|
final_output = request_output
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for output in final_output.outputs:
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=output.text,
|
||||||
|
response_length=len(output.token_ids),
|
||||||
|
prompt_length=len(final_output.prompt_token_ids),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
generated_text = ""
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for result in generator:
|
||||||
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
|
generated_text = result.outputs[0].text
|
||||||
|
yield delta_text
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||||
@@ -1,4 +1,15 @@
|
|||||||
from llmtuner.data.loader import get_dataset
|
from .collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.data.preprocess import preprocess_dataset
|
from .loader import get_dataset
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from llmtuner.data.utils import split_dataset
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PairwiseDataCollatorWithPadding",
|
||||||
|
"get_dataset",
|
||||||
|
"Template",
|
||||||
|
"get_template_and_fix_tokenizer",
|
||||||
|
"templates",
|
||||||
|
"Role",
|
||||||
|
"split_dataset",
|
||||||
|
]
|
||||||
|
|||||||
133
src/llmtuner/data/aligner.py
Normal file
133
src/llmtuner/data/aligner.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
|
from datasets import Features
|
||||||
|
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
|
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
|
prompt = []
|
||||||
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||||
|
content.append(examples[dataset_attr.prompt][i])
|
||||||
|
|
||||||
|
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||||
|
content.append(examples[dataset_attr.query][i])
|
||||||
|
|
||||||
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||||
|
|
||||||
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
|
]
|
||||||
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||||
|
else:
|
||||||
|
response = []
|
||||||
|
|
||||||
|
outputs["prompt"].append(prompt)
|
||||||
|
outputs["response"].append(response)
|
||||||
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
|
outputs["tools"].append("")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
|
tag_mapping = {
|
||||||
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
|
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
|
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
|
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
|
}
|
||||||
|
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||||
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
|
accept_tags = (odd_tags, even_tags)
|
||||||
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
|
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||||
|
system = messages[0][dataset_attr.content_tag]
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||||
|
|
||||||
|
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||||
|
if len(messages) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
aligned_messages = []
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
|
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||||
|
|
||||||
|
aligned_messages.append(
|
||||||
|
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs["prompt"].append(aligned_messages[:-1])
|
||||||
|
outputs["response"].append(aligned_messages[-1:])
|
||||||
|
outputs["system"].append(system)
|
||||||
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def align_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Aligned dataset:
|
||||||
|
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||||
|
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||||
|
system: "..."
|
||||||
|
tools: "..."
|
||||||
|
"""
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||||
|
else:
|
||||||
|
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||||
|
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
features = Features.from_dict(
|
||||||
|
{
|
||||||
|
"prompt": [
|
||||||
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||||
|
],
|
||||||
|
"response": [
|
||||||
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||||
|
],
|
||||||
|
"system": {"dtype": "string", "_type": "Value"},
|
||||||
|
"tools": {"dtype": "string", "_type": "Value"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Converting format of dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset.map(
|
||||||
|
convert_func,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
features=features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@@ -1,16 +1,20 @@
|
|||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Sequence, Tuple
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Masks out the input ids except for the responses.
|
||||||
|
"""
|
||||||
padded_labels = []
|
padded_labels = []
|
||||||
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||||
if self.tokenizer.padding_side == "left":
|
if self.tokenizer.padding_side == "left":
|
||||||
@@ -20,7 +24,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
padded_tensor[start:end] = feature[start:end]
|
padded_tensor[start:end] = feature[start:end]
|
||||||
padded_labels.append(padded_tensor)
|
padded_labels.append(padded_tensor)
|
||||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -34,18 +38,14 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
for key in ("chosen_ids", "rejected_ids"):
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
concatenated_features.append({
|
concatenated_features.append(
|
||||||
"input_ids": feature["prompt_ids"] + feature[key],
|
{
|
||||||
"attention_mask": [1] * (prompt_len + answer_len)
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
})
|
"attention_mask": [1] * (prompt_len + answer_len),
|
||||||
|
}
|
||||||
|
)
|
||||||
label_positions.append((prompt_len, answer_len))
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
batch = self.tokenizer.pad(
|
batch = super().__call__(concatenated_features)
|
||||||
concatenated_features,
|
|
||||||
padding=self.padding,
|
|
||||||
max_length=self.max_length,
|
|
||||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
||||||
return_tensors=self.return_tensors,
|
|
||||||
)
|
|
||||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||||
return batch
|
return batch
|
||||||
187
src/llmtuner/data/formatter.py
Normal file
187
src/llmtuner/data/formatter.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
JSON_FORMAT_PROMPT = (
|
||||||
|
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_SYSTEM_PROMPT = (
|
||||||
|
"You have access to the following tools:\n{tool_text}"
|
||||||
|
"Use the following format if using a tool:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
|
"Action Input: the input to the tool{format_prompt}.\n"
|
||||||
|
"```\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
tool_names = []
|
||||||
|
for tool in tools:
|
||||||
|
param_text = ""
|
||||||
|
for name, param in tool["parameters"]["properties"].items():
|
||||||
|
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||||
|
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||||
|
items = (
|
||||||
|
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||||
|
)
|
||||||
|
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||||
|
name=name,
|
||||||
|
type=param.get("type", ""),
|
||||||
|
required=required,
|
||||||
|
desc=param.get("description", ""),
|
||||||
|
enum=enum,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||||
|
)
|
||||||
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
|
return TOOL_SYSTEM_PROMPT.format(
|
||||||
|
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||||
|
action_match = re.search(regex, content)
|
||||||
|
if not action_match:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_name = action_match.group(1).strip()
|
||||||
|
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Formatter(ABC):
|
||||||
|
slots: SLOTS = field(default_factory=list)
|
||||||
|
tool_format: Optional[Literal["default"]] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
|
|
||||||
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmptyFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if has_placeholder:
|
||||||
|
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||||
|
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
return self.slots
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StringFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if not has_placeholder:
|
||||||
|
raise ValueError("A placeholder is required in the string formatter.")
|
||||||
|
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
elements = []
|
||||||
|
for slot in self.slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise RuntimeError("Expected a string, got {}".format(value))
|
||||||
|
|
||||||
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
|
elements.append(slot)
|
||||||
|
elif isinstance(slot, (dict, set)):
|
||||||
|
elements.append(slot)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_name, has_args = False, False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if "{{name}}" in slot:
|
||||||
|
has_name = True
|
||||||
|
if "{{arguments}}" in slot:
|
||||||
|
has_args = True
|
||||||
|
|
||||||
|
if not has_name or not has_args:
|
||||||
|
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||||
|
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
content = kwargs.pop("content")
|
||||||
|
try:
|
||||||
|
function = json.loads(content)
|
||||||
|
name = function["name"]
|
||||||
|
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
name, arguments = "", ""
|
||||||
|
|
||||||
|
elements = []
|
||||||
|
for slot in self.slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
|
elements.append(slot)
|
||||||
|
elif isinstance(slot, (dict, set)):
|
||||||
|
elements.append(slot)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.tool_format is None:
|
||||||
|
raise ValueError("Tool format was not found.")
|
||||||
|
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
content = kwargs.pop("content")
|
||||||
|
try:
|
||||||
|
tools = json.loads(content)
|
||||||
|
if not len(tools):
|
||||||
|
return [""]
|
||||||
|
|
||||||
|
if self.tool_format == "default":
|
||||||
|
return [default_tool_formatter(tools)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
except Exception:
|
||||||
|
return [""]
|
||||||
|
|
||||||
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
if self.tool_format == "default":
|
||||||
|
return default_tool_extractor(content)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,163 +1,179 @@
|
|||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.misc import has_tokenized_data
|
||||||
|
from .aligner import align_dataset
|
||||||
|
from .parser import get_dataset_list
|
||||||
|
from .preprocess import get_preprocess_and_print_func
|
||||||
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
from .utils import checksum, merge_dataset
|
||||||
|
|
||||||
from llmtuner.data.utils import checksum
|
|
||||||
from llmtuner.extras.constants import FILEEXT2TYPE
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..hparams import DataArguments, ModelArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def load_single_dataset(
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
max_samples = data_args.max_samples
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
|
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||||
|
data_path = dataset_attr.dataset_name
|
||||||
|
data_name = dataset_attr.subset
|
||||||
|
data_dir = dataset_attr.folder
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
elif dataset_attr.load_from == "script":
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
|
data_name = dataset_attr.subset
|
||||||
|
data_dir = dataset_attr.folder
|
||||||
|
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
elif dataset_attr.load_from == "file":
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
data_files = []
|
||||||
data_path = dataset_attr.dataset_name
|
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
data_name = dataset_attr.subset
|
if os.path.isdir(local_path): # is directory
|
||||||
data_dir = dataset_attr.folder
|
for file_name in os.listdir(local_path):
|
||||||
elif dataset_attr.load_from == "script":
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
if data_path is None:
|
||||||
data_name = dataset_attr.subset
|
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
elif dataset_attr.load_from == "file":
|
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||||
data_files = []
|
raise ValueError("File types should be identical.")
|
||||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
elif os.path.isfile(local_path): # is file
|
||||||
if os.path.isdir(local_path): # is directory
|
data_files.append(local_path)
|
||||||
for file_name in os.listdir(local_path):
|
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
|
||||||
elif os.path.isfile(local_path): # is file
|
|
||||||
data_files.append(local_path)
|
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
raise ValueError("File not found.")
|
|
||||||
|
|
||||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
|
||||||
checksum(data_files, dataset_attr.dataset_sha1)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if data_path is None:
|
||||||
try:
|
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||||
from modelscope import MsDataset # type: ignore
|
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
checksum(data_files, dataset_attr.file_sha1)
|
||||||
dataset = MsDataset.load(
|
else:
|
||||||
dataset_name=data_path,
|
raise NotImplementedError
|
||||||
subset_name=data_name,
|
|
||||||
data_dir=data_dir,
|
if dataset_attr.load_from == "ms_hub":
|
||||||
data_files=data_files,
|
try:
|
||||||
split=data_args.split,
|
from modelscope import MsDataset
|
||||||
cache_dir=cache_dir,
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
token=model_args.ms_hub_token,
|
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
).to_hf_dataset()
|
dataset = MsDataset.load(
|
||||||
except ImportError:
|
dataset_name=data_path,
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
subset_name=data_name,
|
||||||
else:
|
|
||||||
dataset = load_dataset(
|
|
||||||
path=data_path,
|
|
||||||
name=data_name,
|
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
)
|
)
|
||||||
|
if isinstance(dataset, MsDataset):
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
dataset = dataset.to_hf_dataset()
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
except ImportError:
|
||||||
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
if max_samples is not None: # truncate dataset
|
|
||||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
|
||||||
|
|
||||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
|
||||||
# convert dataset from sharegpt format to alpaca format
|
|
||||||
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
|
|
||||||
for i, msg_list in enumerate(examples[dataset_attr.messages]):
|
|
||||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
|
||||||
if len(msg_list) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg_pairs = []
|
|
||||||
user_role, assistant_role = None, None
|
|
||||||
for idx in range(0, len(msg_list), 2):
|
|
||||||
if user_role is None and assistant_role is None:
|
|
||||||
user_role = msg_list[idx][dataset_attr.role]
|
|
||||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
msg_list[idx][dataset_attr.role] != user_role
|
|
||||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
|
||||||
):
|
|
||||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
|
||||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
|
||||||
|
|
||||||
if len(msg_pairs) != 0:
|
|
||||||
outputs["prompt"].append(msg_pairs[-1][0])
|
|
||||||
outputs["query"].append("")
|
|
||||||
outputs["response"].append(msg_pairs[-1][1])
|
|
||||||
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
|
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
if dataset_attr.formatting == "sharegpt": # convert format
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Converting format of dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
convert_format,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
|
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
|
||||||
return all_datasets[0]
|
|
||||||
elif data_args.mix_strategy == "concat":
|
|
||||||
if data_args.streaming:
|
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
|
||||||
return concatenate_datasets(all_datasets)
|
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
|
||||||
if not data_args.streaming:
|
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
|
||||||
return interleave_datasets(
|
|
||||||
datasets=all_datasets,
|
|
||||||
probabilities=data_args.interleave_probs,
|
|
||||||
seed=data_args.seed,
|
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=data_args.split,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
token=model_args.hf_hub_token,
|
||||||
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
|
num_samples = min(data_args.max_samples, len(dataset))
|
||||||
|
dataset = dataset.select(range(num_samples))
|
||||||
|
|
||||||
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
|
# Load tokenized dataset
|
||||||
|
if data_args.tokenized_path is not None:
|
||||||
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
dataset = load_from_disk(data_args.tokenized_path)
|
||||||
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.to_iterable_dataset()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
|
all_datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(data_args):
|
||||||
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
|
tokenizer, template, data_args, training_args, stage
|
||||||
|
)
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Running tokenizer on dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
|
|
||||||
|
if data_args.tokenized_path is not None:
|
||||||
|
if training_args.should_save:
|
||||||
|
dataset.save_to_disk(data_args.tokenized_path)
|
||||||
|
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||||
|
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
try:
|
||||||
|
print_function(next(iter(dataset)))
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|||||||
132
src/llmtuner/data/parser.py
Normal file
132
src/llmtuner/data/parser.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from ..extras.constants import DATA_CONFIG
|
||||||
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetAttr:
|
||||||
|
r"""
|
||||||
|
Dataset attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
""" basic configs """
|
||||||
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
|
dataset_name: str
|
||||||
|
""" extra configs """
|
||||||
|
file_sha1: Optional[str] = None
|
||||||
|
subset: Optional[str] = None
|
||||||
|
folder: Optional[str] = None
|
||||||
|
ranking: bool = False
|
||||||
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
|
""" columns """
|
||||||
|
system: Optional[str] = None
|
||||||
|
""" columns for the alpaca format """
|
||||||
|
prompt: Optional[str] = "instruction"
|
||||||
|
query: Optional[str] = "input"
|
||||||
|
response: Optional[str] = "output"
|
||||||
|
history: Optional[str] = None
|
||||||
|
""" columns for the sharegpt format """
|
||||||
|
messages: Optional[str] = "conversations"
|
||||||
|
tools: Optional[str] = None
|
||||||
|
""" tags for the sharegpt format """
|
||||||
|
role_tag: Optional[str] = "from"
|
||||||
|
content_tag: Optional[str] = "value"
|
||||||
|
user_tag: Optional[str] = "human"
|
||||||
|
assistant_tag: Optional[str] = "gpt"
|
||||||
|
observation_tag: Optional[str] = "observation"
|
||||||
|
function_tag: Optional[str] = "function_call"
|
||||||
|
system_tag: Optional[str] = "system"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.dataset_name
|
||||||
|
|
||||||
|
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||||
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
|
if data_args.dataset is not None:
|
||||||
|
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
||||||
|
else:
|
||||||
|
dataset_names = []
|
||||||
|
|
||||||
|
if data_args.dataset_dir == "ONLINE":
|
||||||
|
dataset_info = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
except Exception as err:
|
||||||
|
if len(dataset_names) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||||
|
)
|
||||||
|
dataset_info = None
|
||||||
|
|
||||||
|
if data_args.interleave_probs is not None:
|
||||||
|
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||||
|
|
||||||
|
dataset_list: List[DatasetAttr] = []
|
||||||
|
for name in dataset_names:
|
||||||
|
if dataset_info is None:
|
||||||
|
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||||
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
|
dataset_list.append(dataset_attr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name not in dataset_info:
|
||||||
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
|
if has_hf_url or has_ms_url:
|
||||||
|
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
elif "script_url" in dataset_info[name]:
|
||||||
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
|
|
||||||
|
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
|
|
||||||
|
if "columns" in dataset_info[name]:
|
||||||
|
column_names = ["system"]
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
column_names.extend(["prompt", "query", "response", "history"])
|
||||||
|
else:
|
||||||
|
column_names.extend(["messages", "tools"])
|
||||||
|
|
||||||
|
for column_name in column_names:
|
||||||
|
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||||
|
|
||||||
|
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||||
|
tag_names = (
|
||||||
|
"role_tag",
|
||||||
|
"content_tag",
|
||||||
|
"user_tag",
|
||||||
|
"assistant_tag",
|
||||||
|
"observation_tag",
|
||||||
|
"function_tag",
|
||||||
|
"system_tag",
|
||||||
|
)
|
||||||
|
for tag in tag_names:
|
||||||
|
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||||
|
|
||||||
|
dataset_list.append(dataset_attr)
|
||||||
|
|
||||||
|
return dataset_list
|
||||||
@@ -1,275 +1,278 @@
|
|||||||
import os
|
from functools import partial
|
||||||
import tiktoken
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||||
|
|
||||||
from datasets import load_from_disk
|
from ..extras.constants import IGNORE_INDEX
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def preprocess_pretrain_dataset(
|
||||||
for i in range(len(examples["prompt"])):
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
) -> Dict[str, List[List[int]]]:
|
||||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
history = examples["history"][i] if "history" in examples else None
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
system = examples["system"][i] if "system" in examples else None
|
|
||||||
yield query, response, history, system
|
|
||||||
|
|
||||||
|
if not data_args.packing:
|
||||||
|
if data_args.template == "gemma":
|
||||||
|
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||||
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
else:
|
||||||
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
max_source_len = data_args.cutoff_len - max_target_len
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
|
||||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
|
||||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
|
||||||
setattr(tokenizer, "add_eos_token", True)
|
|
||||||
|
|
||||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
|
||||||
result = {
|
result = {
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
# make sure the saved tokenizer is the same as the original one
|
if data_args.template == "gemma":
|
||||||
if hasattr(tokenizer, "add_eos_token"):
|
for i in range(len(result["input_ids"])):
|
||||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||||
return result
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
return result
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = [], []
|
def preprocess_supervised_dataset(
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer, query, response, history, system
|
tokenizer: "PreTrainedTokenizer",
|
||||||
)):
|
template: "Template",
|
||||||
source_len, target_len = len(source_ids), len(target_ids)
|
data_args: "DataArguments",
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
) -> Dict[str, List[List[int]]]:
|
||||||
if source_len > max_source_len:
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
source_ids = source_ids[:max_source_len]
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
if target_len > max_target_len:
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
target_ids = target_ids[:max_target_len]
|
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
for i in range(len(examples["prompt"])):
|
||||||
source_mask = source_ids
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
continue
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
input_ids += [tokenizer.eos_token_id]
|
|
||||||
labels += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
|
||||||
labels = labels[:data_args.cutoff_len]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
template.encode_multiturn(
|
||||||
continue
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
input_ids += source_ids + target_ids
|
||||||
tokenizer, query, response, history, system
|
labels += source_mask + target_ids
|
||||||
)):
|
|
||||||
if data_args.train_on_prompt:
|
|
||||||
source_mask = source_ids
|
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
input_ids += [tokenizer.eos_token_id]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_length = len(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
block_size = data_args.cutoff_len
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
model_inputs["labels"].append(labels)
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of cutoff_len
|
return model_inputs
|
||||||
for i in range(0, total_length, block_size):
|
|
||||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
|
||||||
|
def preprocess_packed_supervised_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
input_ids, labels = [], []
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
for source_ids, target_ids in template.encode_multiturn(
|
||||||
|
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif len(input_ids) != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
total_length = len(input_ids)
|
||||||
|
block_size = data_args.cutoff_len
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
for i in range(0, total_length, block_size):
|
||||||
|
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||||
|
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
model_inputs["labels"].append(labels[i: i + block_size])
|
model_inputs["labels"].append(labels[i : i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
def preprocess_unsupervised_dataset(
|
||||||
if not (isinstance(query, str) and query != ""):
|
examples: Dict[str, List[Any]],
|
||||||
continue
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
if template.efficient_eos:
|
if len(examples["response"][i]) == 1:
|
||||||
labels += [tokenizer.eos_token_id]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
else:
|
||||||
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
input_ids, labels = template.encode_oneturn(
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
tokenizer,
|
||||||
if len(labels) > data_args.cutoff_len:
|
messages,
|
||||||
labels = labels[:data_args.cutoff_len]
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
model_inputs["input_ids"].append(input_ids)
|
data_args.cutoff_len,
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
data_args.reserved_label_len,
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
|
||||||
continue
|
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
|
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
|
||||||
if source_len > max_source_len:
|
|
||||||
prompt_ids = prompt_ids[:max_source_len]
|
|
||||||
if target_len > max_target_len:
|
|
||||||
chosen_ids = chosen_ids[:max_target_len]
|
|
||||||
rejected_ids = rejected_ids[:max_target_len]
|
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print("labels:\n{}".format(
|
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
|
||||||
))
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
|
||||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
|
||||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
|
||||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
|
||||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
|
||||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
if stage == "pt":
|
|
||||||
preprocess_func = preprocess_pretrain_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
|
||||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
|
||||||
print_function = print_supervised_dataset_example
|
|
||||||
elif stage == "rm":
|
|
||||||
preprocess_func = preprocess_pairwise_dataset
|
|
||||||
print_function = print_pairwise_dataset_example
|
|
||||||
else:
|
|
||||||
preprocess_func = preprocess_unsupervised_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
|
|
||||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
|
||||||
return load_from_disk(data_args.cache_path)
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Running tokenizer on dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
preprocess_func,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
if template.efficient_eos:
|
||||||
if training_args.should_save:
|
labels += [tokenizer.eos_token_id]
|
||||||
dataset.save_to_disk(data_args.cache_path)
|
|
||||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
|
||||||
|
|
||||||
if training_args.should_log:
|
model_inputs["input_ids"].append(input_ids)
|
||||||
try:
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
print_function(next(iter(dataset)))
|
model_inputs["labels"].append(labels)
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("Empty dataset!")
|
|
||||||
|
|
||||||
return dataset
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_pairwise_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
|
tokenizer,
|
||||||
|
chosen_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
_, rejected_ids = template.encode_oneturn(
|
||||||
|
tokenizer,
|
||||||
|
rejected_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print(
|
||||||
|
"labels:\n{}".format(
|
||||||
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocess_and_print_func(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
|
if stage == "pt":
|
||||||
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
|
if data_args.packing:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
|
||||||
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "rm":
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
return preprocess_func, print_function
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,31 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from enum import Enum, unique
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import TrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class Role(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
FUNCTION = "function"
|
||||||
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
if file_sha1 is None:
|
if file_sha1 is None:
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
@@ -27,13 +41,42 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||||
|
max_target_len = max(max_target_len, reserved_label_len)
|
||||||
|
max_source_len = max_len - min(max_target_len, target_len)
|
||||||
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dataset(
|
||||||
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "TrainingArguments"
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
if len(all_datasets) == 1:
|
||||||
|
return all_datasets[0]
|
||||||
|
elif data_args.mix_strategy == "concat":
|
||||||
|
if data_args.streaming:
|
||||||
|
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
return concatenate_datasets(all_datasets)
|
||||||
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
|
if not data_args.streaming:
|
||||||
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
return interleave_datasets(
|
||||||
|
datasets=all_datasets,
|
||||||
|
probabilities=data_args.interleave_probs,
|
||||||
|
seed=training_args.seed,
|
||||||
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> Dict[str, "Dataset"]:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
val_set = dataset.take(int(data_args.val_size))
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
train_set = dataset.skip(int(data_args.val_size))
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
@@ -47,5 +90,5 @@ def split_dataset(
|
|||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
return {"train_dataset": dataset}
|
return {"train_dataset": dataset}
|
||||||
else: # do_eval or do_predict
|
else: # do_eval or do_predict
|
||||||
return {"eval_dataset": dataset}
|
return {"eval_dataset": dataset}
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.eval.evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Evaluator"]
|
||||||
|
|||||||
@@ -1,41 +1,34 @@
|
|||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
import tiktoken
|
import json
|
||||||
import numpy as np
|
import os
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm, trange
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from llmtuner.eval.template import get_eval_template
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
from ..hparams import get_eval_args
|
||||||
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
self.tokenizer = load_tokenizer(self.model_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = self._encode_choices()
|
self.choice_inputs = [
|
||||||
|
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||||
def _encode_choices(self) -> List[int]:
|
]
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=False)
|
|
||||||
|
|
||||||
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
@@ -46,16 +39,11 @@ class Evaluator:
|
|||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
if "token" in inspect.signature(cached_file).parameters:
|
|
||||||
kwargs = {"token": self.model_args.hf_hub_token}
|
|
||||||
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
|
||||||
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
|
|
||||||
|
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
**kwargs
|
token=self.model_args.hf_hub_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
@@ -65,37 +53,45 @@ class Evaluator:
|
|||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
results = {}
|
results = {}
|
||||||
for subject in pbar:
|
for subject in pbar:
|
||||||
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
name=subject,
|
name=subject,
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
download_mode=self.eval_args.download_mode,
|
download_mode=self.eval_args.download_mode,
|
||||||
token=self.model_args.hf_hub_token
|
token=self.model_args.hf_hub_token,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
support_set = (
|
||||||
query, resp, history = self.eval_template.format_example(
|
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
|
)
|
||||||
|
messages = self.eval_template.format_example(
|
||||||
target_data=dataset[self.data_args.split][i],
|
target_data=dataset[self.data_args.split][i],
|
||||||
support_set=support_set,
|
support_set=support_set,
|
||||||
subject_name=categorys[subject]["name"],
|
subject_name=categorys[subject]["name"],
|
||||||
use_history=self.template.use_history
|
|
||||||
)
|
)
|
||||||
input_ids, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
|
||||||
)
|
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
|
||||||
labels.append(resp)
|
|
||||||
|
|
||||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
||||||
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
|
labels.append(messages[-1]["content"])
|
||||||
|
|
||||||
|
for i in trange(
|
||||||
|
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
||||||
|
):
|
||||||
batch_input = self.tokenizer.pad(
|
batch_input = self.tokenizer.pad(
|
||||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
preds = self.batch_inference(batch_input)
|
preds = self.batch_inference(batch_input)
|
||||||
outputs += preds
|
outputs += preds
|
||||||
|
|
||||||
corrects = (np.array(outputs) == np.array(labels))
|
corrects = np.array(outputs) == np.array(labels)
|
||||||
category_name = categorys[subject]["category"]
|
category_name = categorys[subject]["category"]
|
||||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||||
@@ -105,10 +101,13 @@ class Evaluator:
|
|||||||
self._save_results(category_corrects, results)
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join([
|
score_info = "\n".join(
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
[
|
||||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
])
|
for category_name, category_correct in category_corrects.items()
|
||||||
|
if len(category_correct)
|
||||||
|
]
|
||||||
|
)
|
||||||
print(score_info)
|
print(score_info)
|
||||||
if self.eval_args.save_dir is not None:
|
if self.eval_args.save_dir is not None:
|
||||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||||
|
|||||||
@@ -1,86 +1,70 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.constants import CHOICES
|
from ..data import Role
|
||||||
|
from ..extras.constants import CHOICES
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalTemplate:
|
class EvalTemplate:
|
||||||
|
|
||||||
system: str
|
system: str
|
||||||
choice: str
|
choice: str
|
||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
def parse_example(
|
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
self,
|
r"""
|
||||||
example: Dict[str, str]
|
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||||
) -> Tuple[str, str]:
|
output: a tuple of (prompt, response)
|
||||||
|
"""
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
def format_example(
|
def format_example(
|
||||||
self,
|
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||||
target_data: Dict[str, str],
|
) -> List[Dict[str, str]]:
|
||||||
support_set: "Dataset",
|
r"""
|
||||||
subject_name: str,
|
Converts dataset examples to messages.
|
||||||
use_history: bool
|
"""
|
||||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
messages = []
|
||||||
query, resp = self.parse_example(target_data)
|
for k in range(len(support_set)):
|
||||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
prompt, response = self._parse_example(support_set[k])
|
||||||
|
messages.append({"role": Role.USER.value, "content": prompt})
|
||||||
|
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||||
|
|
||||||
if len(history):
|
prompt, response = self._parse_example(target_data)
|
||||||
temp = history.pop(0)
|
messages.append({"role": Role.USER.value, "content": prompt})
|
||||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
messages.append({"role": Role.ASSISTANT.value, "content": response})
|
||||||
else:
|
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||||
query = self.system.format(subject=subject_name) + query
|
return messages
|
||||||
|
|
||||||
if not use_history:
|
|
||||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
|
||||||
history = []
|
|
||||||
return query.strip(), resp, history
|
|
||||||
|
|
||||||
|
|
||||||
eval_templates: Dict[str, EvalTemplate] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_eval_template(
|
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||||
name: str,
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||||
system: str,
|
|
||||||
choice: str,
|
|
||||||
answer: str,
|
|
||||||
prefix: str
|
|
||||||
) -> None:
|
|
||||||
eval_templates[name] = EvalTemplate(
|
|
||||||
system=system,
|
|
||||||
choice=choice,
|
|
||||||
answer=answer,
|
|
||||||
prefix=prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> EvalTemplate:
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
eval_template = eval_templates.get(name, None)
|
eval_template = eval_templates.get(name, None)
|
||||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||||
return eval_template
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
register_eval_template(
|
_register_eval_template(
|
||||||
name="en",
|
name="en",
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\nAnswer: ",
|
answer="\nAnswer: ",
|
||||||
prefix=" "
|
prefix=" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_eval_template(
|
_register_eval_template(
|
||||||
name="zh",
|
name="zh",
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix="\n"
|
prefix=" ",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,56 +1,38 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
|
||||||
|
from .constants import LOG_FILE_NAME
|
||||||
|
from .logging import get_logger
|
||||||
|
from .misc import fix_valuehead_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
model.pretrained_model.config.save_pretrained(output_dir)
|
|
||||||
if model.pretrained_model.can_generate():
|
|
||||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
|
||||||
if getattr(model, "is_peft_model", False):
|
|
||||||
model.pretrained_model.save_pretrained(output_dir)
|
|
||||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
|
||||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback):
|
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(
|
fix_valuehead_checkpoint(
|
||||||
model=unwrap_model(kwargs.pop("model")),
|
model=kwargs.pop("model"),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
|
safe_serialization=args.save_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the end of training.
|
|
||||||
"""
|
|
||||||
if args.should_save:
|
|
||||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.in_training = False
|
self.in_training = False
|
||||||
@@ -76,9 +58,17 @@ class LogCallback(TrainerCallback):
|
|||||||
self.in_training = True
|
self.in_training = True
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.max_steps = state.max_steps
|
self.max_steps = state.max_steps
|
||||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
if args.save_on_each_node:
|
||||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
if not state.is_local_process_zero:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||||
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
|
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
@@ -116,7 +106,9 @@ class LogCallback(TrainerCallback):
|
|||||||
self.cur_steps = 0
|
self.cur_steps = 0
|
||||||
self.max_steps = 0
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
def on_predict(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a successful prediction.
|
Event called after a successful prediction.
|
||||||
"""
|
"""
|
||||||
@@ -128,8 +120,12 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
if not state.is_local_process_zero:
|
if args.save_on_each_node:
|
||||||
return
|
if not state.is_local_process_zero:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
logs = dict(
|
logs = dict(
|
||||||
current_steps=self.cur_steps,
|
current_steps=self.cur_steps,
|
||||||
@@ -138,22 +134,27 @@ class LogCallback(TrainerCallback):
|
|||||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||||
reward=state.log_history[-1].get("reward", None),
|
reward=state.log_history[-1].get("reward", None),
|
||||||
|
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
||||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||||
epoch=state.log_history[-1].get("epoch", None),
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
elapsed_time=self.elapsed_time,
|
elapsed_time=self.elapsed_time,
|
||||||
remaining_time=self.remaining_time
|
remaining_time=self.remaining_time,
|
||||||
)
|
)
|
||||||
if self.runner is not None:
|
if self.runner is not None:
|
||||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
logger.info(
|
||||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||||
))
|
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(logs) + "\n")
|
f.write(json.dumps(logs) + "\n")
|
||||||
|
|
||||||
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_prediction_step(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a prediction step.
|
Event called after a prediction step.
|
||||||
"""
|
"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
@@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
|
|||||||
Gets a standard logger with a stream hander to stdout.
|
Gets a standard logger with a stream hander to stdout.
|
||||||
"""
|
"""
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
|
||||||
)
|
)
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|||||||
@@ -1,34 +1,46 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import torch
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
from typing import TYPE_CHECKING, Tuple
|
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||||
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_torch_bf16_gpu_available,
|
||||||
|
is_torch_cuda_available,
|
||||||
|
is_torch_mps_available,
|
||||||
|
is_torch_npu_available,
|
||||||
|
is_torch_xpu_available,
|
||||||
|
)
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
from transformers.utils import (
|
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
is_torch_bf16_cpu_available,
|
except Exception:
|
||||||
is_torch_bf16_gpu_available,
|
_is_bf16_available = False
|
||||||
is_torch_cuda_available,
|
|
||||||
is_torch_npu_available
|
|
||||||
)
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
|
||||||
except ImportError:
|
|
||||||
_is_fp16_available = torch.cuda.is_available()
|
|
||||||
try:
|
|
||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
|
||||||
except:
|
|
||||||
_is_bf16_available = False
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import HfArgumentParser
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
r"""
|
r"""
|
||||||
Computes and stores the average and current value.
|
Computes and stores the average and current value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -45,6 +57,18 @@ class AverageMeter:
|
|||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def check_dependencies() -> None:
|
||||||
|
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||||
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
else:
|
||||||
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
|
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||||
|
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||||
|
require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
@@ -58,7 +82,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
num_params = num_params * 2
|
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||||
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
else:
|
||||||
|
num_bytes = 1
|
||||||
|
|
||||||
|
num_params = num_params * 2 * num_bytes
|
||||||
|
|
||||||
all_param += num_params
|
all_param += num_params
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
@@ -67,13 +96,65 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def fix_valuehead_checkpoint(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
The model is already unwrapped.
|
||||||
|
|
||||||
|
There are three cases:
|
||||||
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||||
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
|
||||||
|
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
|
return
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
else:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
decoder_state_dict = {}
|
||||||
|
v_head_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("v_head."):
|
||||||
|
v_head_state_dict[name] = param
|
||||||
|
else:
|
||||||
|
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
|
model.pretrained_model.save_pretrained(
|
||||||
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
def get_current_device() -> torch.device:
|
||||||
import accelerate
|
r"""
|
||||||
if accelerate.utils.is_xpu_available():
|
Gets the current available device.
|
||||||
|
"""
|
||||||
|
if is_torch_xpu_available():
|
||||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
elif accelerate.utils.is_npu_available():
|
elif is_torch_npu_available():
|
||||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
elif torch.cuda.is_available():
|
elif is_torch_mps_available():
|
||||||
|
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif is_torch_cuda_available():
|
||||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
@@ -81,6 +162,16 @@ def get_current_device() -> torch.device:
|
|||||||
return torch.device(device)
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_count() -> int:
|
||||||
|
r"""
|
||||||
|
Gets the number of available GPU devices.
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
r"""
|
r"""
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
@@ -102,6 +193,13 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def has_tokenized_data(path: os.PathLike) -> bool:
|
||||||
|
r"""
|
||||||
|
Checks if the path has a tokenized dataset.
|
||||||
|
"""
|
||||||
|
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
r"""
|
r"""
|
||||||
Collects GPU memory.
|
Collects GPU memory.
|
||||||
@@ -112,18 +210,15 @@ def torch_gc() -> None:
|
|||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||||
return
|
return model_args.model_name_or_path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modelscope import snapshot_download # type: ignore
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
model_args.model_name_or_path = snapshot_download(
|
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
||||||
model_args.model_name_or_path,
|
|
||||||
revision=revision,
|
|
||||||
cache_dir=model_args.cache_dir
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
|
||||||
|
|||||||
@@ -2,48 +2,60 @@ import importlib.metadata
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
def is_package_available(name: str) -> bool:
|
def _is_package_available(name: str) -> bool:
|
||||||
return importlib.util.find_spec(name) is not None
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
def get_package_version(name: str) -> str:
|
def _get_package_version(name: str) -> str:
|
||||||
try:
|
try:
|
||||||
return importlib.metadata.version(name)
|
return importlib.metadata.version(name)
|
||||||
except:
|
except Exception:
|
||||||
return "0.0.0"
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
def is_fastapi_availble():
|
def is_fastapi_availble():
|
||||||
return is_package_available("fastapi")
|
return _is_package_available("fastapi")
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn2_available():
|
def is_flash_attn2_available():
|
||||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
|
def is_galore_available():
|
||||||
|
return _is_package_available("galore_torch")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return is_package_available("jieba")
|
return _is_package_available("jieba")
|
||||||
|
|
||||||
|
|
||||||
def is_matplotlib_available():
|
def is_matplotlib_available():
|
||||||
return is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
def is_nltk_available():
|
def is_nltk_available():
|
||||||
return is_package_available("nltk")
|
return _is_package_available("nltk")
|
||||||
|
|
||||||
|
|
||||||
def is_requests_available():
|
def is_requests_available():
|
||||||
return is_package_available("requests")
|
return _is_package_available("requests")
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return is_package_available("rouge_chinese")
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
|
def is_unsloth_available():
|
||||||
|
return _is_package_available("unsloth")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_available():
|
||||||
|
return _is_package_available("vllm")
|
||||||
|
|||||||
@@ -1,224 +1,198 @@
|
|||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
Cache,
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.llama.modeling_llama import repeat_kv
|
|
||||||
except ImportError:
|
|
||||||
print("Please upgrade `transformers`.")
|
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn2_available():
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# Modified from:
|
||||||
class LlamaShiftShortAttention(LlamaAttention):
|
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||||
|
def llama_torch_attn_forward(
|
||||||
|
self: "LlamaAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional["Cache"] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
def forward(
|
query_states = self.q_proj(hidden_states)
|
||||||
self,
|
key_states = self.k_proj(hidden_states)
|
||||||
hidden_states: torch.Tensor,
|
value_states = self.v_proj(hidden_states)
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
if past_key_value is not None:
|
||||||
if past_key_value is not None:
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
if past_key_value is not None: # reuse k, v, self_attention
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
if getattr(self, "num_key_value_groups"):
|
state = torch.cat(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
dim=2,
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
|
||||||
num_groups = q_len // groupsz
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
|
||||||
state = torch.cat((
|
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
|
||||||
), dim=2)
|
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
attn_output = torch.cat((
|
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
|
||||||
))
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaFlashAttention2(LlamaAttention):
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
# LlamaFlashAttention2 attention does not support output_attentions
|
|
||||||
output_attentions = False
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None: # reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# cast to half precision
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
|
||||||
query_states = query_states.to(self.config.torch_dtype)
|
|
||||||
key_states = key_states.to(self.config.torch_dtype)
|
|
||||||
value_states = value_states.to(self.config.torch_dtype)
|
|
||||||
|
|
||||||
if getattr(self, "num_key_value_groups", None):
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
|
||||||
num_groups = q_len // groupsz
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
|
||||||
state = torch.cat((
|
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
|
||||||
), dim=2)
|
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
|
||||||
# -q_len: assumes left padding when q_len != kv_len
|
|
||||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
|
||||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
|
||||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
|
||||||
attn_output_unpad = flash_attn_varlen_func(
|
|
||||||
unpadded_q,
|
|
||||||
unpadded_k,
|
|
||||||
unpadded_v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
)
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = torch.cat(
|
||||||
|
(
|
||||||
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from:
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||||
|
def llama_flash_attn_forward(
|
||||||
|
self: "LlamaFlashAttention2",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional["Cache"] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
else:
|
else:
|
||||||
attn_output = flash_attn_func(
|
target_dtype = self.q_proj.weight.dtype
|
||||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
|
||||||
|
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
|
state = torch.cat(
|
||||||
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
if attention_mask is not None:
|
||||||
attn_output = torch.cat((
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
|
||||||
))
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||||
attn_output = self.o_proj(attn_output)
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
|
)
|
||||||
|
|
||||||
if not output_attentions:
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_weights = None
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = torch.cat(
|
||||||
|
(
|
||||||
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
def apply_llama_patch() -> None:
|
||||||
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3")
|
||||||
def _prepare_decoder_attention_mask(
|
LlamaAttention.forward = llama_torch_attn_forward
|
||||||
self,
|
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
input_shape: torch.Tensor,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
past_key_values_length: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if attention_mask is not None and torch.all(attention_mask):
|
|
||||||
return None # This uses the faster call when training with full samples
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import os
|
|
||||||
import math
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
import math
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from .logging import get_logger
|
||||||
from llmtuner.extras.packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
if is_matplotlib_available():
|
if is_matplotlib_available():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
"""
|
"""
|
||||||
last = scalars[0]
|
last = scalars[0]
|
||||||
smoothed = list()
|
smoothed = list()
|
||||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||||
for next_val in scalars:
|
for next_val in scalars:
|
||||||
smoothed_val = last * weight + (1 - weight) * next_val
|
smoothed_val = last * weight + (1 - weight) * next_val
|
||||||
smoothed.append(smoothed_val)
|
smoothed.append(smoothed_val)
|
||||||
@@ -28,8 +30,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
return smoothed
|
return smoothed
|
||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||||
|
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
@@ -45,11 +46,12 @@ def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||||
plt.title("training {} of {}".format(key, save_dictionary))
|
plt.title("training {} of {}".format(key, save_dictionary))
|
||||||
plt.xlabel("step")
|
plt.xlabel("step")
|
||||||
plt.ylabel(key)
|
plt.ylabel(key)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
|
||||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
plt.savefig(figure_path, format="png", dpi=100)
|
||||||
|
print("Figure saved at:", figure_path)
|
||||||
|
|||||||
@@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
|
|||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataArguments",
|
||||||
|
"EvaluationArguments",
|
||||||
|
"FinetuningArguments",
|
||||||
|
"GeneratingArguments",
|
||||||
|
"ModelArguments",
|
||||||
|
"get_eval_args",
|
||||||
|
"get_infer_args",
|
||||||
|
"get_train_args",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,38 +1,5 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import List, Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
DATA_CONFIG = "dataset_info.json"
|
|
||||||
|
|
||||||
|
|
||||||
def use_modelscope() -> bool:
|
|
||||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetAttr:
|
|
||||||
|
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
|
||||||
dataset_name: Optional[str] = None
|
|
||||||
dataset_sha1: Optional[str] = None
|
|
||||||
subset: Optional[str] = None
|
|
||||||
folder: Optional[str] = None
|
|
||||||
ranking: Optional[bool] = False
|
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
|
||||||
|
|
||||||
prompt: Optional[str] = "instruction"
|
|
||||||
query: Optional[str] = "input"
|
|
||||||
response: Optional[str] = "output"
|
|
||||||
history: Optional[str] = None
|
|
||||||
messages: Optional[str] = "conversations"
|
|
||||||
role: Optional[str] = "from"
|
|
||||||
content: Optional[str] = "value"
|
|
||||||
system: Optional[str] = None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.dataset_name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -40,81 +7,86 @@ class DataArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
template: Optional[str] = field(
|
template: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: str = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."}
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: str = field(
|
||||||
default="train",
|
default="train",
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The maximum length reserved for label after tokenization."}
|
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||||
)
|
)
|
||||||
streaming: Optional[bool] = field(
|
streaming: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable dataset streaming."}
|
metadata={"help": "Enable dataset streaming."},
|
||||||
)
|
)
|
||||||
buffer_size: Optional[int] = field(
|
buffer_size: int = field(
|
||||||
default=16384,
|
default=16384,
|
||||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||||
)
|
)
|
||||||
eval_num_beams: Optional[int] = field(
|
eval_num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
metadata={
|
||||||
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||||
)
|
)
|
||||||
sft_packing: Optional[bool] = field(
|
packing: Optional[bool] = field(
|
||||||
default=False,
|
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
|
||||||
)
|
|
||||||
cache_path: Optional[str] = field(
|
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
metadata={
|
||||||
|
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenized_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -126,67 +98,3 @@ class DataArguments:
|
|||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
if self.streaming and self.cache_path:
|
|
||||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
|
||||||
|
|
||||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
|
||||||
self.seed = seed
|
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
|
||||||
try:
|
|
||||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
except Exception as err:
|
|
||||||
if self.dataset is not None:
|
|
||||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
|
||||||
dataset_info = None
|
|
||||||
|
|
||||||
if self.interleave_probs is not None:
|
|
||||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
|
||||||
for name in dataset_names:
|
|
||||||
if name not in dataset_info:
|
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
|
||||||
|
|
||||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
|
||||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
|
||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"ms_hub",
|
|
||||||
dataset_name=dataset_info[name]["ms_hub_url"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"hf_hub",
|
|
||||||
dataset_name=dataset_info[name]["hf_hub_url"]
|
|
||||||
)
|
|
||||||
elif "script_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"script",
|
|
||||||
dataset_name=dataset_info[name]["script_url"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"file",
|
|
||||||
dataset_name=dataset_info[name]["file_name"],
|
|
||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
|
||||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
|
||||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
|
||||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
|
||||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
|
||||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
|
||||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
|
||||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
|
||||||
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
|
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
|
||||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
|
||||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
|
||||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
|
||||||
self.dataset_list.append(dataset_attr)
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from datasets import DownloadMode
|
from datasets import DownloadMode
|
||||||
|
|
||||||
@@ -10,46 +10,39 @@ class EvaluationArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to specify the evaluation parameters.
|
Arguments pertaining to specify the evaluation parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task: str = field(
|
task: str = field(
|
||||||
metadata={"help": "Name of the evaluation task."}
|
metadata={"help": "Name of the evaluation task."},
|
||||||
)
|
)
|
||||||
task_dir: Optional[str] = field(
|
task_dir: str = field(
|
||||||
default="evaluation",
|
default="evaluation",
|
||||||
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||||
)
|
)
|
||||||
batch_size: Optional[int] = field(
|
batch_size: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The batch size per GPU for evaluation."}
|
metadata={"help": "The batch size per GPU for evaluation."},
|
||||||
)
|
)
|
||||||
seed: Optional[int] = field(
|
seed: int = field(
|
||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed to be used with data loaders."}
|
metadata={"help": "Random seed to be used with data loaders."},
|
||||||
)
|
)
|
||||||
lang: Optional[Literal["en", "zh"]] = field(
|
lang: Literal["en", "zh"] = field(
|
||||||
default="en",
|
default="en",
|
||||||
metadata={"help": "Language used at evaluation."}
|
metadata={"help": "Language used at evaluation."},
|
||||||
)
|
)
|
||||||
n_shot: Optional[int] = field(
|
n_shot: int = field(
|
||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."}
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
save_dir: Optional[str] = field(
|
save_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."}
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
download_mode: Optional[DownloadMode] = field(
|
download_mode: DownloadMode = field(
|
||||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
metadata={"help": "Download mode used for the evaluation datasets."}
|
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
task_available = []
|
|
||||||
for folder in os.listdir(self.task_dir):
|
|
||||||
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
|
||||||
task_available.append(folder)
|
|
||||||
|
|
||||||
if self.task not in task_available:
|
|
||||||
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
|
||||||
|
|
||||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||||
raise ValueError("`save_dir` already exists, use another one.")
|
raise ValueError("`save_dir` already exists, use another one.")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -8,19 +8,23 @@ class FreezeArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the freeze (partial-parameter) training.
|
Arguments pertaining to the freeze (partial-parameter) training.
|
||||||
"""
|
"""
|
||||||
name_module_trainable: Optional[str] = field(
|
|
||||||
default="mlp",
|
name_module_trainable: str = field(
|
||||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
default="all",
|
||||||
Use commas to separate multiple modules. \
|
metadata={
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
Use commas to separate multiple modules. \
|
||||||
Qwen choices: [\"mlp\", \"attn\"], \
|
Use "all" to specify all the available modules. \
|
||||||
Phi choices: [\"mlp\", \"mixer\"], \
|
LLaMA choices: ["mlp", "self_attn"], \
|
||||||
Others choices: the same as LLaMA."}
|
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||||
|
Qwen choices: ["mlp", "attn"], \
|
||||||
|
InternLM2 choices: ["feed_forward", "attention"], \
|
||||||
|
Others choices: the same as LLaMA."""
|
||||||
|
},
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: int = field(
|
||||||
default=3,
|
default=2,
|
||||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -29,35 +33,58 @@ class LoraArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the LoRA training.
|
Arguments pertaining to the LoRA training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
metadata={
|
||||||
|
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[int] = field(
|
lora_alpha: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: float = field(
|
||||||
default=0.1,
|
default=0.0,
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: int = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_target: Optional[str] = field(
|
lora_target: str = field(
|
||||||
|
default="all",
|
||||||
|
metadata={
|
||||||
|
"help": """Name(s) of target modules to apply LoRA. \
|
||||||
|
Use commas to separate multiple modules. \
|
||||||
|
Use "all" to specify all the linear modules. \
|
||||||
|
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
|
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||||
|
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
|
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||||
|
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||||
|
Others choices: the same as LLaMA."""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
|
||||||
Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
|
||||||
Others choices: the same as LLaMA."}
|
|
||||||
)
|
)
|
||||||
create_new_adapter: Optional[bool] = field(
|
loraplus_lr_embedding: float = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
|
use_rslora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."}
|
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||||
|
)
|
||||||
|
use_dora: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||||
|
)
|
||||||
|
create_new_adapter: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,123 +93,141 @@ class RLHFArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the PPO and DPO training.
|
Arguments pertaining to the PPO and DPO training.
|
||||||
"""
|
"""
|
||||||
dpo_beta: Optional[float] = field(
|
|
||||||
|
dpo_beta: float = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."},
|
||||||
)
|
)
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
|
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."}
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
)
|
)
|
||||||
dpo_ftx: Optional[float] = field(
|
dpo_label_smoothing: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
||||||
)
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
dpo_ftx: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
|
)
|
||||||
|
orpo_beta: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||||
|
)
|
||||||
|
ppo_buffer_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_epochs: Optional[int] = field(
|
ppo_epochs: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_logger: Optional[str] = field(
|
ppo_score_norm: bool = field(
|
||||||
default=None,
|
|
||||||
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
|
|
||||||
)
|
|
||||||
ppo_score_norm: Optional[bool] = field(
|
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use score normalization in PPO training."}
|
metadata={"help": "Use score normalization in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: float = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_whiten_rewards: Optional[bool] = field(
|
ppo_whiten_rewards: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
ref_model: Optional[str] = field(
|
ref_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||||
)
|
)
|
||||||
ref_model_adapters: Optional[str] = field(
|
ref_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reference model."}
|
metadata={"help": "Path to the adapters of the reference model."},
|
||||||
)
|
)
|
||||||
ref_model_quantization_bit: Optional[int] = field(
|
ref_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reference model."}
|
metadata={"help": "The number of bits to quantize the reference model."},
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reward model used for the PPO training."}
|
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||||
)
|
)
|
||||||
reward_model_adapters: Optional[str] = field(
|
reward_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reward model."}
|
metadata={"help": "Path to the adapters of the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_quantization_bit: Optional[int] = field(
|
reward_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."}
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
reward_model_type: Literal["lora", "full", "api"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExportArguments:
|
class GaloreArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to model exporting.
|
Arguments pertaining to the GaLore algorithm.
|
||||||
"""
|
"""
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
use_galore: bool = field(
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||||
)
|
)
|
||||||
export_size: Optional[int] = field(
|
galore_target: str = field(
|
||||||
default=1,
|
default="all",
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
metadata={
|
||||||
|
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
|
||||||
|
Use "all" to specify all the linear modules."""
|
||||||
|
},
|
||||||
)
|
)
|
||||||
export_quantization_bit: Optional[int] = field(
|
galore_rank: int = field(
|
||||||
default=None,
|
default=16,
|
||||||
metadata={"help": "The number of bits to quantize the exported model."}
|
metadata={"help": "The rank of GaLore gradients."},
|
||||||
)
|
)
|
||||||
export_quantization_dataset: Optional[str] = field(
|
galore_update_interval: int = field(
|
||||||
default=None,
|
default=200,
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||||
)
|
)
|
||||||
export_quantization_nsamples: Optional[int] = field(
|
galore_scale: float = field(
|
||||||
default=128,
|
default=0.25,
|
||||||
metadata={"help": "The number of samples used for quantization."}
|
metadata={"help": "GaLore scaling coefficient."},
|
||||||
)
|
)
|
||||||
export_quantization_maxlen: Optional[str] = field(
|
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
||||||
default=1024,
|
default="std",
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
metadata={"help": "Type of GaLore projection."},
|
||||||
|
)
|
||||||
|
galore_layerwise: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
|
||||||
|
pure_bf16: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||||
|
)
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
use_llama_pro: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -195,22 +240,26 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA
|
|||||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
self.ref_model_adapters = split_arg(self.ref_model_adapters)
|
self.galore_target = split_arg(self.galore_target)
|
||||||
self.reward_model_adapters = split_arg(self.reward_model_adapters)
|
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||||
|
|
||||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
if self.stage == "dpo" and self.dpo_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||||
|
|
||||||
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
|
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||||
|
|
||||||
|
if self.use_galore and self.finetuning_type == "lora":
|
||||||
|
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -7,41 +7,44 @@ class GeneratingArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
do_sample: Optional[bool] = field(
|
|
||||||
|
do_sample: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = field(
|
temperature: float = field(
|
||||||
default=0.95,
|
default=0.95,
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
metadata={"help": "The value used to modulate the next token probabilities."},
|
||||||
)
|
)
|
||||||
top_p: Optional[float] = field(
|
top_p: float = field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
metadata={
|
||||||
|
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
top_k: Optional[int] = field(
|
top_k: int = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||||
)
|
)
|
||||||
max_length: Optional[int] = field(
|
max_length: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = field(
|
repetition_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
||||||
)
|
)
|
||||||
length_penalty: Optional[float] = field(
|
length_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user