Compare commits
283 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48d2e6d7fe | ||
|
|
041c83ea03 | ||
|
|
0e621c2dc9 | ||
|
|
544e7a491b | ||
|
|
a2c881fa08 | ||
|
|
c53c7af168 | ||
|
|
a2d93e5269 | ||
|
|
b392e6cfb9 | ||
|
|
13aa2d389a | ||
|
|
1e7962dfc4 | ||
|
|
1c9556c84c | ||
|
|
ca3ca7a5b5 | ||
|
|
0500befdb4 | ||
|
|
f618feab51 | ||
|
|
4b06aa134f | ||
|
|
9cde56d760 | ||
|
|
d0ea203694 | ||
|
|
c5eb3fba62 | ||
|
|
a8bc32553c | ||
|
|
88f3358320 | ||
|
|
a85bdcf2f6 | ||
|
|
caf56b313e | ||
|
|
75603c45fc | ||
|
|
89f86cc970 | ||
|
|
c09a0e4f08 | ||
|
|
7bac6c9460 | ||
|
|
0c7d0bf172 | ||
|
|
a274900188 | ||
|
|
67deefe527 | ||
|
|
823f618cba | ||
|
|
bc16c9a54a | ||
|
|
a3f30038a0 | ||
|
|
e237f618c2 | ||
|
|
688adad665 | ||
|
|
0158812afb | ||
|
|
e52e0d9b07 | ||
|
|
eb2aa2c073 | ||
|
|
debfd46749 | ||
|
|
5ccf8fcd6b | ||
|
|
7bd1991513 | ||
|
|
456e4ca569 | ||
|
|
6bf0fe4913 | ||
|
|
596b6828cb | ||
|
|
b403f8d8a8 | ||
|
|
590b6c2143 | ||
|
|
5537ef1e7d | ||
|
|
5f83860aa1 | ||
|
|
62b6a7971a | ||
|
|
1d16e87c5f | ||
|
|
1955a8ea5a | ||
|
|
a41fa6e730 | ||
|
|
b98a64448a | ||
|
|
1ce82f391a | ||
|
|
4d473894fd | ||
|
|
5788b7c7d0 | ||
|
|
04515f6b55 | ||
|
|
96f8ccf3d5 | ||
|
|
2c3ef480a6 | ||
|
|
fa6873122c | ||
|
|
34bc0c22b1 | ||
|
|
e5484b2729 | ||
|
|
f67f781fed | ||
|
|
b564b97b7e | ||
|
|
0dd68d1e06 | ||
|
|
73f40f1ca4 | ||
|
|
ea53bebac4 | ||
|
|
00418012bd | ||
|
|
5f3d8c514b | ||
|
|
cb39a3f1c4 | ||
|
|
4d78fe6ece | ||
|
|
a3e3ea9846 | ||
|
|
feba34e82d | ||
|
|
e134013e04 | ||
|
|
5589d0296a | ||
|
|
de0ebab464 | ||
|
|
f2e7122a96 | ||
|
|
996cc5d900 | ||
|
|
a2ae5bd867 | ||
|
|
5fa52e87cb | ||
|
|
bcd76d2c7a | ||
|
|
36fcbedc11 | ||
|
|
1dad01cc53 | ||
|
|
5fb21f6e54 | ||
|
|
08dfac8352 | ||
|
|
956751e419 | ||
|
|
fe2ae04c91 | ||
|
|
5b8712d061 | ||
|
|
dc7ff90c1e | ||
|
|
1ace676170 | ||
|
|
8947a87b95 | ||
|
|
786a2f1103 | ||
|
|
36ac14a566 | ||
|
|
7a048fc91d | ||
|
|
3f3756b113 | ||
|
|
b36c4b99cc | ||
|
|
9856a2276e | ||
|
|
b6dc3ed3ad | ||
|
|
75be329994 | ||
|
|
1fe1ca1c8b | ||
|
|
882a6a1d51 | ||
|
|
712ab4ae7a | ||
|
|
18ad259fb3 | ||
|
|
fe4d93c6db | ||
|
|
c6ba588e37 | ||
|
|
3fda60fca0 | ||
|
|
96531a0ef8 | ||
|
|
7abc3065fb | ||
|
|
013ded4bac | ||
|
|
010c3c7348 | ||
|
|
bf075c075c | ||
|
|
41b34e5f60 | ||
|
|
5a889398e7 | ||
|
|
054cae86d8 | ||
|
|
cd1cb8b83c | ||
|
|
a34779c027 | ||
|
|
d19cb77d74 | ||
|
|
ab67528e89 | ||
|
|
27f281480a | ||
|
|
50459a39f4 | ||
|
|
5c9815ef6f | ||
|
|
aed00a97b6 | ||
|
|
7543dc4a9d | ||
|
|
841fa0030f | ||
|
|
66e0e651b9 | ||
|
|
1750218057 | ||
|
|
80637fc06d | ||
|
|
8efc055511 | ||
|
|
be61bfda93 | ||
|
|
1a39f529c0 | ||
|
|
0868d5c550 | ||
|
|
384f0e7678 | ||
|
|
9b390c4bea | ||
|
|
42a13fec46 | ||
|
|
790acc4c17 | ||
|
|
b74cf27538 | ||
|
|
ffc874ec6f | ||
|
|
546d6bd0b2 | ||
|
|
8b68ca029e | ||
|
|
502f84b30c | ||
|
|
b7df920860 | ||
|
|
e4a424cb6a | ||
|
|
d8affd3967 | ||
|
|
a423274fd9 | ||
|
|
f7329b1a0e | ||
|
|
48eb07c956 | ||
|
|
636d8a886c | ||
|
|
97b52c7fdf | ||
|
|
344412e66e | ||
|
|
5cdea14cdf | ||
|
|
7b1a56b96f | ||
|
|
d1ec884e75 | ||
|
|
aa72a4349e | ||
|
|
5ab7fd0842 | ||
|
|
86d5e9802a | ||
|
|
18df39e3a1 | ||
|
|
cfe1e24471 | ||
|
|
2edbe87a8c | ||
|
|
880055bc90 | ||
|
|
ad99bd0a14 | ||
|
|
c5f099138d | ||
|
|
6e64e02f71 | ||
|
|
f95f6ec009 | ||
|
|
8aeecc20e1 | ||
|
|
38d0f6c63f | ||
|
|
ac8534a9e7 | ||
|
|
73cab9d9d4 | ||
|
|
64246d42d2 | ||
|
|
6fa6d4532e | ||
|
|
92b9956c06 | ||
|
|
4d6669c268 | ||
|
|
89f4ae51f9 | ||
|
|
af0659f573 | ||
|
|
45a10d501e | ||
|
|
e529ff1245 | ||
|
|
b29371dc87 | ||
|
|
0bef890000 | ||
|
|
75fe1404b1 | ||
|
|
b460c9372f | ||
|
|
c3e574ceaa | ||
|
|
04ae80a52e | ||
|
|
a7ff095399 | ||
|
|
a655dcebaf | ||
|
|
8c74851b70 | ||
|
|
7168392a51 | ||
|
|
ccc5b324fe | ||
|
|
e85c205a81 | ||
|
|
7e225be16e | ||
|
|
ebb32e85f8 | ||
|
|
90d279f39f | ||
|
|
af3f5b6e16 | ||
|
|
53d7c5109f | ||
|
|
bf381563ff | ||
|
|
de4b9334e1 | ||
|
|
c33fbea469 | ||
|
|
921f593632 | ||
|
|
940403720a | ||
|
|
f869e44fe5 | ||
|
|
bcc92919a0 | ||
|
|
306a70c7ba | ||
|
|
d358d955e5 | ||
|
|
0fdd6074c3 | ||
|
|
6faf9c35a9 | ||
|
|
1066898e32 | ||
|
|
d05febe5de | ||
|
|
67f7034a21 | ||
|
|
79f301a2c6 | ||
|
|
31cbc67986 | ||
|
|
fe66bf3663 | ||
|
|
4691d4b35d | ||
|
|
acf5241845 | ||
|
|
2bce99b82f | ||
|
|
3c330869ef | ||
|
|
dba1af4841 | ||
|
|
2b1e52dcc9 | ||
|
|
b5238e945a | ||
|
|
afc0f29704 | ||
|
|
de0bb1d2da | ||
|
|
cc16ece283 | ||
|
|
31ba802fc9 | ||
|
|
4b27cf5460 | ||
|
|
a53b2a643f | ||
|
|
d925ecae1b | ||
|
|
13fd751a78 | ||
|
|
74575f8922 | ||
|
|
5e7bb5fe73 | ||
|
|
790a31404a | ||
|
|
f927601702 | ||
|
|
c4654d54d7 | ||
|
|
df777c30d1 | ||
|
|
d81ad2d4bc | ||
|
|
9f77e8b025 | ||
|
|
04dc3f4614 | ||
|
|
7d1fe50977 | ||
|
|
c0e5e3c5d5 | ||
|
|
3a45cfb604 | ||
|
|
393e4b0f5a | ||
|
|
296711d502 | ||
|
|
9121722999 | ||
|
|
d8d74091f6 | ||
|
|
33521fb45e | ||
|
|
e5204e60ed | ||
|
|
0409428d87 | ||
|
|
f902b0d420 | ||
|
|
27ef5b1aa7 | ||
|
|
c32303fc7e | ||
|
|
45abe361ba | ||
|
|
3ae479faae | ||
|
|
5698038f49 | ||
|
|
020233f725 | ||
|
|
6f9d55b8eb | ||
|
|
2542b62d77 | ||
|
|
95678bb6b1 | ||
|
|
a78759e7ee | ||
|
|
cc5c523f58 | ||
|
|
e39bbdd287 | ||
|
|
d9a50bf93f | ||
|
|
934d00ea1e | ||
|
|
c27675f70d | ||
|
|
7c9f37c83d | ||
|
|
b9736c13e0 | ||
|
|
c47725ff34 | ||
|
|
3ee3fe0bbb | ||
|
|
e54dad75da | ||
|
|
39c2f03eab | ||
|
|
fb9e1c4087 | ||
|
|
ed26bb3d82 | ||
|
|
0baf32e219 | ||
|
|
79a376d1db | ||
|
|
b634e91c43 | ||
|
|
9e2cc21d04 | ||
|
|
6975124a57 | ||
|
|
9f69307db1 | ||
|
|
c3448a045c | ||
|
|
95c561983c | ||
|
|
7a03c8dab5 | ||
|
|
f3ffa8310f | ||
|
|
596f496f19 | ||
|
|
2e6ed731cf | ||
|
|
24ce319b6f | ||
|
|
7b7bfea37d | ||
|
|
3be461260a | ||
|
|
8dab8d9831 | ||
|
|
fb4c5f3c91 |
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# What does this PR do?
|
||||||
|
|
||||||
|
Fixes # (issue)
|
||||||
|
|
||||||
|
## Before submitting
|
||||||
|
|
||||||
|
- [ ] Did you read the [contributor guideline](/CONTRIBUTING.md)?
|
||||||
29
.github/workflows/tests.yml
vendored
Normal file
29
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_code_quality:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.8"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install ruff
|
||||||
|
|
||||||
|
- name: Check quality
|
||||||
|
run: |
|
||||||
|
make style && make quality
|
||||||
21
CONTRIBUTING.md
Normal file
21
CONTRIBUTING.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Contributing to LLaMA Factory
|
||||||
|
|
||||||
|
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||||
|
|
||||||
|
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||||
|
|
||||||
|
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
|
||||||
|
|
||||||
|
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
|
||||||
|
|
||||||
|
## Ways to contribute
|
||||||
|
|
||||||
|
There are several ways you can contribute to LLaMA Factory:
|
||||||
|
|
||||||
|
* Fix outstanding issues with the existing code.
|
||||||
|
* Submit issues related to bugs or desired new features.
|
||||||
|
* Contribute to the examples or to the documentation.
|
||||||
|
|
||||||
|
### Style guide
|
||||||
|
|
||||||
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
11
Makefile
Normal file
11
Makefile
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
.PHONY: quality style
|
||||||
|
|
||||||
|
check_dirs := src tests
|
||||||
|
|
||||||
|
quality:
|
||||||
|
ruff $(check_dirs)
|
||||||
|
ruff format --check $(check_dirs)
|
||||||
|
|
||||||
|
style:
|
||||||
|
ruff $(check_dirs) --fix
|
||||||
|
ruff format $(check_dirs)
|
||||||
226
README.md
226
README.md
@@ -5,8 +5,9 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/c2EPEt5NU)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
|
||||||
@@ -16,9 +17,7 @@
|
|||||||
|
|
||||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||||
|
|
||||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** and **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**, or launch it locally with `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`.
|
||||||
|
|
||||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
|
||||||
|
|
||||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||||
|
|
||||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Features](#features)
|
||||||
- [Benchmark](#benchmark)
|
- [Benchmark](#benchmark)
|
||||||
- [Changelog](#changelog)
|
- [Changelog](#changelog)
|
||||||
- [Supported Models](#supported-models)
|
- [Supported Models](#supported-models)
|
||||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [Citation](#citation)
|
- [Citation](#citation)
|
||||||
- [Acknowledgement](#acknowledgement)
|
- [Acknowledgement](#acknowledgement)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||||
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
|
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning.
|
||||||
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA.
|
||||||
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||||
@@ -55,17 +64,29 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
|
||||||
|
|
||||||
|
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||||
|
|
||||||
|
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||||
|
|
||||||
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||||
|
|
||||||
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||||
|
|
||||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||||
|
|
||||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||||
|
|
||||||
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||||
|
|
||||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||||
|
|
||||||
@@ -91,19 +112,24 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
| Model | Model size | Default module | Template |
|
| Model | Model size | Default module | Template |
|
||||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
|
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||||
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||||
@@ -114,7 +140,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
@@ -123,7 +149,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
> Use `--quantization_bit 4` argument to enable QLoRA.
|
||||||
|
|
||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
@@ -145,8 +171,8 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
|
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self Cognition (zh)](data/self_cognition.json)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
@@ -162,11 +188,14 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
|
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
|
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@@ -174,6 +203,16 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
|
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||||
|
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||||
|
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||||
|
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||||
|
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||||
|
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -183,6 +222,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -197,22 +237,34 @@ huggingface-cli login
|
|||||||
|
|
||||||
## Requirement
|
## Requirement
|
||||||
|
|
||||||
- Python 3.8+ and PyTorch 1.13.1+
|
| Mandatory | Minimum | Recommend |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf and tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
| torch | 1.13.1 | 2.2.1 |
|
||||||
- gradio and matplotlib (used in web UI)
|
| transformers | 4.37.2 | 4.38.1 |
|
||||||
- uvicorn, fastapi and sse-starlette (used in API)
|
| datasets | 2.14.3 | 2.17.1 |
|
||||||
|
| accelerate | 0.27.2 | 0.27.2 |
|
||||||
|
| peft | 0.9.0 | 0.9.0 |
|
||||||
|
| trl | 0.7.11 | 0.7.11 |
|
||||||
|
|
||||||
|
| Optional | Minimum | Recommend |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.13.4 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.5 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 65B |
|
\* *estimated*
|
||||||
| ------ | ---- | ----- | ----- | ----- | ------ |
|
|
||||||
| Full | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
@@ -233,15 +285,17 @@ cd LLaMA-Factory
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use ModelScope Models (optional)
|
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||||
|
|
||||||
If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
### Use ModelScope Hub (optional)
|
||||||
|
|
||||||
|
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||||
@@ -255,7 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
... # arguments (same as above)
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board also supports using the models on the ModelScope Hub.
|
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
@@ -271,8 +325,8 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
@@ -294,8 +348,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
@@ -318,14 +372,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
@@ -343,14 +397,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
--output_dir path_to_ppo_checkpoint \
|
--output_dir path_to_ppo_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
@@ -366,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
||||||
|
|
||||||
@@ -374,14 +431,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage dpo \
|
--stage dpo \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--output_dir path_to_dpo_checkpoint \
|
--output_dir path_to_dpo_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
@@ -394,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
#### Use Huggingface Accelerate
|
#### Use Huggingface Accelerate
|
||||||
@@ -407,6 +467,7 @@ accelerate launch src/train_bash.py # arguments (same as above)
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
distributed_type: MULTI_GPU
|
distributed_type: MULTI_GPU
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
gpu_ids: all
|
gpu_ids: all
|
||||||
@@ -449,7 +510,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
@@ -469,43 +530,51 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--export_dir path_to_export \
|
||||||
--export_dir path_to_export
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
```
|
```
|
||||||
|
|
||||||
### API Demo
|
> [!WARNING]
|
||||||
|
> Merging LoRA weights into a quantized model is not supported.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
|
||||||
|
|
||||||
|
### Inference with OpenAI-style API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit `http://localhost:8000/docs` for API documentation.
|
> Visit `http://localhost:8000/docs` for API documentation.
|
||||||
|
|
||||||
### CLI Demo
|
### Inference with command line
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Web Demo
|
### Inference with web browser
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluation
|
### Evaluation
|
||||||
@@ -513,9 +582,9 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--finetuning_type lora \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--template vanilla \
|
--template vanilla \
|
||||||
|
--finetuning_type lora \
|
||||||
--task mmlu \
|
--task mmlu \
|
||||||
--split test \
|
--split test \
|
||||||
--lang en \
|
--lang en \
|
||||||
@@ -528,14 +597,14 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_predict \
|
--do_predict \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
--per_device_eval_batch_size 8 \
|
--per_device_eval_batch_size 1 \
|
||||||
--max_samples 100 \
|
--max_samples 100 \
|
||||||
--predict_with_generate \
|
--predict_with_generate \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -549,10 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
## Projects using LLaMA Factory
|
## Projects using LLaMA Factory
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||||
@@ -561,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
226
README_zh.md
226
README_zh.md
@@ -5,8 +5,9 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/c2EPEt5NU)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
|
||||||
@@ -16,9 +17,7 @@
|
|||||||
|
|
||||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||||
|
|
||||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board,或者通过命令 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 本地启动。
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
|
||||||
|
|
||||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||||
|
|
||||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
|
- [项目特色](#项目特色)
|
||||||
- [性能指标](#性能指标)
|
- [性能指标](#性能指标)
|
||||||
- [更新日志](#更新日志)
|
- [更新日志](#更新日志)
|
||||||
- [模型](#模型)
|
- [模型](#模型)
|
||||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [引用](#引用)
|
- [引用](#引用)
|
||||||
- [致谢](#致谢)
|
- [致谢](#致谢)
|
||||||
|
|
||||||
|
## 项目特色
|
||||||
|
|
||||||
|
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
|
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
||||||
|
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||||
|
- **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
||||||
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
|
|
||||||
## 性能指标
|
## 性能指标
|
||||||
|
|
||||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||||
@@ -55,23 +64,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
||||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。
|
||||||
|
|
||||||
|
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||||
|
|
||||||
|
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
||||||
|
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||||
|
|
||||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
|
|
||||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||||
|
|
||||||
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||||
|
|
||||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||||
|
|
||||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
||||||
|
|
||||||
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
[23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||||
|
|
||||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||||
|
|
||||||
@@ -91,19 +112,24 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
|
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||||
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||||
@@ -123,7 +149,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
|
||||||
|
|
||||||
## 数据集
|
## 数据集
|
||||||
|
|
||||||
@@ -145,8 +171,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self Cognition (zh)](data/self_cognition.json)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
@@ -162,11 +188,14 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
|
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
|
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@@ -174,6 +203,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
|
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||||
|
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||||
|
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||||
|
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||||
|
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||||
|
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -183,6 +222,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -197,22 +237,34 @@ huggingface-cli login
|
|||||||
|
|
||||||
## 软硬件依赖
|
## 软硬件依赖
|
||||||
|
|
||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
| 必需项 | 至少 | 推荐 |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf 和 tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
| torch | 1.13.1 | 2.2.1 |
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
| transformers | 4.37.2 | 4.38.1 |
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
| datasets | 2.14.3 | 2.17.1 |
|
||||||
|
| accelerate | 0.27.2 | 0.27.2 |
|
||||||
|
| peft | 0.9.0 | 0.9.0 |
|
||||||
|
| trl | 0.7.11 | 0.7.11 |
|
||||||
|
|
||||||
|
| 可选项 | 至少 | 推荐 |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.13.4 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.5 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B |
|
\* *估算值*
|
||||||
| ------- | ---- | ----- | ----- | ----- | ------ |
|
|
||||||
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
@@ -233,15 +285,17 @@ cd LLaMA-Factory
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||||
|
|
||||||
### 使用魔搭社区(可跳过)
|
### 使用魔搭社区(可跳过)
|
||||||
|
|
||||||
如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||||
@@ -255,7 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
... # 参数同上
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board 同样支持魔搭社区的模型下载。
|
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
@@ -271,8 +325,8 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
@@ -294,8 +348,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
@@ -318,14 +372,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset comparison_gpt4_zh \
|
--dataset comparison_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
@@ -343,14 +397,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
--output_dir path_to_ppo_checkpoint \
|
--output_dir path_to_ppo_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
@@ -366,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
||||||
|
|
||||||
@@ -374,14 +431,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage dpo \
|
--stage dpo \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_train \
|
--do_train \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_sft_checkpoint \
|
||||||
|
--create_new_adapter \
|
||||||
--dataset comparison_gpt4_zh \
|
--dataset comparison_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
|
||||||
--output_dir path_to_dpo_checkpoint \
|
--output_dir path_to_dpo_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
@@ -394,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
|
||||||
|
|
||||||
### 多 GPU 分布式训练
|
### 多 GPU 分布式训练
|
||||||
|
|
||||||
#### 使用 Huggingface Accelerate
|
#### 使用 Huggingface Accelerate
|
||||||
@@ -407,6 +467,7 @@ accelerate launch src/train_bash.py # 参数同上
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
distributed_type: MULTI_GPU
|
distributed_type: MULTI_GPU
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
gpu_ids: all
|
gpu_ids: all
|
||||||
@@ -449,7 +510,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
@@ -464,48 +525,56 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 合并 LoRA 权重并导出完整模型
|
### 合并 LoRA 权重并导出模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--export_dir path_to_export \
|
||||||
--export_dir path_to_export
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
```
|
```
|
||||||
|
|
||||||
### API 服务
|
> [!WARNING]
|
||||||
|
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
|
||||||
|
|
||||||
|
### 使用 OpenAI 风格 API 推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||||
|
|
||||||
### 命令行测试
|
### 使用命令行推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 浏览器测试
|
### 使用浏览器推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 模型评估
|
### 模型评估
|
||||||
@@ -513,9 +582,9 @@ python src/web_demo.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--finetuning_type lora \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--template vanilla \
|
--template vanilla \
|
||||||
|
--finetuning_type lora \
|
||||||
--task ceval \
|
--task ceval \
|
||||||
--split validation \
|
--split validation \
|
||||||
--lang zh \
|
--lang zh \
|
||||||
@@ -528,14 +597,14 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_llama_model \
|
|
||||||
--do_predict \
|
--do_predict \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
--per_device_eval_batch_size 8 \
|
--per_device_eval_batch_size 1 \
|
||||||
--max_samples 100 \
|
--max_samples 100 \
|
||||||
--predict_with_generate \
|
--predict_with_generate \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -549,10 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
## 使用了 LLaMA Factory 的项目
|
## 使用了 LLaMA Factory 的项目
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||||
@@ -561,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
|||||||
7
SECURITY.md
Normal file
7
SECURITY.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Reporting Security Issues
|
||||||
|
|
||||||
|
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
|
||||||
|
|
||||||
|
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||||
|
|
||||||
|
Report security bugs in third-party modules to the person or team maintaining the module.
|
||||||
@@ -2,21 +2,32 @@ If you are using a custom dataset, please provide your dataset definition in the
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
|
||||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
"ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)",
|
||||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
|
||||||
|
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||||
"columns": {
|
"columns (optional)": {
|
||||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
|
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||||
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
|
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||||
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
|
"response": "the column name in the dataset containing the responses. (default: output)",
|
||||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
"history": "the column name in the dataset containing the histories. (default: None)",
|
||||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
"tools": "the column name in the dataset containing the tool description. (default: None)"
|
||||||
|
},
|
||||||
|
"tags (optional, used for the sharegpt format)": {
|
||||||
|
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||||
|
"content_tag": "the key in the message represents the content. (default: value)",
|
||||||
|
"user_tag": "the value of the role_tag represents the user. (default: human)",
|
||||||
|
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
||||||
|
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
||||||
|
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
||||||
|
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -31,6 +42,7 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
|||||||
"instruction": "user instruction (required)",
|
"instruction": "user instruction (required)",
|
||||||
"input": "user input (optional)",
|
"input": "user input (optional)",
|
||||||
"output": "model response (required)",
|
"output": "model response (required)",
|
||||||
|
"system": "system prompt (optional)",
|
||||||
"history": [
|
"history": [
|
||||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||||
@@ -47,14 +59,15 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
"prompt": "instruction",
|
"prompt": "instruction",
|
||||||
"query": "input",
|
"query": "input",
|
||||||
"response": "output",
|
"response": "output",
|
||||||
|
"system": "system",
|
||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||||
|
|
||||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
|
||||||
|
|
||||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||||
|
|
||||||
@@ -85,7 +98,9 @@ The dataset in sharegpt format should follow the below format:
|
|||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
"value": "model response"
|
"value": "model response"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"system": "system prompt (optional)",
|
||||||
|
"tools": "tool description (optional)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@@ -96,12 +111,18 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
"dataset_name": {
|
"dataset_name": {
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"system": "system",
|
||||||
"content": "value"
|
"tools": "tools"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "from",
|
||||||
|
"content_tag": "value",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "gpt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
|
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||||
|
|
||||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||||
|
|||||||
@@ -2,21 +2,32 @@
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
"数据集名称": {
|
"数据集名称": {
|
||||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||||
|
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
|
||||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||||
"columns": {
|
"columns(可选)": {
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)",
|
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||||
"query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)",
|
"query": "数据集代表请求的表头名称(默认:input)",
|
||||||
"response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)",
|
"response": "数据集代表回答的表头名称(默认:output)",
|
||||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
"history": "数据集代表历史对话的表头名称(默认:None)",
|
||||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
"tools": "数据集代表工具描述的表头名称(默认:None)"
|
||||||
|
},
|
||||||
|
"tags(可选,用于 sharegpt 格式)": {
|
||||||
|
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||||
|
"content_tag": "消息中代表文本内容的键名(默认:value)",
|
||||||
|
"user_tag": "消息中代表用户的 role_tag(默认:human)",
|
||||||
|
"assistant_tag": "消息中代表助手的 role_tag(默认:gpt)",
|
||||||
|
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||||
|
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||||
|
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -31,6 +42,7 @@
|
|||||||
"instruction": "用户指令(必填)",
|
"instruction": "用户指令(必填)",
|
||||||
"input": "用户输入(选填)",
|
"input": "用户输入(选填)",
|
||||||
"output": "模型回答(必填)",
|
"output": "模型回答(必填)",
|
||||||
|
"system": "系统提示词(选填)",
|
||||||
"history": [
|
"history": [
|
||||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||||
@@ -47,14 +59,15 @@
|
|||||||
"prompt": "instruction",
|
"prompt": "instruction",
|
||||||
"query": "input",
|
"query": "input",
|
||||||
"response": "output",
|
"response": "output",
|
||||||
|
"system": "system",
|
||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||||
|
|
||||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
|
||||||
|
|
||||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||||
|
|
||||||
@@ -85,7 +98,9 @@
|
|||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
"value": "模型回答"
|
"value": "模型回答"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"system": "系统提示词(选填)",
|
||||||
|
"tools": "工具描述(选填)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@@ -96,12 +111,18 @@
|
|||||||
"数据集名称": {
|
"数据集名称": {
|
||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"system": "system",
|
||||||
"content": "value"
|
"tools": "tools"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "from",
|
||||||
|
"content_tag": "value",
|
||||||
|
"user_tag": "human",
|
||||||
|
"assistant_tag": "gpt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||||
|
|
||||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
||||||
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
4748dff00d1dc42768a5b6cc772143c313017812
|
||||||
@@ -1 +0,0 @@
|
|||||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
|
||||||
29
examples/full_multi_gpu/sft.sh
Normal file
29
examples/full_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||||
|
--deepspeed ds_z3_config.json \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
16
examples/lora_multi_gpu/config.yaml
Normal file
16
examples/lora_multi_gpu/config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
30
examples/lora_multi_gpu/sft.sh
Normal file
30
examples/lora_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file config.yaml ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
33
examples/lora_single_gpu/dpo.sh
Normal file
33
examples/lora_single_gpu/dpo.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--dpo_ftx 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
31
examples/lora_single_gpu/ppo.sh
Normal file
31
examples/lora_single_gpu/ppo.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage ppo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset alpaca_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--reward_model ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 512 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--top_k 0 \
|
||||||
|
--top_p 0.9 \
|
||||||
|
--max_new_tokens 256 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
18
examples/lora_single_gpu/predict.sh
Normal file
18
examples/lora_single_gpu/predict.sh
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_predict \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/predict \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--max_samples 20 \
|
||||||
|
--predict_with_generate
|
||||||
29
examples/lora_single_gpu/pretrain.sh
Normal file
29
examples/lora_single_gpu/pretrain.sh
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage pt \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset c4_demo \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 10000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
31
examples/lora_single_gpu/reward.sh
Normal file
31
examples/lora_single_gpu/reward.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage rm \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 5000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/lora_single_gpu/sft.sh
Normal file
30
examples/lora_single_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/aqlm.sh
Normal file
30
examples/qlora_single_gpu/aqlm.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/awq.sh
Normal file
30
examples/qlora_single_gpu/awq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/gptq.sh
Normal file
30
examples/qlora_single_gpu/gptq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
@@ -1,3 +1,32 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61.0"]
|
requires = ["setuptools>=61.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py38"
|
||||||
|
line-length = 119
|
||||||
|
indent-width = 4
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||||
|
select = ["C", "E", "F", "I", "W"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
lines-after-imports = 2
|
||||||
|
known-first-party = ["llmtuner"]
|
||||||
|
known-third-party = [
|
||||||
|
"accelerate",
|
||||||
|
"datasets",
|
||||||
|
"gradio",
|
||||||
|
"numpy",
|
||||||
|
"peft",
|
||||||
|
"torch",
|
||||||
|
"transformers",
|
||||||
|
"trl"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.31.0,<4.35.0
|
transformers>=4.37.2
|
||||||
datasets>=2.14.0
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.27.2
|
||||||
peft>=0.6.0
|
peft>=0.9.0
|
||||||
trl>=0.7.4
|
trl>=0.7.11
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
|
||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import ChatModel, create_app
|
from llmtuner import ChatModel, create_app
|
||||||
@@ -6,8 +8,8 @@ from llmtuner import ChatModel, create_app
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
print("Visit http://localhost:8000/docs for API document.")
|
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
if platform.system() != "Windows":
|
||||||
import readline
|
import readline # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Install `readline` for a better experience.")
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
history = []
|
messages = []
|
||||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -27,20 +29,20 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if query.strip() == "clear":
|
if query.strip() == "clear":
|
||||||
history = []
|
messages = []
|
||||||
torch_gc()
|
torch_gc()
|
||||||
print("History has been removed.")
|
print("History has been removed.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
print("Assistant: ", end="", flush=True)
|
print("Assistant: ", end="", flush=True)
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in chat_model.stream_chat(query, history):
|
for new_text in chat_model.stream_chat(messages):
|
||||||
print(new_text, end="", flush=True)
|
print(new_text, end="", flush=True)
|
||||||
response += new_text
|
response += new_text
|
||||||
print()
|
print()
|
||||||
|
messages.append({"role": "assistant", "content": response})
|
||||||
history = history + [(query, response)]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||||
|
|
||||||
from llmtuner.api import create_app
|
from .api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from .chat import ChatModel
|
||||||
from llmtuner.eval import Evaluator
|
from .eval import Evaluator
|
||||||
from llmtuner.train import export_model, run_exp
|
from .train import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from .webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.3.3"
|
__version__ = "0.5.3"
|
||||||
|
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.api.app import create_app
|
from .app import create_app
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
|
|||||||
@@ -1,28 +1,31 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import List, Tuple
|
import os
|
||||||
from pydantic import BaseModel
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
from llmtuner.api.protocol import (
|
from pydantic import BaseModel
|
||||||
Role,
|
|
||||||
Finish,
|
from ..chat import ChatModel
|
||||||
ModelCard,
|
from ..data import Role as DataRole
|
||||||
ModelList,
|
from ..extras.misc import torch_gc
|
||||||
ChatMessage,
|
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
DeltaMessage,
|
from .protocol import (
|
||||||
|
ChatCompletionMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage,
|
ChatCompletionResponseUsage,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
Finish,
|
||||||
|
Function,
|
||||||
|
FunctionCall,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
Role,
|
||||||
ScoreEvaluationRequest,
|
ScoreEvaluationRequest,
|
||||||
ScoreEvaluationResponse
|
ScoreEvaluationResponse,
|
||||||
)
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
|
||||||
from llmtuner.extras.packages import (
|
|
||||||
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -40,15 +43,22 @@ if is_uvicorn_available():
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def to_json(data: BaseModel) -> str:
|
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||||
try: # pydantic v2
|
try: # pydantic v2
|
||||||
|
return data.model_dump(exclude_unset=True)
|
||||||
|
except AttributeError: # pydantic v1
|
||||||
|
return data.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
|
def jsonify(data: "BaseModel") -> str:
|
||||||
|
try: # pydantic v2
|
||||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||||
except: # pydantic v1
|
except AttributeError: # pydantic v1
|
||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +73,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
role_mapping = {
|
||||||
|
Role.USER: DataRole.USER.value,
|
||||||
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
|
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||||
|
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||||
|
Role.TOOL: DataRole.OBSERVATION.value,
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||||
@@ -73,92 +92,123 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
if not chat_model.can_generate:
|
if not chat_model.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
|
||||||
query = request.messages[-1].content
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
prev_messages = request.messages[:-1]
|
system = request.messages.pop(0).content
|
||||||
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
|
||||||
system = prev_messages.pop(0).content
|
|
||||||
else:
|
else:
|
||||||
system = None
|
system = ""
|
||||||
|
|
||||||
history = []
|
if len(request.messages) % 2 == 0:
|
||||||
if len(prev_messages) % 2 == 0:
|
|
||||||
for i in range(0, len(prev_messages), 2):
|
|
||||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
|
input_messages = []
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
|
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||||
|
|
||||||
|
tool_list = request.tools
|
||||||
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
|
try:
|
||||||
|
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
|
else:
|
||||||
|
tools = ""
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
|
||||||
|
|
||||||
|
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, system, request)
|
if tools:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
|
generate = stream_chat_completion(messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
query, history, system,
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
num_return_sequences=request.n
|
num_return_sequences=request.n,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_length, response_length = 0, 0
|
prompt_length, response_length = 0, 0
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
choices.append(ChatCompletionResponseChoice(
|
if tools:
|
||||||
index=i,
|
result = chat_model.template.format_tools.extract(response.response_text)
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
else:
|
||||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
result = response.response_text
|
||||||
))
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
name, arguments = result
|
||||||
|
function = Function(name=name, arguments=arguments)
|
||||||
|
response_message = ChatCompletionMessage(
|
||||||
|
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||||
|
)
|
||||||
|
finish_reason = Finish.TOOL
|
||||||
|
else:
|
||||||
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||||
|
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||||
|
)
|
||||||
prompt_length = response.prompt_length
|
prompt_length = response.prompt_length
|
||||||
response_length += response.response_length
|
response_length += response.response_length
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
total_tokens=prompt_length+response_length
|
total_tokens=prompt_length + response_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
def stream_chat_completion(
|
||||||
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
query, history, system,
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens
|
max_new_tokens=request.max_tokens,
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_text) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
||||||
delta=DeltaMessage(content=new_text),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason=Finish.STOP
|
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield to_json(chunk)
|
yield jsonify(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
@@ -168,7 +218,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, get_score, request)
|
||||||
|
|
||||||
|
def get_score(request: ScoreEvaluationRequest):
|
||||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
@@ -178,4 +233,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||||
|
|||||||
@@ -1,30 +1,48 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum, unique
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Role(str, Enum):
|
class Role(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
FUNCTION = "function"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class Finish(str, Enum):
|
class Finish(str, Enum):
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
LENGTH = "length"
|
LENGTH = "length"
|
||||||
|
TOOL = "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: Optional[str] = "model"
|
object: Literal["model"] = "model"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: Optional[str] = "owner"
|
owned_by: Literal["owner"] = "owner"
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
object: Optional[str] = "list"
|
object: Literal["list"] = "list"
|
||||||
data: Optional[List[ModelCard]] = []
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
id: Literal["call_default"] = "call_default"
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@@ -32,31 +50,33 @@ class ChatMessage(BaseModel):
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[Role] = None
|
role: Optional[Role] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[FunctionCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
do_sample: Optional[bool] = True
|
tools: Optional[list] = []
|
||||||
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
n: Optional[int] = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatCompletionMessage
|
||||||
finish_reason: Finish
|
finish_reason: Finish
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: ChatCompletionMessage
|
||||||
finish_reason: Optional[Finish] = None
|
finish_reason: Optional[Finish] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -67,18 +87,18 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
usage: ChatCompletionResponseUsage
|
usage: ChatCompletionResponseUsage
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||||
object: Optional[str] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|
||||||
@@ -90,7 +110,7 @@ class ScoreEvaluationRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ScoreEvaluationResponse(BaseModel):
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
id: Optional[str] = "scoreeval-default"
|
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||||
object: Optional[str] = "score.evaluation"
|
object: Literal["score.evaluation"] = "score.evaluation"
|
||||||
model: str
|
model: str
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.chat.chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ChatModel"]
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
import torch
|
|
||||||
import tiktoken
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
from ..hparams import get_infer_args
|
||||||
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Response:
|
class Response:
|
||||||
|
|
||||||
response_text: str
|
response_text: str
|
||||||
response_length: int
|
response_length: int
|
||||||
prompt_length: int
|
prompt_length: int
|
||||||
@@ -20,28 +20,26 @@ class Response:
|
|||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
self.can_generate = (finetuning_args.stage == "sft")
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
)
|
)
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
self.system_prompt = data_args.system_prompt
|
|
||||||
|
|
||||||
def _process_args(
|
def _process_args(
|
||||||
self,
|
self,
|
||||||
query: str,
|
messages: Sequence[Dict[str, str]],
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
system = system or self.system_prompt
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
prompt, _ = self.template.encode_oneturn(
|
prompt, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
)
|
)
|
||||||
prompt_length = len(prompt)
|
prompt_length = len(prompt)
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||||
@@ -56,16 +54,18 @@ class ChatModel:
|
|||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
generating_args = self.generating_args.to_dict()
|
generating_args = self.generating_args.to_dict()
|
||||||
generating_args.update(dict(
|
generating_args.update(
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
dict(
|
||||||
temperature=temperature or generating_args["temperature"],
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
top_p=top_p or generating_args["top_p"],
|
temperature=temperature or generating_args["temperature"],
|
||||||
top_k=top_k or generating_args["top_k"],
|
top_p=top_p or generating_args["top_p"],
|
||||||
num_return_sequences=num_return_sequences or 1,
|
top_k=top_k or generating_args["top_k"],
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
num_return_sequences=num_return_sequences or 1,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
pad_token_id=self.tokenizer.pad_token_id
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
))
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
generating_args["do_sample"] = True
|
generating_args["do_sample"] = True
|
||||||
@@ -81,7 +81,7 @@ class ChatModel:
|
|||||||
gen_kwargs = dict(
|
gen_kwargs = dict(
|
||||||
inputs=input_ids,
|
inputs=input_ids,
|
||||||
generation_config=GenerationConfig(**generating_args),
|
generation_config=GenerationConfig(**generating_args),
|
||||||
logits_processor=get_logits_processor()
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
@@ -89,17 +89,15 @@ class ChatModel:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
messages: Sequence[Dict[str, str]],
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
r"""
|
if not self.can_generate:
|
||||||
Args: query, history, system, **input_kwargs
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
|
||||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||||
"""
|
|
||||||
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
generate_output = self.model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = self.tokenizer.batch_decode(
|
response = self.tokenizer.batch_decode(
|
||||||
@@ -109,24 +107,29 @@ class ChatModel:
|
|||||||
for i in range(len(response)):
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
results.append(Response(
|
results.append(
|
||||||
response_text=response[i],
|
Response(
|
||||||
response_length=response_length,
|
response_text=response[i],
|
||||||
prompt_length=prompt_length,
|
response_length=response_length,
|
||||||
finish_reason="stop" if len(eos_index) else "length"
|
prompt_length=prompt_length,
|
||||||
))
|
finish_reason="stop" if len(eos_index) else "length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
messages: Sequence[Dict[str, str]],
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
|
||||||
|
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
||||||
@@ -136,27 +139,19 @@ class ChatModel:
|
|||||||
yield from streamer
|
yield from streamer
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_scores(
|
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
||||||
self,
|
if self.can_generate:
|
||||||
batch_input: List[str],
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
**input_kwargs
|
|
||||||
) -> List[float]:
|
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
batch_input,
|
batch_input,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||||
pad_to_multiple_of=8,
|
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs
|
add_special_tokens=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from llmtuner.data.loader import get_dataset
|
from .loader import get_dataset
|
||||||
from llmtuner.data.preprocess import preprocess_dataset
|
from .template import get_template_and_fix_tokenizer, templates
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from .utils import Role, split_dataset
|
||||||
from llmtuner.data.utils import split_dataset
|
|
||||||
|
|
||||||
|
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
||||||
|
|||||||
133
src/llmtuner/data/aligner.py
Normal file
133
src/llmtuner/data/aligner.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
|
from datasets import Features
|
||||||
|
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
|
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
|
prompt = []
|
||||||
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||||
|
content.append(examples[dataset_attr.prompt][i])
|
||||||
|
|
||||||
|
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||||
|
content.append(examples[dataset_attr.query][i])
|
||||||
|
|
||||||
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||||
|
|
||||||
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
|
]
|
||||||
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
|
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||||
|
else:
|
||||||
|
response = []
|
||||||
|
|
||||||
|
outputs["prompt"].append(prompt)
|
||||||
|
outputs["response"].append(response)
|
||||||
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
|
outputs["tools"].append("")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
|
tag_mapping = {
|
||||||
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
|
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
|
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
|
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
|
}
|
||||||
|
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||||
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
|
accept_tags = (odd_tags, even_tags)
|
||||||
|
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||||
|
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||||
|
system = messages[0][dataset_attr.content_tag]
|
||||||
|
messages = messages[1:]
|
||||||
|
else:
|
||||||
|
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||||
|
|
||||||
|
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||||
|
if len(messages) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
aligned_messages = []
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
|
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||||
|
|
||||||
|
aligned_messages.append(
|
||||||
|
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs["prompt"].append(aligned_messages[:-1])
|
||||||
|
outputs["response"].append(aligned_messages[-1:])
|
||||||
|
outputs["system"].append(system)
|
||||||
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def align_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Aligned dataset:
|
||||||
|
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||||
|
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||||
|
system: "..."
|
||||||
|
tools: "..."
|
||||||
|
"""
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||||
|
else:
|
||||||
|
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||||
|
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
features = Features.from_dict(
|
||||||
|
{
|
||||||
|
"prompt": [
|
||||||
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||||
|
],
|
||||||
|
"response": [
|
||||||
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||||
|
],
|
||||||
|
"system": {"dtype": "string", "_type": "Value"},
|
||||||
|
"tools": {"dtype": "string", "_type": "Value"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Converting format of dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset.map(
|
||||||
|
convert_func,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
features=features,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
155
src/llmtuner/data/formatter.py
Normal file
155
src/llmtuner/data/formatter.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
JSON_FORMAT_PROMPT = (
|
||||||
|
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_SYSTEM_PROMPT = (
|
||||||
|
"You have access to the following tools:\n{tool_text}"
|
||||||
|
"Use the following format if using a tool:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
|
"Action Input: the input to the tool{format_prompt}.\n"
|
||||||
|
"```\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
tool_names = []
|
||||||
|
for tool in tools:
|
||||||
|
param_text = ""
|
||||||
|
for name, param in tool["parameters"]["properties"].items():
|
||||||
|
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||||
|
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||||
|
items = (
|
||||||
|
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||||
|
)
|
||||||
|
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||||
|
name=name,
|
||||||
|
type=param.get("type", ""),
|
||||||
|
required=required,
|
||||||
|
desc=param.get("description", ""),
|
||||||
|
enum=enum,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||||
|
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||||
|
)
|
||||||
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
|
return TOOL_SYSTEM_PROMPT.format(
|
||||||
|
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||||
|
action_match = re.search(regex, content)
|
||||||
|
if not action_match:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_name = action_match.group(1).strip()
|
||||||
|
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Formatter(ABC):
|
||||||
|
slots: SLOTS = field(default_factory=list)
|
||||||
|
tool_format: Literal["default"] = "default"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
...
|
||||||
|
|
||||||
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmptyFormatter(Formatter):
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
return self.slots
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StringFormatter(Formatter):
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
elements = []
|
||||||
|
for slot in self.slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise RuntimeError("Expected a string, got {}".format(value))
|
||||||
|
|
||||||
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
|
elements.append(slot)
|
||||||
|
elif isinstance(slot, (dict, set)):
|
||||||
|
elements.append(slot)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionFormatter(Formatter):
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
content = kwargs.pop("content")
|
||||||
|
try:
|
||||||
|
function = json.loads(content)
|
||||||
|
name = function["name"]
|
||||||
|
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
name, arguments = "", ""
|
||||||
|
|
||||||
|
elements = []
|
||||||
|
for slot in self.slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
|
elements.append(slot)
|
||||||
|
elif isinstance(slot, (dict, set)):
|
||||||
|
elements.append(slot)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolFormatter(Formatter):
|
||||||
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
content = kwargs.pop("content")
|
||||||
|
try:
|
||||||
|
tools = json.loads(content)
|
||||||
|
if not len(tools):
|
||||||
|
return [""]
|
||||||
|
|
||||||
|
if self.tool_format == "default":
|
||||||
|
return [default_tool_formatter(tools)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
except Exception:
|
||||||
|
return [""]
|
||||||
|
|
||||||
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
if self.tool_format == "default":
|
||||||
|
return default_tool_extractor(content)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,135 +1,122 @@
|
|||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, List, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from .aligner import align_dataset
|
||||||
|
from .parser import get_dataset_list
|
||||||
|
from .preprocess import get_preprocess_and_print_func
|
||||||
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
from .utils import checksum
|
||||||
|
|
||||||
from llmtuner.data.utils import checksum, EXT2TYPE
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..hparams import DataArguments, ModelArguments
|
||||||
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def load_single_dataset(
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
):
|
||||||
max_samples = data_args.max_samples
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
|
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||||
|
data_path = dataset_attr.dataset_name
|
||||||
|
data_name = dataset_attr.subset
|
||||||
|
data_dir = dataset_attr.folder
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
elif dataset_attr.load_from == "script":
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
|
data_name = dataset_attr.subset
|
||||||
|
data_dir = dataset_attr.folder
|
||||||
|
|
||||||
if dataset_attr.load_from == "hf_hub":
|
elif dataset_attr.load_from == "file":
|
||||||
data_path = dataset_attr.dataset_name
|
data_files = []
|
||||||
data_name = dataset_attr.subset
|
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
data_files = None
|
if os.path.isdir(local_path): # is directory
|
||||||
elif dataset_attr.load_from == "script":
|
for file_name in os.listdir(local_path):
|
||||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
data_name = dataset_attr.subset
|
if data_path is None:
|
||||||
data_files = None
|
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
elif dataset_attr.load_from == "file":
|
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||||
data_path, data_name = None, None
|
raise ValueError("File types should be identical.")
|
||||||
data_files: List[str] = []
|
elif os.path.isfile(local_path): # is file
|
||||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
data_files.append(local_path)
|
||||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
|
||||||
if data_path is None:
|
|
||||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
|
||||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
|
||||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
raise ValueError("File not found.")
|
|
||||||
|
|
||||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
|
||||||
checksum(data_files, dataset_attr.dataset_sha1)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
|
if data_path is None:
|
||||||
|
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||||
|
|
||||||
|
checksum(data_files, dataset_attr.file_sha1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if dataset_attr.load_from == "ms_hub":
|
||||||
|
try:
|
||||||
|
from modelscope import MsDataset
|
||||||
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
|
|
||||||
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
|
dataset = MsDataset.load(
|
||||||
|
dataset_name=data_path,
|
||||||
|
subset_name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=data_args.split,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
token=model_args.ms_hub_token,
|
||||||
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
|
).to_hf_dataset()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
else:
|
||||||
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
num_samples = min(data_args.max_samples, len(dataset))
|
||||||
|
dataset = dataset.select(range(num_samples))
|
||||||
|
|
||||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
# convert dataset from sharegpt format to alpaca format
|
|
||||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
|
||||||
for msg_list in examples[dataset_attr.messages]:
|
|
||||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
|
||||||
if len(msg_list) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
msg_pairs = []
|
|
||||||
user_role, assistant_role = None, None
|
|
||||||
for idx in range(0, len(msg_list), 2):
|
|
||||||
if user_role is None and assistant_role is None:
|
|
||||||
user_role = msg_list[idx][dataset_attr.role]
|
|
||||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
msg_list[idx][dataset_attr.role] != user_role
|
|
||||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
|
||||||
):
|
|
||||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
|
||||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
|
||||||
|
|
||||||
if len(msg_pairs) != 0:
|
def merge_dataset(
|
||||||
outputs["prompt"].append(msg_pairs[-1][0])
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
outputs["query"].append("")
|
data_args: "DataArguments",
|
||||||
outputs["response"].append(msg_pairs[-1][1])
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
outputs["history"].append(msg_pairs[:-1])
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
if len(all_datasets) == 1:
|
||||||
return outputs
|
|
||||||
|
|
||||||
if dataset_attr.formatting == "sharegpt": # convert format
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Converting format of dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
convert_format,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
|
||||||
|
|
||||||
if dataset_attr.system_prompt: # add system prompt
|
|
||||||
system_prompt = dataset_attr.system_prompt
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
|
||||||
else:
|
|
||||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
@@ -141,8 +128,64 @@ def get_dataset(
|
|||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
probabilities=data_args.interleave_probs,
|
probabilities=data_args.interleave_probs,
|
||||||
seed=data_args.seed,
|
seed=training_args.seed,
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
# split: Optional[str] = "train", # TODO: add split
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
|
# Load from cache
|
||||||
|
if data_args.cache_path is not None:
|
||||||
|
if os.path.exists(data_args.cache_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
dataset = load_from_disk(data_args.cache_path)
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.to_iterable_dataset()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
|
all_datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(data_args):
|
||||||
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||||
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
|
tokenizer, template, data_args, training_args, stage
|
||||||
|
)
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache),
|
||||||
|
desc="Running tokenizer on dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
|
|
||||||
|
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||||
|
if training_args.should_save:
|
||||||
|
dataset.save_to_disk(data_args.cache_path)
|
||||||
|
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
try:
|
||||||
|
print_function(next(iter(dataset)))
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|||||||
119
src/llmtuner/data/parser.py
Normal file
119
src/llmtuner/data/parser.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from ..extras.constants import DATA_CONFIG
|
||||||
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetAttr:
|
||||||
|
r"""
|
||||||
|
Dataset attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
""" basic configs """
|
||||||
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
|
dataset_name: Optional[str] = None
|
||||||
|
""" extra configs """
|
||||||
|
file_sha1: Optional[str] = None
|
||||||
|
subset: Optional[str] = None
|
||||||
|
folder: Optional[str] = None
|
||||||
|
ranking: Optional[bool] = False
|
||||||
|
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||||
|
""" columns """
|
||||||
|
system: Optional[str] = None
|
||||||
|
""" columns for the alpaca format """
|
||||||
|
prompt: Optional[str] = "instruction"
|
||||||
|
query: Optional[str] = "input"
|
||||||
|
response: Optional[str] = "output"
|
||||||
|
history: Optional[str] = None
|
||||||
|
""" columns for the sharegpt format """
|
||||||
|
messages: Optional[str] = "conversations"
|
||||||
|
tools: Optional[str] = None
|
||||||
|
""" tags for the sharegpt format """
|
||||||
|
role_tag: Optional[str] = "from"
|
||||||
|
content_tag: Optional[str] = "value"
|
||||||
|
user_tag: Optional[str] = "human"
|
||||||
|
assistant_tag: Optional[str] = "gpt"
|
||||||
|
observation_tag: Optional[str] = "observation"
|
||||||
|
function_tag: Optional[str] = "function_call"
|
||||||
|
system_tag: Optional[str] = "system"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.dataset_name
|
||||||
|
|
||||||
|
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||||
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
|
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||||
|
try:
|
||||||
|
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
except Exception as err:
|
||||||
|
if data_args.dataset is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||||
|
)
|
||||||
|
dataset_info = None
|
||||||
|
|
||||||
|
if data_args.interleave_probs is not None:
|
||||||
|
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||||
|
|
||||||
|
dataset_list: List[DatasetAttr] = []
|
||||||
|
for name in dataset_names:
|
||||||
|
if name not in dataset_info:
|
||||||
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
|
if has_hf_url or has_ms_url:
|
||||||
|
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
elif "script_url" in dataset_info[name]:
|
||||||
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||||
|
|
||||||
|
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
|
|
||||||
|
if "columns" in dataset_info[name]:
|
||||||
|
column_names = ["system"]
|
||||||
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
column_names.extend(["prompt", "query", "response", "history"])
|
||||||
|
else:
|
||||||
|
column_names.extend(["messages", "tools"])
|
||||||
|
|
||||||
|
for column_name in column_names:
|
||||||
|
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||||
|
|
||||||
|
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||||
|
tag_names = (
|
||||||
|
"role_tag",
|
||||||
|
"content_tag",
|
||||||
|
"user_tag",
|
||||||
|
"assistant_tag",
|
||||||
|
"observation_tag",
|
||||||
|
"function_tag",
|
||||||
|
"system_tag",
|
||||||
|
)
|
||||||
|
for tag in tag_names:
|
||||||
|
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||||
|
|
||||||
|
dataset_list.append(dataset_attr)
|
||||||
|
|
||||||
|
return dataset_list
|
||||||
@@ -1,275 +1,269 @@
|
|||||||
import os
|
from functools import partial
|
||||||
import tiktoken
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||||
|
|
||||||
from datasets import load_from_disk
|
from ..extras.constants import IGNORE_INDEX
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from .utils import Role
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def preprocess_pretrain_dataset(
|
||||||
for i in range(len(examples["prompt"])):
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
) -> Dict[str, List[List[int]]]:
|
||||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
# build grouped texts with format `X1 X2 X3 ...`
|
||||||
history = examples["history"][i] if "history" in examples else None
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
system = examples["system"][i] if "system" in examples else None
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
yield query, response, history, system
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
|
block_size = data_args.cutoff_len
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
result = {
|
||||||
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
|
for k, t in concatenated_examples.items()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
def preprocess_supervised_dataset(
|
||||||
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
examples: Dict[str, List[Any]],
|
||||||
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
|
||||||
max_source_len = data_args.cutoff_len - max_target_len
|
|
||||||
return max_source_len, max_target_len
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
) -> Dict[str, List[List[int]]]:
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
for i in range(len(examples["prompt"])):
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
|
||||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=True)
|
|
||||||
|
|
||||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
|
||||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
|
||||||
setattr(tokenizer, "add_eos_token", True)
|
|
||||||
|
|
||||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
|
||||||
block_size = data_args.cutoff_len
|
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of cutoff_len
|
|
||||||
result = {
|
|
||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
|
||||||
for k, t in concatenated_examples.items()
|
|
||||||
}
|
|
||||||
# make sure the saved tokenizer is the same as the original one
|
|
||||||
if hasattr(tokenizer, "add_eos_token"):
|
|
||||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids, labels = [], []
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
|
||||||
tokenizer, query, response, history, system
|
|
||||||
)):
|
|
||||||
source_len, target_len = len(source_ids), len(target_ids)
|
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
|
||||||
if source_len > max_source_len:
|
|
||||||
source_ids = source_ids[:max_source_len]
|
|
||||||
if target_len > max_target_len:
|
|
||||||
target_ids = target_ids[:max_target_len]
|
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
|
||||||
source_mask = source_ids
|
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
input_ids += [tokenizer.eos_token_id]
|
|
||||||
labels += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
|
||||||
labels = labels[:data_args.cutoff_len]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
|
||||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
template.encode_multiturn(
|
||||||
continue
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
input_ids += source_ids + target_ids
|
||||||
tokenizer, query, response, history, system
|
labels += source_mask + target_ids
|
||||||
)):
|
|
||||||
if data_args.train_on_prompt:
|
|
||||||
source_mask = source_ids
|
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
|
||||||
else:
|
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
|
||||||
input_ids += source_ids + target_ids
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
input_ids += [tokenizer.eos_token_id]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
labels += [tokenizer.eos_token_id]
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_length = len(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
block_size = data_args.cutoff_len
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
model_inputs["labels"].append(labels)
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of cutoff_len
|
return model_inputs
|
||||||
for i in range(0, total_length, block_size):
|
|
||||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
|
||||||
|
def preprocess_packed_supervised_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
input_ids, labels = [], []
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
for source_ids, target_ids in template.encode_multiturn(
|
||||||
|
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif len(input_ids) != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
total_length = len(input_ids)
|
||||||
|
block_size = data_args.cutoff_len
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of cutoff_len
|
||||||
|
for i in range(0, total_length, block_size):
|
||||||
|
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||||
|
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
model_inputs["labels"].append(labels[i: i + block_size])
|
model_inputs["labels"].append(labels[i : i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
def preprocess_unsupervised_dataset(
|
||||||
if not (isinstance(query, str) and query != ""):
|
examples: Dict[str, List[Any]],
|
||||||
continue
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
if template.efficient_eos:
|
if len(examples["response"][i]) == 1:
|
||||||
labels += [tokenizer.eos_token_id]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
else:
|
||||||
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
if len(input_ids) > data_args.cutoff_len:
|
input_ids, labels = template.encode_oneturn(
|
||||||
input_ids = input_ids[:data_args.cutoff_len]
|
tokenizer,
|
||||||
if len(labels) > data_args.cutoff_len:
|
messages,
|
||||||
labels = labels[:data_args.cutoff_len]
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
model_inputs["input_ids"].append(input_ids)
|
data_args.cutoff_len,
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
data_args.reserved_label_len,
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
|
||||||
for query, response, history, system in construct_example(examples):
|
|
||||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
|
||||||
continue
|
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
|
||||||
chosen_ids += [tokenizer.eos_token_id]
|
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
|
|
||||||
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
|
||||||
if source_len > max_source_len:
|
|
||||||
prompt_ids = prompt_ids[:max_source_len]
|
|
||||||
if target_len > max_target_len:
|
|
||||||
chosen_ids = chosen_ids[:max_target_len]
|
|
||||||
rejected_ids = rejected_ids[:max_target_len]
|
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print("labels:\n{}".format(
|
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
|
||||||
))
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
|
||||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
|
||||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
|
||||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
|
||||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
|
||||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
if stage == "pt":
|
|
||||||
preprocess_func = preprocess_pretrain_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
|
||||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
|
||||||
print_function = print_supervised_dataset_example
|
|
||||||
elif stage == "rm":
|
|
||||||
preprocess_func = preprocess_pairwise_dataset
|
|
||||||
print_function = print_pairwise_dataset_example
|
|
||||||
else:
|
|
||||||
preprocess_func = preprocess_unsupervised_dataset
|
|
||||||
print_function = print_unsupervised_dataset_example
|
|
||||||
|
|
||||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
|
||||||
return load_from_disk(data_args.cache_path)
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
|
||||||
desc="Running tokenizer on dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
preprocess_func,
|
|
||||||
batched=True,
|
|
||||||
remove_columns=column_names,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
if template.efficient_eos:
|
||||||
if training_args.should_save:
|
labels += [tokenizer.eos_token_id]
|
||||||
dataset.save_to_disk(data_args.cache_path)
|
|
||||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
|
||||||
|
|
||||||
if training_args.should_log:
|
model_inputs["input_ids"].append(input_ids)
|
||||||
try:
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
print_function(next(iter(dataset)))
|
model_inputs["labels"].append(labels)
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("Empty dataset!")
|
|
||||||
|
|
||||||
return dataset
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_pairwise_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
|
tokenizer,
|
||||||
|
chosen_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
_, rejected_ids = template.encode_oneturn(
|
||||||
|
tokenizer,
|
||||||
|
rejected_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print(
|
||||||
|
"labels:\n{}".format(
|
||||||
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocess_and_print_func(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
|
if stage == "pt":
|
||||||
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
|
if data_args.sft_packing:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
|
||||||
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "rm":
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
|
else:
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
|
)
|
||||||
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
return preprocess_func, print_function
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,25 +1,27 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from enum import Enum, unique
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EXT2TYPE = {
|
@unique
|
||||||
"arrow": "arrow",
|
class Role(str, Enum):
|
||||||
"csv": "csv",
|
USER = "user"
|
||||||
"json": "json",
|
ASSISTANT = "assistant"
|
||||||
"jsonl": "json",
|
SYSTEM = "system"
|
||||||
"parquet": "parquet",
|
FUNCTION = "function"
|
||||||
"txt": "text"
|
OBSERVATION = "observation"
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
@@ -37,13 +39,18 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||||
|
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||||
|
max_target_len = max(max_target_len, reserved_label_len)
|
||||||
|
max_source_len = max_len - max_target_len
|
||||||
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "TrainingArguments"
|
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> Dict[str, "Dataset"]:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
val_set = dataset.take(int(data_args.val_size))
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
train_set = dataset.skip(int(data_args.val_size))
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
@@ -57,5 +64,5 @@ def split_dataset(
|
|||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
return {"train_dataset": dataset}
|
return {"train_dataset": dataset}
|
||||||
else: # do_eval or do_predict
|
else: # do_eval or do_predict
|
||||||
return {"eval_dataset": dataset}
|
return {"eval_dataset": dataset}
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.eval.evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Evaluator"]
|
||||||
|
|||||||
@@ -1,41 +1,34 @@
|
|||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
import tiktoken
|
import json
|
||||||
import numpy as np
|
import os
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from tqdm import tqdm, trange
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from llmtuner.eval.template import get_eval_template
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
from ..hparams import get_eval_args
|
||||||
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
from ..model import dispatch_model, load_model_and_tokenizer
|
||||||
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = self._encode_choices()
|
self.choice_inputs = [
|
||||||
|
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||||
def _encode_choices(self) -> List[int]:
|
]
|
||||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
|
||||||
kwargs = dict(allowed_special="all")
|
|
||||||
else:
|
|
||||||
kwargs = dict(add_special_tokens=False)
|
|
||||||
|
|
||||||
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
@@ -46,16 +39,11 @@ class Evaluator:
|
|||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
if "token" in inspect.signature(cached_file).parameters:
|
|
||||||
kwargs = {"token": self.model_args.hf_hub_token}
|
|
||||||
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
|
||||||
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
|
|
||||||
|
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
**kwargs
|
token=self.model_args.hf_hub_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
@@ -65,37 +53,45 @@ class Evaluator:
|
|||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
results = {}
|
results = {}
|
||||||
for subject in pbar:
|
for subject in pbar:
|
||||||
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
name=subject,
|
name=subject,
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
download_mode=self.eval_args.download_mode,
|
download_mode=self.eval_args.download_mode,
|
||||||
token=self.model_args.hf_hub_token
|
token=self.model_args.hf_hub_token,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
support_set = (
|
||||||
query, resp, history = self.eval_template.format_example(
|
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
|
)
|
||||||
|
messages = self.eval_template.format_example(
|
||||||
target_data=dataset[self.data_args.split][i],
|
target_data=dataset[self.data_args.split][i],
|
||||||
support_set=support_set,
|
support_set=support_set,
|
||||||
subject_name=categorys[subject]["name"],
|
subject_name=categorys[subject]["name"],
|
||||||
use_history=self.template.use_history
|
|
||||||
)
|
)
|
||||||
input_ids, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
|
||||||
)
|
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
|
||||||
labels.append(resp)
|
|
||||||
|
|
||||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
||||||
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
|
labels.append(messages[-1]["content"])
|
||||||
|
|
||||||
|
for i in trange(
|
||||||
|
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
||||||
|
):
|
||||||
batch_input = self.tokenizer.pad(
|
batch_input = self.tokenizer.pad(
|
||||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
preds = self.batch_inference(batch_input)
|
preds = self.batch_inference(batch_input)
|
||||||
outputs += preds
|
outputs += preds
|
||||||
|
|
||||||
corrects = (np.array(outputs) == np.array(labels))
|
corrects = np.array(outputs) == np.array(labels)
|
||||||
category_name = categorys[subject]["category"]
|
category_name = categorys[subject]["category"]
|
||||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||||
@@ -105,10 +101,13 @@ class Evaluator:
|
|||||||
self._save_results(category_corrects, results)
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join([
|
score_info = "\n".join(
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
[
|
||||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
])
|
for category_name, category_correct in category_corrects.items()
|
||||||
|
if len(category_correct)
|
||||||
|
]
|
||||||
|
)
|
||||||
print(score_info)
|
print(score_info)
|
||||||
if self.eval_args.save_dir is not None:
|
if self.eval_args.save_dir is not None:
|
||||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.constants import CHOICES
|
from ..data import Role
|
||||||
|
from ..extras.constants import CHOICES
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -9,60 +11,39 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalTemplate:
|
class EvalTemplate:
|
||||||
|
|
||||||
system: str
|
system: str
|
||||||
choice: str
|
choice: str
|
||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
def parse_example(
|
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
self,
|
|
||||||
example: Dict[str, str]
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
def format_example(
|
def format_example(
|
||||||
self,
|
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||||
target_data: Dict[str, str],
|
) -> List[Dict[str, str]]:
|
||||||
support_set: "Dataset",
|
messages = []
|
||||||
subject_name: str,
|
for k in range(len(support_set)):
|
||||||
use_history: bool
|
prompt, response = self.parse_example(support_set[k])
|
||||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
messages.append({"role": Role.USER, "content": prompt})
|
||||||
query, resp = self.parse_example(target_data)
|
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
|
||||||
|
|
||||||
if len(history):
|
prompt, response = self.parse_example(target_data)
|
||||||
temp = history.pop(0)
|
messages.append({"role": Role.USER, "content": prompt})
|
||||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||||
else:
|
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||||
query = self.system.format(subject=subject_name) + query
|
return messages
|
||||||
|
|
||||||
if not use_history:
|
|
||||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
|
||||||
history = []
|
|
||||||
return query.strip(), resp, history
|
|
||||||
|
|
||||||
|
|
||||||
eval_templates: Dict[str, EvalTemplate] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_eval_template(
|
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||||
name: str,
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||||
system: str,
|
|
||||||
choice: str,
|
|
||||||
answer: str,
|
|
||||||
prefix: str
|
|
||||||
) -> None:
|
|
||||||
eval_templates[name] = EvalTemplate(
|
|
||||||
system=system,
|
|
||||||
choice=choice,
|
|
||||||
answer=answer,
|
|
||||||
prefix=prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> EvalTemplate:
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
eval_template = eval_templates.get(name, None)
|
eval_template = eval_templates.get(name, None)
|
||||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||||
return eval_template
|
return eval_template
|
||||||
@@ -73,7 +54,7 @@ register_eval_template(
|
|||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\nAnswer: ",
|
answer="\nAnswer: ",
|
||||||
prefix=" "
|
prefix=" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,5 +63,5 @@ register_eval_template(
|
|||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix="\n"
|
prefix="\n",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,56 +1,38 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
|
||||||
|
from .constants import LOG_FILE_NAME
|
||||||
|
from .logging import get_logger
|
||||||
|
from .misc import fix_valuehead_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
model.pretrained_model.config.save_pretrained(output_dir)
|
|
||||||
if model.pretrained_model.can_generate():
|
|
||||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
|
||||||
if getattr(model, "is_peft_model", False):
|
|
||||||
model.pretrained_model.save_pretrained(output_dir)
|
|
||||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
|
||||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback):
|
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_save_model_with_valuehead(
|
fix_valuehead_checkpoint(
|
||||||
model=unwrap_model(kwargs.pop("model")),
|
model=kwargs.pop("model"),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
|
safe_serialization=args.save_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the end of training.
|
|
||||||
"""
|
|
||||||
if args.should_save:
|
|
||||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.in_training = False
|
self.in_training = False
|
||||||
@@ -116,7 +98,9 @@ class LogCallback(TrainerCallback):
|
|||||||
self.cur_steps = 0
|
self.cur_steps = 0
|
||||||
self.max_steps = 0
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
def on_predict(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a successful prediction.
|
Event called after a successful prediction.
|
||||||
"""
|
"""
|
||||||
@@ -142,18 +126,22 @@ class LogCallback(TrainerCallback):
|
|||||||
epoch=state.log_history[-1].get("epoch", None),
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
elapsed_time=self.elapsed_time,
|
elapsed_time=self.elapsed_time,
|
||||||
remaining_time=self.remaining_time
|
remaining_time=self.remaining_time,
|
||||||
)
|
)
|
||||||
if self.runner is not None:
|
if self.runner is not None:
|
||||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
logger.info(
|
||||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||||
))
|
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(logs) + "\n")
|
f.write(json.dumps(logs) + "\n")
|
||||||
|
|
||||||
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_prediction_step(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Event called after a prediction step.
|
Event called after a prediction step.
|
||||||
"""
|
"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
@@ -27,8 +27,7 @@ def get_logger(name: str) -> logging.Logger:
|
|||||||
Gets a standard logger with a stream hander to stdout.
|
Gets a standard logger with a stream hander to stdout.
|
||||||
"""
|
"""
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
|
||||||
)
|
)
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
|
|||||||
@@ -1,35 +1,45 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||||
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_torch_bf16_gpu_available,
|
||||||
|
is_torch_cuda_available,
|
||||||
|
is_torch_mps_available,
|
||||||
|
is_torch_npu_available,
|
||||||
|
is_torch_xpu_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
from transformers.utils import (
|
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
is_torch_bf16_cpu_available,
|
except Exception:
|
||||||
is_torch_bf16_gpu_available,
|
_is_bf16_available = False
|
||||||
is_torch_cuda_available,
|
|
||||||
is_torch_npu_available
|
|
||||||
)
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
|
||||||
except ImportError:
|
|
||||||
_is_fp16_available = torch.cuda.is_available()
|
|
||||||
try:
|
|
||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
|
||||||
except:
|
|
||||||
_is_bf16_available = False
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import HfArgumentParser
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
r"""
|
r"""
|
||||||
Computes and stores the average and current value.
|
Computes and stores the average and current value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -68,6 +78,76 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def fix_valuehead_checkpoint(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
The model is already unwrapped.
|
||||||
|
|
||||||
|
There are three cases:
|
||||||
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||||
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
|
||||||
|
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
|
return
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
else:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
decoder_state_dict = {}
|
||||||
|
v_head_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("v_head."):
|
||||||
|
v_head_state_dict[name] = param
|
||||||
|
else:
|
||||||
|
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
|
model.pretrained_model.save_pretrained(
|
||||||
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> torch.device:
|
||||||
|
r"""
|
||||||
|
Gets the current available device.
|
||||||
|
"""
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif is_torch_cuda_available():
|
||||||
|
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
return torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_count() -> int:
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
r"""
|
r"""
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
@@ -89,17 +169,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
|
||||||
if args is not None:
|
|
||||||
return parser.parse_dict(args)
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
|
||||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
return parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
r"""
|
r"""
|
||||||
Collects GPU memory.
|
Collects GPU memory.
|
||||||
@@ -115,12 +184,11 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modelscope import snapshot_download # type: ignore
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
model_args.model_name_or_path = snapshot_download(
|
model_args.model_name_or_path = snapshot_download(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||||
revision=revision,
|
|
||||||
cache_dir=model_args.cache_dir
|
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
|
|||||||
@@ -2,59 +2,52 @@ import importlib.metadata
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
def is_package_available(name: str) -> bool:
|
def _is_package_available(name: str) -> bool:
|
||||||
return importlib.util.find_spec(name) is not None
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
def get_package_version(name: str) -> str:
|
def _get_package_version(name: str) -> str:
|
||||||
try:
|
try:
|
||||||
return importlib.metadata.version(name)
|
return importlib.metadata.version(name)
|
||||||
except:
|
except Exception:
|
||||||
return "0.0.0"
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
_fastapi_available = is_package_available("fastapi")
|
|
||||||
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
|
||||||
_jieba_available = is_package_available("jieba")
|
|
||||||
_matplotlib_available = is_package_available("matplotlib")
|
|
||||||
_nltk_available = is_package_available("nltk")
|
|
||||||
_requests_available = is_package_available("requests")
|
|
||||||
_rouge_available = is_package_available("rouge_chinese")
|
|
||||||
_starlette_available = is_package_available("sse_starlette")
|
|
||||||
_uvicorn_available = is_package_available("uvicorn")
|
|
||||||
|
|
||||||
|
|
||||||
def is_fastapi_availble():
|
def is_fastapi_availble():
|
||||||
return _fastapi_available
|
return _is_package_available("fastapi")
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn2_available():
|
def is_flash_attn2_available():
|
||||||
return _flash_attn2_available
|
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return _jieba_available
|
return _is_package_available("jieba")
|
||||||
|
|
||||||
|
|
||||||
def is_matplotlib_available():
|
def is_matplotlib_available():
|
||||||
return _matplotlib_available
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
def is_nltk_available():
|
def is_nltk_available():
|
||||||
return _nltk_available
|
return _is_package_available("nltk")
|
||||||
|
|
||||||
|
|
||||||
def is_requests_available():
|
def is_requests_available():
|
||||||
return _requests_available
|
return _is_package_available("requests")
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return _rouge_available
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return _starlette_available
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
|
def is_unsloth_available():
|
||||||
|
return _is_package_available("unsloth")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _uvicorn_available
|
return _is_package_available("uvicorn")
|
||||||
|
|||||||
@@ -1,224 +1,197 @@
|
|||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
Cache,
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.models.llama.modeling_llama import repeat_kv
|
|
||||||
except ImportError:
|
|
||||||
print("Please upgrade `transformers`.")
|
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn2_available():
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
class LlamaShiftShortAttention(LlamaAttention):
|
def llama_torch_attn_forward(
|
||||||
|
self: "LlamaAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional["Cache"] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
def forward(
|
query_states = self.q_proj(hidden_states)
|
||||||
self,
|
key_states = self.k_proj(hidden_states)
|
||||||
hidden_states: torch.Tensor,
|
value_states = self.v_proj(hidden_states)
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
kv_seq_len = key_states.shape[-2]
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
if past_key_value is not None:
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
if past_key_value is not None:
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
if past_key_value is not None:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
if past_key_value is not None: # reuse k, v, self_attention
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
if getattr(self, "num_key_value_groups"):
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
state = torch.cat(
|
||||||
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
dim=2,
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
|
||||||
num_groups = q_len // groupsz
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
|
||||||
state = torch.cat((
|
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
|
||||||
), dim=2)
|
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
attn_output = torch.cat((
|
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
|
||||||
))
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaFlashAttention2(LlamaAttention):
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
# LlamaFlashAttention2 attention does not support output_attentions
|
|
||||||
output_attentions = False
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None: # reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# cast to half precision
|
|
||||||
input_dtype = query_states.dtype
|
|
||||||
if input_dtype == torch.float32:
|
|
||||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
|
||||||
query_states = query_states.to(self.config.torch_dtype)
|
|
||||||
key_states = key_states.to(self.config.torch_dtype)
|
|
||||||
value_states = value_states.to(self.config.torch_dtype)
|
|
||||||
|
|
||||||
if getattr(self, "num_key_value_groups", None):
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
|
||||||
num_groups = q_len // groupsz
|
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
|
||||||
state = torch.cat((
|
|
||||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
|
||||||
), dim=2)
|
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
|
||||||
# -q_len: assumes left padding when q_len != kv_len
|
|
||||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
|
||||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
|
||||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
|
||||||
attn_output_unpad = flash_attn_varlen_func(
|
|
||||||
unpadded_q,
|
|
||||||
unpadded_k,
|
|
||||||
unpadded_v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
)
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = torch.cat(
|
||||||
|
(
|
||||||
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
def llama_flash_attn_forward(
|
||||||
|
self: "LlamaFlashAttention2",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
else:
|
else:
|
||||||
attn_output = flash_attn_func(
|
target_dtype = self.q_proj.weight.dtype
|
||||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
|
||||||
|
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
|
state = torch.cat(
|
||||||
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
if attention_mask is not None:
|
||||||
attn_output = torch.cat((
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
|
||||||
))
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||||
attn_output = self.o_proj(attn_output)
|
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||||
|
)
|
||||||
|
|
||||||
if not output_attentions:
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_weights = None
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = torch.cat(
|
||||||
|
(
|
||||||
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
def apply_llama_patch() -> None:
|
||||||
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
LlamaAttention.forward = llama_torch_attn_forward
|
||||||
def _prepare_decoder_attention_mask(
|
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||||
self,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
input_shape: torch.Tensor,
|
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
past_key_values_length: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if attention_mask is not None and torch.all(attention_mask):
|
|
||||||
return None # This uses the faster call when training with full samples
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|||||||
38
src/llmtuner/extras/patches/mixtral_patch.py
Normal file
38
src/llmtuner/extras/patches/mixtral_patch.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
||||||
|
current_hidden_states = self.w2(current_hidden_states)
|
||||||
|
return current_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
|
||||||
|
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
||||||
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
# we cast back to the input dtype
|
||||||
|
topk_weight = topk_weight.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||||
|
y = torch.empty_like(hidden_states)
|
||||||
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
expert = self.experts[i]
|
||||||
|
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
||||||
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mixtral_replace_moe_impl() -> None:
|
||||||
|
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||||
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
import os
|
|
||||||
import math
|
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from .logging import get_logger
|
||||||
from llmtuner.extras.packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
if is_matplotlib_available():
|
if is_matplotlib_available():
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
"""
|
"""
|
||||||
last = scalars[0]
|
last = scalars[0]
|
||||||
smoothed = list()
|
smoothed = list()
|
||||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||||
for next_val in scalars:
|
for next_val in scalars:
|
||||||
smoothed_val = last * weight + (1 - weight) * next_val
|
smoothed_val = last * weight + (1 - weight) * next_val
|
||||||
smoothed.append(smoothed_val)
|
smoothed.append(smoothed_val)
|
||||||
@@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||||
|
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
|
|||||||
@@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
|
|||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataArguments",
|
||||||
|
"EvaluationArguments",
|
||||||
|
"FinetuningArguments",
|
||||||
|
"GeneratingArguments",
|
||||||
|
"ModelArguments",
|
||||||
|
"get_eval_args",
|
||||||
|
"get_infer_args",
|
||||||
|
"get_train_args",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,33 +1,5 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import List, Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
DATA_CONFIG = "dataset_info.json"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetAttr:
|
|
||||||
|
|
||||||
load_from: str
|
|
||||||
dataset_name: Optional[str] = None
|
|
||||||
dataset_sha1: Optional[str] = None
|
|
||||||
system_prompt: Optional[str] = None
|
|
||||||
subset: Optional[str] = None
|
|
||||||
ranking: Optional[bool] = False
|
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
|
||||||
|
|
||||||
prompt: Optional[str] = "instruction"
|
|
||||||
query: Optional[str] = "input"
|
|
||||||
response: Optional[str] = "output"
|
|
||||||
history: Optional[str] = None
|
|
||||||
messages: Optional[str] = "conversations"
|
|
||||||
role: Optional[str] = "from"
|
|
||||||
content: Optional[str] = "value"
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.dataset_name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -35,85 +7,84 @@ class DataArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
template: Optional[str] = field(
|
template: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: Optional[str] = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."}
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: Optional[str] = field(
|
||||||
default="train",
|
default="train",
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: Optional[int] = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The maximum length reserved for label after tokenization."}
|
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||||
)
|
)
|
||||||
streaming: Optional[bool] = field(
|
streaming: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable dataset streaming."}
|
metadata={"help": "Enable dataset streaming."},
|
||||||
)
|
)
|
||||||
buffer_size: Optional[int] = field(
|
buffer_size: Optional[int] = field(
|
||||||
default=16384,
|
default=16384,
|
||||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||||
)
|
)
|
||||||
eval_num_beams: Optional[int] = field(
|
eval_num_beams: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
metadata={
|
||||||
)
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
system_prompt: Optional[str] = field(
|
},
|
||||||
default=None,
|
|
||||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||||
)
|
)
|
||||||
sft_packing: Optional[bool] = field(
|
sft_packing: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
|
||||||
)
|
)
|
||||||
cache_path: Optional[str] = field(
|
cache_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
metadata={"help": "Path to save or load the preprocessed datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -125,55 +96,3 @@ class DataArguments:
|
|||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
if self.streaming and self.cache_path:
|
|
||||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
|
||||||
|
|
||||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
|
||||||
self.seed = seed
|
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
|
||||||
try:
|
|
||||||
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
except Exception as err:
|
|
||||||
if self.dataset is not None:
|
|
||||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
|
||||||
dataset_info = None
|
|
||||||
|
|
||||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
|
||||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
|
||||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
|
||||||
|
|
||||||
if self.interleave_probs is not None:
|
|
||||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
|
||||||
for i, name in enumerate(dataset_names):
|
|
||||||
if name not in dataset_info:
|
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
|
||||||
|
|
||||||
if "hf_hub_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
|
||||||
elif "script_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"file",
|
|
||||||
dataset_name=dataset_info[name]["file_name"],
|
|
||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
|
||||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
|
||||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
|
||||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
|
||||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
|
||||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
|
||||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
|
||||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
|
||||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
|
||||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
|
||||||
dataset_attr.system_prompt = prompt_list[i]
|
|
||||||
self.dataset_list.append(dataset_attr)
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from datasets import DownloadMode
|
from datasets import DownloadMode
|
||||||
|
|
||||||
@@ -10,46 +10,39 @@ class EvaluationArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to specify the evaluation parameters.
|
Arguments pertaining to specify the evaluation parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task: str = field(
|
task: str = field(
|
||||||
metadata={"help": "Name of the evaluation task."}
|
metadata={"help": "Name of the evaluation task."},
|
||||||
)
|
)
|
||||||
task_dir: Optional[str] = field(
|
task_dir: Optional[str] = field(
|
||||||
default="evaluation",
|
default="evaluation",
|
||||||
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||||
)
|
)
|
||||||
batch_size: Optional[int] = field(
|
batch_size: Optional[int] = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The batch size per GPU for evaluation."}
|
metadata={"help": "The batch size per GPU for evaluation."},
|
||||||
)
|
)
|
||||||
seed: Optional[int] = field(
|
seed: Optional[int] = field(
|
||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed to be used with data loaders."}
|
metadata={"help": "Random seed to be used with data loaders."},
|
||||||
)
|
)
|
||||||
lang: Optional[Literal["en", "zh"]] = field(
|
lang: Optional[Literal["en", "zh"]] = field(
|
||||||
default="en",
|
default="en",
|
||||||
metadata={"help": "Language used at evaluation."}
|
metadata={"help": "Language used at evaluation."},
|
||||||
)
|
)
|
||||||
n_shot: Optional[int] = field(
|
n_shot: Optional[int] = field(
|
||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."}
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
save_dir: Optional[str] = field(
|
save_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."}
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
download_mode: Optional[DownloadMode] = field(
|
download_mode: Optional[DownloadMode] = field(
|
||||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
metadata={"help": "Download mode used for the evaluation datasets."}
|
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
task_available = []
|
|
||||||
for folder in os.listdir(self.task_dir):
|
|
||||||
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
|
||||||
task_available.append(folder)
|
|
||||||
|
|
||||||
if self.task not in task_available:
|
|
||||||
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
|
||||||
|
|
||||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||||
raise ValueError("`save_dir` already exists, use another one.")
|
raise ValueError("`save_dir` already exists, use another one.")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -8,19 +8,23 @@ class FreezeArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the freeze (partial-parameter) training.
|
Arguments pertaining to the freeze (partial-parameter) training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name_module_trainable: Optional[str] = field(
|
name_module_trainable: Optional[str] = field(
|
||||||
default="mlp",
|
default=None,
|
||||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
metadata={
|
||||||
Use commas to separate multiple modules. \
|
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
Use commas to separate multiple modules. \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
Use "all" to specify all the available modules. \
|
||||||
Qwen choices: [\"mlp\", \"attn\"], \
|
LLaMA choices: ["mlp", "self_attn"], \
|
||||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||||
Others choices: the same as LLaMA."}
|
Qwen choices: ["mlp", "attn"], \
|
||||||
|
InternLM2 choices: ["feed_forward", "attention"], \
|
||||||
|
Others choices: the same as LLaMA."""
|
||||||
|
},
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -29,35 +33,53 @@ class LoraArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the LoRA training.
|
Arguments pertaining to the LoRA training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
metadata={
|
||||||
|
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[float] = field(
|
lora_alpha: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.0,
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
"help": """Name(s) of target modules to apply LoRA. \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
Use commas to separate multiple modules. \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
Use "all" to specify all the available modules. \
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||||
Others choices: the same as LLaMA."}
|
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
|
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||||
|
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||||
|
Others choices: the same as LLaMA."""
|
||||||
|
},
|
||||||
)
|
)
|
||||||
resume_lora_training: Optional[bool] = field(
|
lora_bf16_mode: Optional[bool] = field(
|
||||||
default=True,
|
default=False,
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
||||||
|
)
|
||||||
|
use_rslora: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||||
|
)
|
||||||
|
use_dora: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
|
||||||
|
)
|
||||||
|
create_new_adapter: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,61 +88,70 @@ class RLHFArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the PPO and DPO training.
|
Arguments pertaining to the PPO and DPO training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."},
|
||||||
|
)
|
||||||
|
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
||||||
|
default="sigmoid",
|
||||||
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
|
)
|
||||||
|
dpo_ftx: Optional[float] = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
ppo_buffer_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_epochs: Optional[int] = field(
|
ppo_epochs: Optional[int] = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_logger: Optional[str] = field(
|
ppo_logger: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
|
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
ppo_score_norm: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use score normalization in PPO training."}
|
metadata={"help": "Use score normalization in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: Optional[float] = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_whiten_rewards: Optional[bool] = field(
|
ppo_whiten_rewards: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
ref_model: Optional[str] = field(
|
ref_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||||
)
|
)
|
||||||
ref_model_checkpoint: Optional[str] = field(
|
ref_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
metadata={"help": "Path to the adapters of the reference model."},
|
||||||
)
|
)
|
||||||
ref_model_quantization_bit: Optional[int] = field(
|
ref_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reference model."}
|
metadata={"help": "The number of bits to quantize the reference model."},
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||||
)
|
)
|
||||||
reward_model_checkpoint: Optional[str] = field(
|
reward_model_adapters: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
metadata={"help": "Path to the adapters of the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_quantization_bit: Optional[int] = field(
|
reward_model_quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."}
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -129,33 +160,26 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
use_llama_pro: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
neft_alpha: Optional[float] = field(
|
disable_version_checking: Optional[bool] = field(
|
||||||
default=0,
|
default=False,
|
||||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
metadata={"help": "Whether or not to disable version checking."},
|
||||||
)
|
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
|
||||||
export_size: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -165,21 +189,22 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
return arg
|
return arg
|
||||||
|
|
||||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||||
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
|
||||||
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||||
|
|
||||||
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
|
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -7,41 +7,44 @@ class GeneratingArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = field(
|
temperature: Optional[float] = field(
|
||||||
default=0.95,
|
default=0.95,
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
metadata={"help": "The value used to modulate the next token probabilities."},
|
||||||
)
|
)
|
||||||
top_p: Optional[float] = field(
|
top_p: Optional[float] = field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
metadata={
|
||||||
|
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
top_k: Optional[int] = field(
|
top_k: Optional[int] = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||||
)
|
)
|
||||||
max_length: Optional[int] = field(
|
max_length: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = field(
|
repetition_penalty: Optional[float] = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
||||||
)
|
)
|
||||||
length_penalty: Optional[float] = field(
|
length_penalty: Optional[float] = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Any, Dict, Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -7,57 +7,119 @@ class ModelArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
metadata={"help": "Path to pretrained model or model identifier from \
|
metadata={
|
||||||
huggingface.co/models or modelscope.cn/models."}
|
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adapter_name_or_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
use_fast_tokenizer: Optional[bool] = field(
|
||||||
default=True,
|
default=False,
|
||||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||||
|
)
|
||||||
|
resize_vocab: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
||||||
)
|
)
|
||||||
split_special_tokens: Optional[bool] = field(
|
split_special_tokens: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
model_revision: Optional[str] = field(
|
model_revision: Optional[str] = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
)
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model."}
|
metadata={"help": "The number of bits to quantize the model."},
|
||||||
)
|
)
|
||||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||||
default="nf4",
|
default="nf4",
|
||||||
metadata={"help": "Quantization data type to use in int4 training."}
|
metadata={"help": "Quantization data type to use in int4 training."},
|
||||||
)
|
)
|
||||||
double_quantization: Optional[bool] = field(
|
double_quantization: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||||
)
|
)
|
||||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
|
||||||
checkpoint_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
|
||||||
)
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||||
)
|
)
|
||||||
shift_attn: Optional[bool] = field(
|
shift_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
|
)
|
||||||
|
use_unsloth: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||||
|
)
|
||||||
|
disable_gradient_checkpointing: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||||
|
)
|
||||||
|
upcast_layernorm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
||||||
|
)
|
||||||
|
upcast_lmhead_output: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||||
)
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
|
)
|
||||||
|
ms_hub_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
|
)
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
|
)
|
||||||
|
export_size: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
|
)
|
||||||
|
export_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the exported model."},
|
||||||
|
)
|
||||||
|
export_quantization_dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
|
)
|
||||||
|
export_quantization_nsamples: Optional[int] = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "The number of samples used for quantization."},
|
||||||
|
)
|
||||||
|
export_quantization_maxlen: Optional[int] = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||||
|
)
|
||||||
|
export_legacy_format: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
|
)
|
||||||
|
export_hub_model_id: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
|
)
|
||||||
|
print_param_status: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -67,10 +129,14 @@ class ModelArguments:
|
|||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||||
|
|
||||||
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|||||||
275
src/llmtuner/hparams/parser.py
Normal file
275
src/llmtuner/hparams/parser.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.packages import is_unsloth_available
|
||||||
|
from .data_args import DataArguments
|
||||||
|
from .evaluation_args import EvaluationArguments
|
||||||
|
from .finetuning_args import FinetuningArguments
|
||||||
|
from .generating_args import GeneratingArguments
|
||||||
|
from .model_args import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dependencies(disabled: bool) -> None:
|
||||||
|
if disabled:
|
||||||
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
else:
|
||||||
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
|
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||||
|
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
|
if args is not None:
|
||||||
|
return parser.parse_dict(args)
|
||||||
|
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
|
||||||
|
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
|
if unknown_args:
|
||||||
|
print(parser.format_help())
|
||||||
|
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||||
|
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||||
|
|
||||||
|
return (*parsed_args,)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
if finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||||
|
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
|
parser = HfArgumentParser(_EVAL_ARGS)
|
||||||
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
if training_args.should_log:
|
||||||
|
_set_transformers_logging()
|
||||||
|
|
||||||
|
# Check arguments
|
||||||
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
|
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||||
|
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||||
|
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
training_args.do_train
|
||||||
|
and finetuning_args.finetuning_type == "freeze"
|
||||||
|
and finetuning_args.name_module_trainable is None
|
||||||
|
):
|
||||||
|
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
|
||||||
|
|
||||||
|
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||||
|
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||||
|
|
||||||
|
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
||||||
|
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||||
|
|
||||||
|
if finetuning_args.use_dora:
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
raise ValueError("DoRA does not support quantization.")
|
||||||
|
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
raise ValueError("Unsloth does not support DoRA.")
|
||||||
|
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||||
|
|
||||||
|
if (
|
||||||
|
training_args.do_train
|
||||||
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
and model_args.resize_vocab
|
||||||
|
and finetuning_args.additional_target is None
|
||||||
|
):
|
||||||
|
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
|
||||||
|
|
||||||
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
|
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
|
logger.warning("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
|
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
|
# Post-process training arguments
|
||||||
|
if (
|
||||||
|
training_args.local_rank != -1
|
||||||
|
and training_args.ddp_find_unused_parameters is None
|
||||||
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
):
|
||||||
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
|
training_args_dict = training_args.to_dict()
|
||||||
|
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||||
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
|
can_resume_from_checkpoint = False
|
||||||
|
if training_args.resume_from_checkpoint is not None:
|
||||||
|
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||||
|
training_args.resume_from_checkpoint = None
|
||||||
|
else:
|
||||||
|
can_resume_from_checkpoint = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
training_args.resume_from_checkpoint is None
|
||||||
|
and training_args.do_train
|
||||||
|
and os.path.isdir(training_args.output_dir)
|
||||||
|
and not training_args.overwrite_output_dir
|
||||||
|
and can_resume_from_checkpoint
|
||||||
|
):
|
||||||
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
|
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||||
|
|
||||||
|
if last_checkpoint is not None:
|
||||||
|
training_args_dict = training_args.to_dict()
|
||||||
|
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||||
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
logger.info(
|
||||||
|
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||||
|
training_args.resume_from_checkpoint
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
and training_args.resume_from_checkpoint is not None
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
|
training_args.resume_from_checkpoint
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-process model arguments
|
||||||
|
model_args.compute_dtype = (
|
||||||
|
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||||
|
)
|
||||||
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
|
# Log on each process the small summary:
|
||||||
|
logger.info(
|
||||||
|
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||||
|
training_args.local_rank,
|
||||||
|
training_args.device,
|
||||||
|
training_args.n_gpu,
|
||||||
|
bool(training_args.local_rank != -1),
|
||||||
|
str(model_args.compute_dtype),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
|
_set_transformers_logging()
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||||
|
|
||||||
|
if data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
|
_set_transformers_logging()
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||||
|
|
||||||
|
if data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, eval_args, finetuning_args
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# Level: loader > adapter > parser, utils
|
from .loader import load_model_and_tokenizer
|
||||||
|
from .utils import dispatch_model, load_valuehead_params
|
||||||
|
|
||||||
from llmtuner.model.loader import load_model_and_tokenizer
|
|
||||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]
|
||||||
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
|
||||||
|
|||||||
@@ -1,23 +1,24 @@
|
|||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
import torch
|
||||||
from llmtuner.model.utils import find_all_linear_modules
|
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from .utils import find_all_linear_modules, find_expanded_modules
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|
||||||
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_adapter(
|
def init_adapter(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
is_trainable: bool
|
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Initializes the adapters.
|
Initializes the adapters.
|
||||||
@@ -27,8 +28,8 @@ def init_adapter(
|
|||||||
Note that the trainable parameters must be cast to float32.
|
Note that the trainable parameters must be cast to float32.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
@@ -44,65 +45,115 @@ def init_adapter(
|
|||||||
)
|
)
|
||||||
if not num_layers:
|
if not num_layers:
|
||||||
raise ValueError("Current model does not support freeze tuning.")
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
|
||||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
if finetuning_args.use_llama_pro:
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
if num_layers % finetuning_args.num_layer_trainable != 0:
|
||||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
raise ValueError(
|
||||||
|
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||||
|
num_layers, finetuning_args.num_layer_trainable
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stride = num_layers // finetuning_args.num_layer_trainable
|
||||||
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
|
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
|
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
|
||||||
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
|
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
|
freeze_modules = {"all"}
|
||||||
|
for name, _ in model.named_modules():
|
||||||
|
if ".0." in name:
|
||||||
|
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||||
|
|
||||||
trainable_layers = []
|
trainable_layers = []
|
||||||
for module_name in finetuning_args.name_module_trainable:
|
for module_name in finetuning_args.name_module_trainable:
|
||||||
|
if module_name not in freeze_modules:
|
||||||
|
raise ValueError(
|
||||||
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
|
||||||
|
)
|
||||||
|
|
||||||
for idx in trainable_layer_ids:
|
for idx in trainable_layer_ids:
|
||||||
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
param.requires_grad_(False)
|
|
||||||
else:
|
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
else:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
logger.info("Fine-tuning method: LoRA")
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
checkpoint_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
if getattr(model, "quantization_method", None) == "gptq":
|
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||||
assert len(model_args.checkpoint_dir) == 1, "GPTQ quantized model only accepts a single checkpoint."
|
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable):
|
if is_deepspeed_zero3_enabled():
|
||||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||||
else:
|
is_mergeable = False
|
||||||
checkpoints_to_merge = model_args.checkpoint_dir
|
|
||||||
|
|
||||||
for checkpoint in checkpoints_to_merge:
|
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||||
model = PeftModel.from_pretrained(model, checkpoint)
|
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||||
|
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||||
|
else:
|
||||||
|
adapter_to_merge = model_args.adapter_name_or_path
|
||||||
|
|
||||||
|
for adapter in adapter_to_merge:
|
||||||
|
model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(checkpoints_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||||
|
|
||||||
if checkpoint_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
||||||
|
|
||||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model)
|
target_modules = find_all_linear_modules(model)
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
if finetuning_args.use_llama_pro:
|
||||||
task_type=TaskType.CAUSAL_LM,
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
inference_mode=False,
|
|
||||||
r=finetuning_args.lora_rank,
|
|
||||||
lora_alpha=finetuning_args.lora_alpha,
|
|
||||||
lora_dropout=finetuning_args.lora_dropout,
|
|
||||||
target_modules=target_modules,
|
|
||||||
modules_to_save=finetuning_args.additional_target
|
|
||||||
)
|
|
||||||
model = get_peft_model(model, lora_config)
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if finetuning_args.use_dora:
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
if getattr(model, "quantization_method", None):
|
||||||
|
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||||
|
|
||||||
|
peft_kwargs = {
|
||||||
|
"r": finetuning_args.lora_rank,
|
||||||
|
"target_modules": target_modules,
|
||||||
|
"lora_alpha": finetuning_args.lora_alpha,
|
||||||
|
"lora_dropout": finetuning_args.lora_dropout,
|
||||||
|
"use_rslora": finetuning_args.use_rslora,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
from unsloth import FastLanguageModel # type: ignore
|
||||||
|
|
||||||
|
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||||
|
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||||
|
else:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
|
modules_to_save=finetuning_args.additional_target,
|
||||||
|
use_dora=finetuning_args.use_dora,
|
||||||
|
**peft_kwargs,
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
|
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -1,56 +1,31 @@
|
|||||||
import os
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from types import MethodType
|
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
AutoConfig,
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizerBase
|
|
||||||
)
|
|
||||||
from transformers.models.llama import modeling_llama as LlamaModule
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
try:
|
from ..extras.logging import get_logger
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
from .adapter import init_adapter
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
|
from .utils import load_valuehead_params, register_autoclass
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype, try_download_model_from_ms
|
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
|
||||||
from llmtuner.hparams import FinetuningArguments
|
|
||||||
from llmtuner.model.adapter import init_adapter
|
|
||||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from llmtuner.hparams import ModelArguments
|
|
||||||
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
|
||||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
|
||||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
|
||||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
add_valuehead: Optional[bool] = False
|
add_valuehead: Optional[bool] = False,
|
||||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
|
||||||
@@ -63,176 +38,95 @@ def load_model_and_tokenizer(
|
|||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
"revision": model_args.model_revision,
|
"revision": model_args.model_revision,
|
||||||
"token": model_args.hf_hub_token
|
"token": model_args.hf_hub_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
split_special_tokens=model_args.split_special_tokens,
|
split_special_tokens=model_args.split_special_tokens,
|
||||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
padding_side="right",
|
||||||
**config_kwargs
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
|
||||||
else:
|
|
||||||
model_to_load = model_args.model_name_or_path
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
model = None
|
||||||
|
if is_trainable and model_args.use_unsloth:
|
||||||
|
from unsloth import FastLanguageModel # type: ignore
|
||||||
|
|
||||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
unsloth_kwargs = {
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
"model_name": model_args.model_name_or_path,
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
"dtype": model_args.compute_dtype,
|
||||||
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
|
"token": model_args.hf_hub_token,
|
||||||
|
"device_map": {"": get_current_device()},
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
model_args.use_unsloth = False
|
||||||
|
|
||||||
# Set model dtype
|
if model_args.adapter_name_or_path:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
model_args.adapter_name_or_path = None
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
logger.warning("Unsloth does not support loading adapters.")
|
||||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
|
||||||
|
|
||||||
# Fix config (for Qwen)
|
if model is None:
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
model_args.model_name_or_path,
|
||||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
config=config,
|
||||||
|
torch_dtype=model_args.compute_dtype,
|
||||||
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
|
**config_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Set RoPE scaling
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
if model_args.rope_scaling is not None:
|
register_autoclass(config, model, tokenizer)
|
||||||
if not hasattr(config, "rope_scaling"):
|
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
|
||||||
else:
|
|
||||||
if is_trainable:
|
|
||||||
if model_args.rope_scaling == "dynamic":
|
|
||||||
logger.warning(
|
|
||||||
"Dynamic NTK may not work well with fine-tuning. "
|
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
|
||||||
else:
|
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
|
||||||
scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
scaling_factor = 2.0
|
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
|
||||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
|
||||||
model_args.rope_scaling, scaling_factor
|
|
||||||
))
|
|
||||||
|
|
||||||
# Set FlashAttention-2
|
|
||||||
if model_args.flash_attn:
|
|
||||||
if getattr(config, "model_type", None) == "llama":
|
|
||||||
if is_flash_attn2_available():
|
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
else:
|
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
|
||||||
else:
|
|
||||||
logger.warning("Current model does not support FlashAttention.")
|
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
|
||||||
|
|
||||||
# Set shift short attention (S^2-Attn)
|
|
||||||
if is_trainable and model_args.shift_attn:
|
|
||||||
if getattr(config, "model_type", None) == "llama":
|
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
|
||||||
else:
|
|
||||||
logger.warning("Current model does not support shift short attention.")
|
|
||||||
|
|
||||||
# Quantization configurations (using gptq or awq)
|
|
||||||
if getattr(config, "quantization_config", None):
|
|
||||||
if model_args.quantization_bit is not None: # remove bnb quantization
|
|
||||||
model_args.quantization_bit = None
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
|
||||||
quantization_config = getattr(config, "quantization_config", None)
|
|
||||||
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library)
|
|
||||||
if model_args.quantization_bit is not None:
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
|
||||||
|
|
||||||
if model_args.quantization_bit == 4:
|
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
|
||||||
)
|
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
||||||
|
|
||||||
# Load pre-trained models (without valuehead)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_to_load,
|
|
||||||
config=config,
|
|
||||||
torch_dtype=model_args.compute_dtype,
|
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
|
||||||
**config_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Disable custom generate method (for Qwen and Baichuan2)
|
|
||||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
|
||||||
|
|
||||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
|
||||||
setattr(model, "lm_head", model.transformer.output_layer)
|
|
||||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
|
||||||
|
|
||||||
# Register auto class to save the custom code files
|
|
||||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
|
||||||
config.__class__.register_for_auto_class()
|
|
||||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
|
||||||
model.__class__.register_for_auto_class()
|
|
||||||
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
|
||||||
tokenizer.__class__.register_for_auto_class()
|
|
||||||
|
|
||||||
# Initialize adapters
|
|
||||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||||
|
|
||||||
# Prepare model with valuehead for RLHF
|
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
patch_valuehead_model(model)
|
||||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
|
||||||
vhead_path = (
|
if model_args.adapter_name_or_path is not None:
|
||||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
vhead_path = model_args.adapter_name_or_path[-1]
|
||||||
)
|
else:
|
||||||
|
vhead_path = model_args.model_name_or_path
|
||||||
|
|
||||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||||
|
|
||||||
# Prepare model for inference
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False)
|
||||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
logger.info(
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
))
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||||
|
|
||||||
|
if model_args.print_param_status:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
print(
|
||||||
|
"name: {}, dtype: {}, device: {}, trainable: {}".format(
|
||||||
|
name, param.dtype, param.device, param.requires_grad
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|||||||
@@ -1,205 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import datasets
|
|
||||||
import transformers
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
from llmtuner.extras.misc import parse_args
|
|
||||||
from llmtuner.hparams import (
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
EvaluationArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [
|
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_TRAIN_CLS = Tuple[
|
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_INFER_ARGS = [
|
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_INFER_CLS = Tuple[
|
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
]
|
|
||||||
_EVAL_ARGS = [
|
|
||||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
|
||||||
]
|
|
||||||
_EVAL_CLS = Tuple[
|
|
||||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_args.checkpoint_dir is not None
|
|
||||||
and len(model_args.checkpoint_dir) != 1
|
|
||||||
and finetuning_args.finetuning_type != "lora"
|
|
||||||
):
|
|
||||||
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
|
||||||
return parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|
||||||
parser = HfArgumentParser(_INFER_ARGS)
|
|
||||||
return parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|
||||||
parser = HfArgumentParser(_EVAL_ARGS)
|
|
||||||
return parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
if training_args.should_log:
|
|
||||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
|
|
||||||
log_level = training_args.get_process_log_level()
|
|
||||||
datasets.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.enable_default_handler()
|
|
||||||
transformers.utils.logging.enable_explicit_format()
|
|
||||||
|
|
||||||
# Check arguments
|
|
||||||
data_args.init_for_training(training_args.seed)
|
|
||||||
|
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
|
||||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
|
||||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
|
|
||||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
|
||||||
|
|
||||||
if training_args.do_train and training_args.predict_with_generate:
|
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
|
||||||
|
|
||||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
|
||||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
|
||||||
|
|
||||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
|
||||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
|
||||||
|
|
||||||
# postprocess training_args
|
|
||||||
if (
|
|
||||||
training_args.local_rank != -1
|
|
||||||
and training_args.ddp_find_unused_parameters is None
|
|
||||||
and finetuning_args.finetuning_type == "lora"
|
|
||||||
):
|
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
|
||||||
training_args_dict = training_args.to_dict()
|
|
||||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
|
|
||||||
if (
|
|
||||||
training_args.resume_from_checkpoint is None
|
|
||||||
and training_args.do_train
|
|
||||||
and os.path.isdir(training_args.output_dir)
|
|
||||||
and not training_args.overwrite_output_dir
|
|
||||||
):
|
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
||||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
|
||||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
|
||||||
|
|
||||||
if last_checkpoint is not None:
|
|
||||||
training_args_dict = training_args.to_dict()
|
|
||||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
|
||||||
training_args.resume_from_checkpoint
|
|
||||||
))
|
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
|
||||||
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
|
|
||||||
training_args.resume_from_checkpoint
|
|
||||||
))
|
|
||||||
|
|
||||||
# postprocess model_args
|
|
||||||
model_args.compute_dtype = (
|
|
||||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
|
||||||
)
|
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
|
||||||
|
|
||||||
# Log on each process the small summary:
|
|
||||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
|
||||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
|
||||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
|
||||||
))
|
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
|
||||||
|
|
||||||
# Set seed before initializing model.
|
|
||||||
transformers.set_seed(training_args.seed)
|
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
|
||||||
|
|
||||||
|
|
||||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|
||||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
|
||||||
raise ValueError("Please specify which `template` to use.")
|
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
|
||||||
|
|
||||||
return model_args, data_args, eval_args, finetuning_args
|
|
||||||
334
src/llmtuner/model/patcher.py
Normal file
334
src/llmtuner/model/patcher.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from types import MethodType
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||||
|
from ..extras.packages import is_flash_attn2_available
|
||||||
|
from ..extras.patches.llama_patch import apply_llama_patch
|
||||||
|
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
||||||
|
|
||||||
|
|
||||||
|
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||||
|
embedding_dim = embed_weight.size(1)
|
||||||
|
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
|
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||||
|
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||||
|
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Resize token embeddings.
|
||||||
|
"""
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
|
params = [model.get_input_embeddings().weight]
|
||||||
|
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||||
|
params.append(model.get_output_embeddings().weight)
|
||||||
|
|
||||||
|
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||||
|
else:
|
||||||
|
context_maybe_zero3 = nullcontext()
|
||||||
|
|
||||||
|
with context_maybe_zero3:
|
||||||
|
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
|
||||||
|
if len(tokenizer) > current_embedding_size:
|
||||||
|
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||||
|
logger.warning("Current model does not support resizing token embeddings.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||||
|
with context_maybe_zero3:
|
||||||
|
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
num_new_tokens = new_embedding_size - current_embedding_size
|
||||||
|
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||||
|
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||||
|
|
||||||
|
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||||
|
r"""
|
||||||
|
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||||
|
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||||
|
"""
|
||||||
|
if os.path.isfile(model_args.export_quantization_dataset):
|
||||||
|
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||||
|
data_files = model_args.export_quantization_dataset
|
||||||
|
else:
|
||||||
|
data_path = model_args.export_quantization_dataset
|
||||||
|
data_files = None
|
||||||
|
|
||||||
|
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||||
|
maxlen = model_args.export_quantization_maxlen
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for _ in range(model_args.export_quantization_nsamples):
|
||||||
|
while True:
|
||||||
|
sample_idx = random.randint(0, len(dataset) - 1)
|
||||||
|
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||||
|
if sample["input_ids"].size(1) >= maxlen:
|
||||||
|
break # TODO: fix large maxlen
|
||||||
|
|
||||||
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||||
|
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||||
|
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
|
||||||
|
if model_args.flash_attn:
|
||||||
|
if is_flash_attn2_available():
|
||||||
|
config_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.warning("FlashAttention2 is not installed.")
|
||||||
|
config_kwargs["attn_implementation"] = None
|
||||||
|
else:
|
||||||
|
config_kwargs["attn_implementation"] = "eager"
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if not hasattr(config, "rope_scaling"):
|
||||||
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
if model_args.rope_scaling == "dynamic":
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||||
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
|
else:
|
||||||
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scaling_factor = 2.0
|
||||||
|
|
||||||
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
|
logger.info(
|
||||||
|
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_longlora(config: "PretrainedConfig") -> None:
|
||||||
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
|
apply_llama_patch()
|
||||||
|
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_quantization(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
config_kwargs: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
|
"""
|
||||||
|
if getattr(config, "quantization_config", None): # gptq
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||||
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
|
|
||||||
|
if quantization_config.get("quant_method", None) == "aqlm":
|
||||||
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loading {}-bit {}-quantized model.".format(
|
||||||
|
quantization_config.get("bits", "?"), quantization_config.get("quant_method", None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
from accelerate.utils import get_max_memory
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
|
raise ValueError("ChatGLM model is not supported.")
|
||||||
|
|
||||||
|
config_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
bits=model_args.export_quantization_bit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||||
|
)
|
||||||
|
config_kwargs["device_map"] = "auto"
|
||||||
|
config_kwargs["max_memory"] = get_max_memory()
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit == 8:
|
||||||
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
|
elif model_args.quantization_bit == 4:
|
||||||
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_model_for_training(
|
||||||
|
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Includes:
|
||||||
|
(1) cast the layernorm in fp32
|
||||||
|
(2) make output embedding layer require grads
|
||||||
|
(3) add the upcasting of the lm_head in fp32
|
||||||
|
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||||
|
"""
|
||||||
|
if model_args.upcast_layernorm:
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
logger.info("Upcasting layernorm weights in float32.")
|
||||||
|
|
||||||
|
if not model_args.disable_gradient_checkpointing:
|
||||||
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
|
logger.warning("Current model does not support gradient checkpointing.")
|
||||||
|
else:
|
||||||
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
|
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||||
|
|
||||||
|
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||||
|
return output.to(torch.float32)
|
||||||
|
|
||||||
|
output_layer = getattr(model, output_layer_name)
|
||||||
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
|
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||||
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_config(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
config_kwargs: Dict[str, Any],
|
||||||
|
is_trainable: bool,
|
||||||
|
) -> None:
|
||||||
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
|
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||||
|
|
||||||
|
_configure_attn_implementation(model_args, config_kwargs)
|
||||||
|
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
_configure_rope(config, model_args, is_trainable)
|
||||||
|
|
||||||
|
if is_trainable and model_args.shift_attn:
|
||||||
|
_configure_longlora(config)
|
||||||
|
|
||||||
|
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_model(
|
||||||
|
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||||
|
) -> None:
|
||||||
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
|
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||||
|
|
||||||
|
if model_args.resize_vocab:
|
||||||
|
_resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
_prepare_model_for_training(model, model_args)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
|
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||||
|
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
patch_mixtral_replace_moe_impl()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.add_model_tags(["llama-factory"])
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Cannot properly tag the model.")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
|
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
|
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||||
|
self.pretrained_model.tie_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||||
|
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||||
|
return self.pretrained_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||||
|
if isinstance(self.pretrained_model, PeftModel):
|
||||||
|
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||||
|
|
||||||
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
|
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||||
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
|
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||||
@@ -1,17 +1,19 @@
|
|||||||
import torch
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedModel
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from ..extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -19,27 +21,32 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||||
"""
|
"""
|
||||||
if getattr(model, "quantization_method", None): # already set on current device
|
if getattr(model, "quantization_method", None): # already set on current device
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and getattr(model.config, "model_type", None) != "chatglm":
|
if (
|
||||||
|
torch.cuda.device_count() > 1
|
||||||
|
and isinstance(model, PreTrainedModel)
|
||||||
|
and model._no_split_modules is not None
|
||||||
|
and model.config.model_type != "chatglm"
|
||||||
|
):
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||||
|
|
||||||
if model._no_split_modules is None:
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
# Make sure tied weights are tied before creating the device map.
|
# Make sure tied weights are tied before creating the device map.
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||||
return dispatch_model(model, device_map)
|
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
||||||
|
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||||
|
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||||
|
return dispatch_model(model, **device_map_kwargs)
|
||||||
else:
|
else:
|
||||||
return model.cuda()
|
return model.to(device=get_current_device())
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
@@ -51,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
|||||||
linear_cls = torch.nn.Linear
|
linear_cls = torch.nn.Linear
|
||||||
elif quantization_method == "bitsandbytes":
|
elif quantization_method == "bitsandbytes":
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||||
else:
|
else:
|
||||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||||
@@ -61,123 +69,72 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
|||||||
|
|
||||||
module_names = set()
|
module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
|
||||||
isinstance(module, linear_cls)
|
|
||||||
and not any([output_layer in name for output_layer in output_layer_names])
|
|
||||||
):
|
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||||
return list(module_names)
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
def get_modelcard_args(
|
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||||
model_args: "ModelArguments",
|
r"""
|
||||||
data_args: "DataArguments",
|
Finds the modules in the expanded blocks to apply lora.
|
||||||
finetuning_args: "FinetuningArguments"
|
"""
|
||||||
) -> Dict[str, Any]:
|
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||||
return {
|
if not num_layers:
|
||||||
"tasks": "text-generation",
|
raise ValueError("Model was not supported.")
|
||||||
"license": "other",
|
|
||||||
"finetuned_from": model_args.model_name_or_path,
|
if num_layers % num_layer_trainable != 0:
|
||||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
raise ValueError(
|
||||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||||
}
|
)
|
||||||
|
|
||||||
|
stride = num_layers // num_layer_trainable
|
||||||
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
|
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||||
|
module_names = []
|
||||||
|
for name, _ in model.named_modules():
|
||||||
|
if any(target_module in name for target_module in target_modules) and any(
|
||||||
|
trainable_layer in name for trainable_layer in trainable_layers
|
||||||
|
):
|
||||||
|
module_names.append(name)
|
||||||
|
|
||||||
|
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
path_or_repo_id: str,
|
|
||||||
model_args: "ModelArguments"
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
|
||||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||||
"path_or_repo_id": path_or_repo_id,
|
|
||||||
"cache_dir": model_args.cache_dir
|
|
||||||
}
|
|
||||||
|
|
||||||
if "token" in inspect.signature(cached_file).parameters:
|
|
||||||
kwargs["token"] = model_args.hf_hub_token
|
|
||||||
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
|
||||||
kwargs["use_auth_token"] = model_args.hf_hub_token
|
|
||||||
else:
|
|
||||||
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
|
||||||
return torch.load(vhead_file, map_location="cpu")
|
|
||||||
except Exception as err:
|
|
||||||
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
|
||||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
|
||||||
return {
|
|
||||||
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
|
||||||
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
|
|
||||||
}
|
|
||||||
except Exception as err:
|
|
||||||
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
|
||||||
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
|
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||||
|
return {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
except Exception as err:
|
||||||
|
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||||
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
|
except Exception as err:
|
||||||
|
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||||
|
|
||||||
|
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||||
|
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||||
model: "PreTrainedModel",
|
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
finetuning_args: "FinetuningArguments",
|
config.__class__.register_for_auto_class()
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
model.__class__.register_for_auto_class()
|
||||||
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||||
) -> "PreTrainedModel":
|
tokenizer.__class__.register_for_auto_class()
|
||||||
r"""
|
|
||||||
Includes:
|
|
||||||
(1) cast the layernorm in fp32
|
|
||||||
(2) make output embedding layer require grads
|
|
||||||
(3) upcast the lm_head to fp32
|
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
|
||||||
"""
|
|
||||||
if finetuning_args.upcast_layernorm:
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
logger.info("Upcasting weights in layernorm in float32.")
|
|
||||||
|
|
||||||
if finetuning_args.neft_alpha > 1e-6:
|
|
||||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
|
||||||
if module.training:
|
|
||||||
dims = torch.tensor(output.size(1) * output.size(2))
|
|
||||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
|
||||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
|
||||||
return output
|
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
|
||||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
|
||||||
|
|
||||||
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
else:
|
|
||||||
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
|
||||||
output.requires_grad_(True)
|
|
||||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
||||||
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
|
||||||
logger.info("Gradient checkpointing enabled.")
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
|
||||||
output_layer = getattr(model, output_layer_name)
|
|
||||||
if isinstance(output_layer, torch.nn.Linear):
|
|
||||||
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
|
||||||
return args[0].to(output_layer.weight.dtype)
|
|
||||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
|
||||||
return output.to(torch.float32)
|
|
||||||
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
|
||||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.tuner import export_model, run_exp
|
from .tuner import export_model, run_exp
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["export_model", "run_exp"]
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.dpo.workflow import run_dpo
|
from .workflow import run_dpo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_dpo"]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Sequence, Tuple
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
padded_tensor[start:end] = feature[start:end]
|
padded_tensor[start:end] = feature[start:end]
|
||||||
padded_labels.append(padded_tensor)
|
padded_labels.append(padded_tensor)
|
||||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
for key in ("chosen_ids", "rejected_ids"):
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
concatenated_features.append({
|
concatenated_features.append(
|
||||||
"input_ids": feature["prompt_ids"] + feature[key],
|
{
|
||||||
"attention_mask": [1] * (prompt_len + answer_len)
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
})
|
"attention_mask": [1] * (prompt_len + answer_len),
|
||||||
|
}
|
||||||
|
)
|
||||||
label_positions.append((prompt_len, answer_len))
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
batch = self.tokenizer.pad(
|
batch = self.tokenizer.pad(
|
||||||
|
|||||||
@@ -1,40 +1,51 @@
|
|||||||
import torch
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
|
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||||
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: Optional[bool] = True,
|
disable_dropout: Optional[bool] = True,
|
||||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
**kwargs,
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
if disable_dropout:
|
if disable_dropout:
|
||||||
disable_dropout_in_model(model)
|
disable_dropout_in_model(model)
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
self.reference_free = False
|
||||||
self.ref_model = ref_model
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.generate_during_eval = False # disable at evaluation
|
||||||
self.generate_during_eval = False # disable at evaluation
|
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
self.padding_value = 0
|
self.padding_value = 0
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
self.precompute_ref_log_probs = False
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
self.label_smoothing = 0
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
self.ftx_gamma = ftx_gamma
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
Trainer.__init__(self, model=model, **kwargs)
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
@@ -44,32 +55,95 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
if not (
|
if not (
|
||||||
getattr(ref_model, "is_loaded_in_8bit", False)
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
or getattr(ref_model, "is_loaded_in_4bit", False)
|
): # quantized models are already set on the correct device
|
||||||
): # quantized models are already set on the correct device
|
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
|
"""
|
||||||
|
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||||
|
return -all_logps
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
|
||||||
model: Optional[torch.nn.Module] = None,
|
|
||||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
|
|
||||||
all_logits = model(
|
all_logits = model(
|
||||||
input_ids=batch_copied["input_ids"],
|
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
|
||||||
attention_mask=batch_copied["attention_mask"],
|
|
||||||
return_dict=True
|
|
||||||
).logits.to(torch.float32)
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
all_logps = self._get_batch_logps(
|
all_logps = self.get_batch_logps(
|
||||||
all_logits,
|
all_logits,
|
||||||
batch["labels"],
|
batch["labels"],
|
||||||
average_log_prob=False
|
average_log_prob=False,
|
||||||
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
)
|
)
|
||||||
batch_size = batch["input_ids"].size(0) // 2
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
batch: Dict[str, torch.Tensor],
|
||||||
|
train_eval: Optional[Literal["train", "eval"]] = "train",
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
r"""
|
||||||
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
) = self.concatenated_forward(model, batch)
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.ref_model is None:
|
||||||
|
ref_model = self.model
|
||||||
|
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||||
|
else:
|
||||||
|
ref_model = self.ref_model
|
||||||
|
ref_context = nullcontext()
|
||||||
|
|
||||||
|
with ref_context:
|
||||||
|
(
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.concatenated_forward(ref_model, batch)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
)
|
||||||
|
if self.ftx_gamma > 1e-6:
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||||
|
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||||
|
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||||
|
|
||||||
|
return losses.mean(), metrics
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||||
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
from ...train.dpo.trainer import CustomDPOTrainer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
|
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_dpo(
|
def run_dpo(
|
||||||
@@ -22,38 +25,39 @@ def run_dpo(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create reference model
|
# Create reference model
|
||||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
ref_model = model
|
ref_model = model
|
||||||
else:
|
else:
|
||||||
ref_model = create_ref_model(model_args, finetuning_args)
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
|
loss_type=finetuning_args.dpo_loss,
|
||||||
|
ftx_gamma=finetuning_args.dpo_ftx,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
@@ -69,7 +73,7 @@ def run_dpo(
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
metrics.pop(key)
|
metrics.pop(key)
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.ppo.workflow import run_ppo
|
from .workflow import run_ppo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_ppo"]
|
||||||
|
|||||||
@@ -1,27 +1,28 @@
|
|||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
import torch
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from tqdm import tqdm
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -40,7 +41,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
reward_model: "AutoModelForCausalLMWithValueHead",
|
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
|
||||||
@@ -52,7 +53,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
**generating_args.to_dict()
|
**generating_args.to_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
@@ -61,7 +62,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.accelerator.state, "deepspeed_plugin"
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
)
|
)
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
@@ -71,7 +72,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if not (
|
if not (
|
||||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
): # quantized models are already set on the correct device
|
): # quantized models are already set on the correct device
|
||||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||||
else:
|
else:
|
||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
@@ -111,9 +112,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
logger.info(" Num examples = {}".format(num_examples))
|
logger.info(" Num examples = {}".format(num_examples))
|
||||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||||
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
logger.info(
|
||||||
total_train_batch_size
|
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||||
))
|
total_train_batch_size
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||||
logger.info(" Total training steps = {}".format(max_steps))
|
logger.info(" Total training steps = {}".format(max_steps))
|
||||||
@@ -138,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
queries, responses, rewards = [], [], []
|
queries, responses, rewards = [], [], []
|
||||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||||
|
batch[idx : idx + self.config.mini_batch_size]
|
||||||
|
)
|
||||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||||
queries.extend(mini_batch_queries)
|
queries.extend(mini_batch_queries)
|
||||||
responses.extend(mini_batch_responses)
|
responses.extend(mini_batch_responses)
|
||||||
@@ -154,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
stats = self.step(queries, responses, rewards)
|
stats = self.step(queries, responses, rewards)
|
||||||
self.tokenizer.padding_side = "left" # restore padding side
|
self.tokenizer.padding_side = "left" # restore padding side
|
||||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
@@ -163,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||||
self.log_stats(stats, batch, rewards)
|
self.log_stats(stats, batch, rewards)
|
||||||
except:
|
except Exception:
|
||||||
logger.warning("Failed to save stats due to unknown errors.")
|
logger.warning("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||||
logs = dict(
|
logs = dict(
|
||||||
loss=round(loss_meter.avg, 4),
|
loss=round(loss_meter.avg, 4),
|
||||||
reward=round(reward_meter.avg, 4),
|
reward=round(reward_meter.avg, 4),
|
||||||
learning_rate=stats["ppo/learning_rate"],
|
learning_rate=stats["ppo/learning_rate"],
|
||||||
epoch=round(step / steps_in_epoch, 2)
|
epoch=round(step / steps_in_epoch, 2),
|
||||||
)
|
)
|
||||||
tqdm.write(str(logs))
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
@@ -183,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
reward_meter.reset()
|
reward_meter.reset()
|
||||||
|
|
||||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||||
self.save_model(os.path.join(
|
self.save_model(
|
||||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||||
))
|
)
|
||||||
self.save_callback.on_save(
|
self.save_callback.on_save(
|
||||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||||
)
|
)
|
||||||
@@ -204,33 +209,36 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
if self.finetuning_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
|
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||||
|
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v[:, start_index:]
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||||
logits_processor=get_logits_processor(),
|
|
||||||
**batch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.finetuning_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query = batch["input_ids"].detach().cpu()
|
query = batch["input_ids"].detach().cpu()
|
||||||
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
|
|
||||||
if len(response_index) == 0:
|
if len(response_index) == 0:
|
||||||
response_length = 1 # allow empty response
|
response_length = 1 # allow empty response
|
||||||
else:
|
else:
|
||||||
response_length = response_index[-1].item() + 1
|
response_length = response_index[-1].item() + 1
|
||||||
|
|
||||||
queries.append(query[i, query_length:]) # remove padding from left
|
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||||
responses.append(response[i, :response_length]) # remove padding from right
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
return queries, responses
|
return queries, responses
|
||||||
|
|
||||||
@@ -239,7 +247,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self,
|
self,
|
||||||
queries: List[torch.Tensor],
|
queries: List[torch.Tensor],
|
||||||
responses: List[torch.Tensor],
|
responses: List[torch.Tensor],
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
unwrapped_model: "AutoModelForCausalLMWithValueHead",
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
@@ -259,17 +267,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
if self.finetuning_args.reward_model_type == "lora":
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
@@ -284,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
responses: torch.Tensor,
|
responses: torch.Tensor,
|
||||||
model_inputs: dict,
|
model_inputs: dict,
|
||||||
return_logits: Optional[bool] = False,
|
return_logits: Optional[bool] = False,
|
||||||
response_masks: Optional[torch.Tensor] = None
|
response_masks: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Calculates model outputs in multiple batches.
|
Calculates model outputs in multiple batches.
|
||||||
@@ -307,7 +315,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
input_ids = input_kwargs["input_ids"]
|
input_ids = input_kwargs["input_ids"]
|
||||||
attention_mask = input_kwargs["attention_mask"]
|
attention_mask = input_kwargs["attention_mask"]
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
logits, _, values = model(**input_kwargs)
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
@@ -320,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
for j in range(len(query_batch)):
|
for j in range(len(query_batch)):
|
||||||
start = len(query_batch[j]) - 1
|
start = len(query_batch[j]) - 1
|
||||||
if attention_mask[j, 0] == 0: # offset left padding
|
if attention_mask[j, 0] == 0: # offset left padding
|
||||||
start += attention_mask[j, :].nonzero()[0].item()
|
start += attention_mask[j, :].nonzero()[0].item()
|
||||||
end = start + len(response_batch[j])
|
end = start + len(response_batch[j])
|
||||||
|
|
||||||
if response_masks is not None:
|
if response_masks is not None:
|
||||||
response_masks_batch = torch.cat(
|
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
|
||||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
|
||||||
)[1:]
|
|
||||||
|
|
||||||
masks[j, :start] = 0
|
masks[j, :start] = 0
|
||||||
masks[j, end:] = 0
|
masks[j, end:] = 0
|
||||||
@@ -361,9 +367,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
self.model.save_checkpoint(output_dir) # wrapped model
|
self.model.save_checkpoint(output_dir)
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llmtuner.extras.packages import is_requests_available
|
import torch
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
@@ -21,16 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
|||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if is_deepspeed_zero3_enabled():
|
||||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
import deepspeed # type: ignore
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
|
||||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||||
model.v_head.load_state_dict({
|
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
else:
|
||||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
context_maybe_zero3 = nullcontext()
|
||||||
})
|
|
||||||
|
with context_maybe_zero3:
|
||||||
|
if target == "reward": # save default head temporarily
|
||||||
|
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||||
|
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||||
|
|
||||||
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
|
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
|
||||||
|
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -1,22 +1,26 @@
|
|||||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from trl import PPOConfig
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from trl import PPOConfig
|
||||||
|
|
||||||
|
from ...data import get_dataset
|
||||||
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...model import load_model_and_tokenizer
|
||||||
|
from ...train.ppo.trainer import CustomPPOTrainer
|
||||||
|
from ...train.utils import create_ref_model, create_reward_model
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset
|
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
|
||||||
from llmtuner.train.utils import create_ref_model, create_reward_model
|
|
||||||
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
@@ -25,13 +29,14 @@ def run_ppo(
|
|||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
model, tokenizer = load_model_and_tokenizer(
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
)
|
||||||
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
# Create reference model and reward model
|
# Create reference model and reward model
|
||||||
@@ -55,7 +60,8 @@ def run_ppo(
|
|||||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||||
|
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
@@ -70,7 +76,7 @@ def run_ppo(
|
|||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=num_training_steps
|
num_training_steps=num_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
@@ -79,7 +85,7 @@ def run_ppo(
|
|||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
reward_model=reward_model,
|
reward_model=reward_model,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -88,13 +94,15 @@ def run_ppo(
|
|||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler
|
lr_scheduler=lr_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
if training_args.should_save:
|
||||||
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.pt.workflow import run_pt
|
from .workflow import run_pt
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_pt"]
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
@@ -19,11 +22,10 @@ def run_pt(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
@@ -33,7 +35,7 @@ def run_pt(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.rm.workflow import run_rm
|
from .workflow import run_rm
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_rm"]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
|
||||||
@@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
|||||||
features = [
|
features = [
|
||||||
{
|
{
|
||||||
"input_ids": feature["prompt_ids"] + feature[key],
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
|
||||||
}
|
}
|
||||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
for key in ("chosen_ids", "rejected_ids")
|
||||||
|
for feature in features
|
||||||
]
|
]
|
||||||
return super().__call__(features)
|
return super().__call__(features)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
|
||||||
from typing import Dict, Sequence, Tuple, Union
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
preds, _ = eval_preds
|
preds, _ = eval_preds
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
||||||
model: "PreTrainedModel",
|
|
||||||
inputs: Dict[str, torch.Tensor],
|
|
||||||
return_outputs: Optional[bool] = False
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
r"""
|
r"""
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
@@ -68,9 +67,9 @@ class PairwiseTrainer(Trainer):
|
|||||||
assert div_index > 0
|
assert div_index > 0
|
||||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||||
if return_outputs: # use the score on the last token except pad token for inference
|
if return_outputs: # use the score on the last token except pad token for inference
|
||||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||||
|
|
||||||
loss = loss / batch_size
|
loss = loss / batch_size
|
||||||
@@ -80,10 +79,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
self,
|
|
||||||
predict_results: "PredictionOutput"
|
|
||||||
) -> None:
|
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,24 @@
|
|||||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.rm.metric import compute_accuracy
|
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.train.rm.trainer import PairwiseTrainer
|
from ...train.rm.metric import compute_accuracy
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.rm.trainer import PairwiseTrainer
|
||||||
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_rm(
|
def run_rm(
|
||||||
@@ -22,16 +26,17 @@ def run_rm(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
model, tokenizer = load_model_and_tokenizer(
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
)
|
||||||
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
@@ -40,15 +45,17 @@ def run_rm(
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
|
if training_args.should_save:
|
||||||
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.train.sft.workflow import run_sft
|
from .workflow import run_sft
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_sft"]
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import numpy as np
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
import numpy as np
|
||||||
from llmtuner.extras.packages import (
|
|
||||||
is_jieba_available, is_nltk_available, is_rouge_available
|
from ...extras.constants import IGNORE_INDEX
|
||||||
)
|
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
@@ -14,7 +14,7 @@ if is_jieba_available():
|
|||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
if is_nltk_available():
|
if is_nltk_available():
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||||
|
|
||||||
if is_rouge_available():
|
if is_rouge_available():
|
||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
@@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||||
if self.args.predict_with_generate:
|
if self.args.predict_with_generate:
|
||||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||||
if prompt_len > label_len:
|
if prompt_len > label_len:
|
||||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||||
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
||||||
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
||||||
|
|
||||||
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
if generated_tokens is not None and self.args.predict_with_generate:
|
if generated_tokens is not None and self.args.predict_with_generate:
|
||||||
@@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
return loss, generated_tokens, labels
|
return loss, generated_tokens, labels
|
||||||
|
|
||||||
def _pad_tensors_to_target_len(
|
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
src_tensor: torch.Tensor,
|
|
||||||
tgt_tensor: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
r"""
|
||||||
Pads the tensor to the same length as the target tensor.
|
Pads the tensor to the same length as the target tensor.
|
||||||
"""
|
"""
|
||||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||||
return padded_tensor.contiguous() # in contiguous memory
|
return padded_tensor.contiguous() # in contiguous memory
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||||
self,
|
|
||||||
predict_results: "PredictionOutput"
|
|
||||||
) -> None:
|
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|
||||||
@@ -79,15 +74,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||||
|
|
||||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
labels = np.where(
|
||||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
preds = np.where(
|
||||||
|
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(preds)):
|
for i in range(len(preds)):
|
||||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||||
if len(pad_len):
|
if len(pad_len):
|
||||||
preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), axis=-1) # move pad token to last
|
preds[i] = np.concatenate(
|
||||||
|
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
|
||||||
|
) # move pad token to last
|
||||||
|
|
||||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
decoded_labels = self.tokenizer.batch_decode(
|
||||||
|
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from ...extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from ...model import load_model_and_tokenizer
|
||||||
from llmtuner.train.sft.metric import ComputeMetrics
|
from ...train.sft.metric import ComputeMetrics
|
||||||
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
from llmtuner.train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
@@ -23,27 +26,31 @@ def run_sft(
|
|||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
|
|
||||||
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(
|
training_args_dict.update(
|
||||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
dict(
|
||||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
|
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||||
))
|
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
|
||||||
|
)
|
||||||
|
)
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
@@ -54,7 +61,7 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
@@ -76,7 +83,7 @@ def run_sft(
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||||
metrics.pop("eval_loss", None)
|
metrics.pop("eval_loss", None)
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -84,7 +91,7 @@ def run_sft(
|
|||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
import torch
|
||||||
from llmtuner.extras.logging import get_logger
|
from transformers import PreTrainedModel
|
||||||
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
|
|
||||||
from llmtuner.train.pt import run_pt
|
from ..extras.callbacks import LogCallback
|
||||||
from llmtuner.train.sft import run_sft
|
from ..extras.logging import get_logger
|
||||||
from llmtuner.train.rm import run_rm
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from llmtuner.train.ppo import run_ppo
|
from ..model import load_model_and_tokenizer
|
||||||
from llmtuner.train.dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
|
from .ppo import run_ppo
|
||||||
|
from .pt import run_pt
|
||||||
|
from .rm import run_rm
|
||||||
|
from .sft import run_sft
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
@@ -36,19 +41,49 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
|||||||
|
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
|
|
||||||
|
if model_args.export_dir is None:
|
||||||
|
raise ValueError("Please specify `export_dir`.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||||
|
raise ValueError("Please merge adapters before quantizing the model.")
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None) in ["gptq", "awq"]:
|
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
|
||||||
raise ValueError("Cannot export a GPTQ or AWQ quantized model.")
|
raise ValueError("Cannot merge adapters to a quantized model.")
|
||||||
|
|
||||||
model.config.use_cache = True
|
if not isinstance(model, PreTrainedModel):
|
||||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||||
|
|
||||||
|
if getattr(model, "quantization_method", None):
|
||||||
|
model = model.to("cpu")
|
||||||
|
elif hasattr(model.config, "torch_dtype"):
|
||||||
|
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
|
||||||
|
else:
|
||||||
|
model = model.to(torch.float16).to("cpu")
|
||||||
|
setattr(model.config, "torch_dtype", torch.float16)
|
||||||
|
|
||||||
|
model.save_pretrained(
|
||||||
|
save_directory=model_args.export_dir,
|
||||||
|
max_shard_size="{}GB".format(model_args.export_size),
|
||||||
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
|
)
|
||||||
|
if model_args.export_hub_model_id is not None:
|
||||||
|
model.push_to_hub(
|
||||||
|
model_args.export_hub_model_id,
|
||||||
|
token=model_args.hf_hub_token,
|
||||||
|
max_shard_size="{}GB".format(model_args.export_size),
|
||||||
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(finetuning_args.export_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
except:
|
if model_args.export_hub_model_id is not None:
|
||||||
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
except Exception:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import torch
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
import torch
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|
||||||
from llmtuner.model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
from ..extras.logging import get_logger
|
||||||
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import DataArguments
|
|
||||||
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -20,22 +23,24 @@ def create_modelcard_and_push(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments"
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
if training_args.do_train:
|
kwargs = {
|
||||||
if training_args.push_to_hub:
|
"tasks": "text-generation",
|
||||||
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
return
|
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||||
try:
|
"tags": ["llama-factory", finetuning_args.finetuning_type],
|
||||||
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
|
}
|
||||||
except Exception as err:
|
if not training_args.do_train:
|
||||||
logger.warning("Failed to create model card: {}".format(str(err)))
|
pass
|
||||||
|
elif training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**kwargs)
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(
|
def create_ref_model(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
add_valuehead: Optional[bool] = False
|
|
||||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||||
r"""
|
r"""
|
||||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||||
@@ -44,11 +49,13 @@ def create_ref_model(
|
|||||||
"""
|
"""
|
||||||
if finetuning_args.ref_model is not None:
|
if finetuning_args.ref_model is not None:
|
||||||
ref_model_args_dict = model_args.to_dict()
|
ref_model_args_dict = model_args.to_dict()
|
||||||
ref_model_args_dict.update(dict(
|
ref_model_args_dict.update(
|
||||||
model_name_or_path=finetuning_args.ref_model,
|
dict(
|
||||||
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
model_name_or_path=finetuning_args.ref_model,
|
||||||
quantization_bit=finetuning_args.ref_model_quantization_bit
|
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||||
))
|
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||||
|
)
|
||||||
|
)
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
ref_model, _ = load_model_and_tokenizer(
|
ref_model, _ = load_model_and_tokenizer(
|
||||||
@@ -68,9 +75,7 @@ def create_ref_model(
|
|||||||
|
|
||||||
|
|
||||||
def create_reward_model(
|
def create_reward_model(
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments"
|
|
||||||
) -> "AutoModelForCausalLMWithValueHead":
|
) -> "AutoModelForCausalLMWithValueHead":
|
||||||
r"""
|
r"""
|
||||||
Creates reward model for PPO training.
|
Creates reward model for PPO training.
|
||||||
@@ -81,24 +86,30 @@ def create_reward_model(
|
|||||||
return finetuning_args.reward_model
|
return finetuning_args.reward_model
|
||||||
elif finetuning_args.reward_model_type == "lora":
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
if "default" in name:
|
if "default" in name:
|
||||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||||
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
||||||
assert vhead_params is not None, "Reward model is not correctly loaded."
|
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
model.register_buffer(
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
|
||||||
|
)
|
||||||
|
model.register_buffer(
|
||||||
|
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||||
|
)
|
||||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
reward_model_args_dict = model_args.to_dict()
|
reward_model_args_dict = model_args.to_dict()
|
||||||
reward_model_args_dict.update(dict(
|
reward_model_args_dict.update(
|
||||||
model_name_or_path=finetuning_args.reward_model,
|
dict(
|
||||||
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
model_name_or_path=finetuning_args.reward_model,
|
||||||
quantization_bit=finetuning_args.reward_model_quantization_bit
|
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||||
))
|
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||||
|
)
|
||||||
|
)
|
||||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
reward_model, _ = load_model_and_tokenizer(
|
reward_model, _ = load_model_and_tokenizer(
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from llmtuner.webui.interface import create_ui, create_web_demo
|
from .interface import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["create_ui", "create_web_demo"]
|
||||||
|
|||||||
@@ -1,24 +1,24 @@
|
|||||||
import gradio as gr
|
import json
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
|
||||||
|
import gradio as gr
|
||||||
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
|
|
||||||
|
from ..chat import ChatModel
|
||||||
|
from ..data import Role
|
||||||
|
from ..extras.misc import torch_gc
|
||||||
|
from ..hparams import GeneratingArguments
|
||||||
|
from .common import get_save_dir
|
||||||
|
from .locales import ALERTS
|
||||||
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
|
||||||
from llmtuner.hparams import GeneratingArguments
|
|
||||||
from llmtuner.webui.common import get_save_dir
|
|
||||||
from llmtuner.webui.locales import ALERTS
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.webui.manager import Manager
|
from .manager import Manager
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||||
manager: "Manager",
|
|
||||||
demo_mode: Optional[bool] = False,
|
|
||||||
lazy_init: Optional[bool] = True
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
@@ -26,11 +26,12 @@ class WebChatModel(ChatModel):
|
|||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
|
|
||||||
if not lazy_init: # read arguments from command line
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if demo_mode: # load demo_config.json if exists
|
if demo_mode: # load demo_config.json if exists
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
with open("demo_config.json", "r", encoding="utf-8") as f:
|
||||||
args = json.load(f)
|
args = json.load(f)
|
||||||
@@ -38,7 +39,7 @@ class WebChatModel(ChatModel):
|
|||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print("Please provided model name and template in `demo_config.json`.")
|
print("Please provided model name and template in `demo_config.json`.")
|
||||||
except:
|
except Exception:
|
||||||
print("Cannot find `demo_config.json` at current directory.")
|
print("Cannot find `demo_config.json` at current directory.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -63,24 +64,26 @@ class WebChatModel(ChatModel):
|
|||||||
yield error
|
yield error
|
||||||
return
|
return
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.adapter_path"):
|
||||||
checkpoint_dir = ",".join([
|
adapter_name_or_path = ",".join(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
[
|
||||||
])
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||||
|
for adapter in get("top.adapter_path")
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
yield ALERTS["info_loading"][lang]
|
yield ALERTS["info_loading"][lang]
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
checkpoint_dir=checkpoint_dir,
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
system_prompt=get("top.system_prompt"),
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
flash_attn=get("top.flash_attn"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
shift_attn=get("top.shift_attn"),
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
|
||||||
)
|
)
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
@@ -103,22 +106,39 @@ class WebChatModel(ChatModel):
|
|||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
chatbot: List[Tuple[str, str]],
|
chatbot: List[Tuple[str, str]],
|
||||||
|
role: str,
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Tuple[str, str]],
|
messages: Sequence[Tuple[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
|
tools: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float
|
temperature: float,
|
||||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
|
query_messages = messages + [{"role": role, "content": query}]
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
new_history = history + [(query, response)]
|
if tools:
|
||||||
chatbot[-1] = [query, self.postprocess(response)]
|
result = self.template.format_tools.extract(response)
|
||||||
yield chatbot, new_history
|
else:
|
||||||
|
result = response
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
name, arguments = result
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||||
|
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||||
|
bot_text = "```json\n" + tool_call + "\n```"
|
||||||
|
else:
|
||||||
|
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||||
|
bot_text = result
|
||||||
|
|
||||||
|
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||||
|
yield chatbot, output_messages
|
||||||
|
|
||||||
def postprocess(self, response: str) -> str:
|
def postprocess(self, response: str) -> str:
|
||||||
blocks = response.split("```")
|
blocks = response.split("```")
|
||||||
|
|||||||
@@ -1,39 +1,28 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from transformers.utils import (
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
ADAPTER_WEIGHTS_NAME,
|
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import (
|
import gradio as gr
|
||||||
|
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
from ..extras.constants import (
|
||||||
|
DATA_CONFIG,
|
||||||
DEFAULT_MODULE,
|
DEFAULT_MODULE,
|
||||||
DEFAULT_TEMPLATE,
|
DEFAULT_TEMPLATE,
|
||||||
|
PEFT_METHODS,
|
||||||
SUPPORTED_MODELS,
|
SUPPORTED_MODELS,
|
||||||
TRAINING_STAGES,
|
TRAINING_STAGES,
|
||||||
DownloadSource
|
DownloadSource,
|
||||||
)
|
)
|
||||||
from llmtuner.extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
|
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
DEFAULT_SAVE_DIR = "saves"
|
DEFAULT_SAVE_DIR = "saves"
|
||||||
USER_CONFIG = "user.config"
|
USER_CONFIG = "user.config"
|
||||||
CKPT_NAMES = [
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
ADAPTER_WEIGHTS_NAME,
|
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_save_dir(*args) -> os.PathLike:
|
def get_save_dir(*args) -> os.PathLike:
|
||||||
@@ -48,7 +37,7 @@ def load_config() -> Dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except Exception:
|
||||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||||
|
|
||||||
|
|
||||||
@@ -65,13 +54,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
|||||||
|
|
||||||
def get_model_path(model_name: str) -> str:
|
def get_model_path(model_name: str) -> str:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, [])
|
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "")
|
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
||||||
if (
|
if (
|
||||||
use_modelscope()
|
use_modelscope()
|
||||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
): # replace path
|
): # replace path
|
||||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
@@ -90,18 +79,20 @@ def get_template(model_name: str) -> str:
|
|||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||||
checkpoints = []
|
if finetuning_type not in PEFT_METHODS:
|
||||||
if model_name:
|
return gr.update(value=[], choices=[], interactive=False)
|
||||||
|
|
||||||
|
adapters = []
|
||||||
|
if model_name and finetuning_type == "lora":
|
||||||
save_dir = get_save_dir(model_name, finetuning_type)
|
save_dir = get_save_dir(model_name, finetuning_type)
|
||||||
if save_dir and os.path.isdir(save_dir):
|
if save_dir and os.path.isdir(save_dir):
|
||||||
for checkpoint in os.listdir(save_dir):
|
for adapter in os.listdir(save_dir):
|
||||||
if (
|
if os.path.isdir(os.path.join(save_dir, adapter)) and any(
|
||||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
||||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
|
||||||
):
|
):
|
||||||
checkpoints.append(checkpoint)
|
adapters.append(adapter)
|
||||||
return gr.update(value=[], choices=checkpoints)
|
return gr.update(value=[], choices=adapters, interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
|||||||
@@ -1,6 +1,16 @@
|
|||||||
from llmtuner.webui.components.top import create_top
|
from .chatbot import create_chat_box
|
||||||
from llmtuner.webui.components.train import create_train_tab
|
from .eval import create_eval_tab
|
||||||
from llmtuner.webui.components.eval import create_eval_tab
|
from .export import create_export_tab
|
||||||
from llmtuner.webui.components.infer import create_infer_tab
|
from .infer import create_infer_tab
|
||||||
from llmtuner.webui.components.export import create_export_tab
|
from .top import create_top
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
from .train import create_train_tab
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_chat_box",
|
||||||
|
"create_eval_tab",
|
||||||
|
"create_export_tab",
|
||||||
|
"create_infer_tab",
|
||||||
|
"create_top",
|
||||||
|
"create_train_tab",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,49 +1,63 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from ...data import Role
|
||||||
|
from ..utils import check_json_schema
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.blocks import Block
|
from gradio.blocks import Block
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_chat_box(
|
def create_chat_box(
|
||||||
engine: "Engine",
|
engine: "Engine", visible: Optional[bool] = False
|
||||||
visible: Optional[bool] = False
|
|
||||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||||
with gr.Box(visible=visible) as chat_box:
|
with gr.Box(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
history = gr.State([])
|
messages = gr.State([])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
|
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||||
system = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
|
tools = gr.Textbox(show_label=False, lines=2)
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
submit_btn = gr.Button(variant="primary")
|
submit_btn = gr.Button(variant="primary")
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
clear_btn = gr.Button()
|
|
||||||
gen_kwargs = engine.chatter.generating_args
|
gen_kwargs = engine.chatter.generating_args
|
||||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||||
|
clear_btn = gr.Button()
|
||||||
|
|
||||||
|
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
engine.chatter.predict,
|
engine.chatter.predict,
|
||||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, messages],
|
||||||
show_progress=True
|
show_progress=True,
|
||||||
).then(
|
).then(lambda: gr.update(value=""), outputs=[query])
|
||||||
lambda: gr.update(value=""), outputs=[query]
|
|
||||||
)
|
|
||||||
|
|
||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||||
|
|
||||||
return chat_box, chatbot, history, dict(
|
return (
|
||||||
system=system,
|
chat_box,
|
||||||
query=query,
|
chatbot,
|
||||||
submit_btn=submit_btn,
|
messages,
|
||||||
clear_btn=clear_btn,
|
dict(
|
||||||
max_new_tokens=max_new_tokens,
|
role=role,
|
||||||
top_p=top_p,
|
system=system,
|
||||||
temperature=temperature
|
tools=tools,
|
||||||
|
query=query,
|
||||||
|
submit_btn=submit_btn,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
clear_btn=clear_btn,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
|
|
||||||
from llmtuner.webui.common import DATA_CONFIG
|
import gradio as gr
|
||||||
|
|
||||||
|
from ...extras.constants import DATA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@@ -21,8 +23,11 @@ def next_page(page_index: int, total_num: int) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
try:
|
||||||
dataset_info = json.load(f)
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return gr.update(interactive=False)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(dataset) > 0
|
len(dataset) > 0
|
||||||
@@ -45,7 +50,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
|||||||
elif data_file.endswith(".jsonl"):
|
elif data_file.endswith(".jsonl"):
|
||||||
data = [json.loads(line) for line in f]
|
data = [json.loads(line) for line in f]
|
||||||
else:
|
else:
|
||||||
data = [line for line in f]
|
data = [line for line in f] # noqa: C416
|
||||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,32 +69,17 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
preview_samples = gr.JSON(interactive=False)
|
preview_samples = gr.JSON(interactive=False)
|
||||||
|
|
||||||
dataset.change(
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||||
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
|
||||||
).then(
|
|
||||||
lambda: 0, outputs=[page_index], queue=False
|
lambda: 0, outputs=[page_index], queue=False
|
||||||
)
|
)
|
||||||
data_preview_btn.click(
|
data_preview_btn.click(
|
||||||
get_preview,
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
prev_btn.click(
|
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
||||||
prev_page, [page_index], [page_index], queue=False
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
).then(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
next_btn.click(
|
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||||
next_page, [page_index, preview_count], [page_index], queue=False
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||||
).then(
|
|
||||||
get_preview,
|
|
||||||
[dataset_dir, dataset, page_index],
|
|
||||||
[preview_count, preview_samples, preview_box],
|
|
||||||
queue=False
|
|
||||||
)
|
)
|
||||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||||
return dict(
|
return dict(
|
||||||
@@ -99,5 +89,5 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
|||||||
prev_btn=prev_btn,
|
prev_btn=prev_btn,
|
||||||
next_btn=next_btn,
|
next_btn=next_btn,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
preview_samples=preview_samples
|
preview_samples=preview_samples,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
import gradio as gr
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
|
||||||
|
from ..common import DEFAULT_DATA_DIR, list_dataset
|
||||||
|
from .data import create_preview_box
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
@@ -30,9 +33,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
predict = gr.Checkbox(value=True)
|
predict = gr.Checkbox(value=True)
|
||||||
|
|
||||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||||
@@ -41,9 +42,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
output_dir = gr.Textbox()
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
|
||||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
@@ -58,10 +57,16 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
output_box = gr.Markdown()
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
output_elems = [output_box, process_bar]
|
output_elems = [output_box, process_bar]
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
dict(
|
||||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
))
|
start_btn=start_btn,
|
||||||
|
stop_btn=stop_btn,
|
||||||
|
resume_btn=resume_btn,
|
||||||
|
process_bar=process_bar,
|
||||||
|
output_box=output_box,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||||
|
|||||||
@@ -1,47 +1,68 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||||
|
|
||||||
from llmtuner.train import export_model
|
import gradio as gr
|
||||||
from llmtuner.webui.common import get_save_dir
|
|
||||||
from llmtuner.webui.locales import ALERTS
|
from ...train import export_model
|
||||||
|
from ..common import get_save_dir
|
||||||
|
from ..locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||||
|
|
||||||
|
|
||||||
def save_model(
|
def save_model(
|
||||||
lang: str,
|
lang: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
checkpoints: List[str],
|
adapter_path: List[str],
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
template: str,
|
template: str,
|
||||||
max_shard_size: int,
|
max_shard_size: int,
|
||||||
export_dir: str
|
export_quantization_bit: int,
|
||||||
|
export_quantization_dataset: str,
|
||||||
|
export_legacy_format: bool,
|
||||||
|
export_dir: str,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
error = ""
|
error = ""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
error = ALERTS["err_no_model"][lang]
|
error = ALERTS["err_no_model"][lang]
|
||||||
elif not model_path:
|
elif not model_path:
|
||||||
error = ALERTS["err_no_path"][lang]
|
error = ALERTS["err_no_path"][lang]
|
||||||
elif not checkpoints:
|
|
||||||
error = ALERTS["err_no_checkpoint"][lang]
|
|
||||||
elif not export_dir:
|
elif not export_dir:
|
||||||
error = ALERTS["err_no_export_dir"][lang]
|
error = ALERTS["err_no_export_dir"][lang]
|
||||||
|
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
|
||||||
|
error = ALERTS["err_no_dataset"][lang]
|
||||||
|
elif export_quantization_bit not in GPTQ_BITS and not adapter_path:
|
||||||
|
error = ALERTS["err_no_adapter"][lang]
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error
|
yield error
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if adapter_path:
|
||||||
|
adapter_name_or_path = ",".join(
|
||||||
|
[get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
adapter_name_or_path = None
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_path,
|
model_name_or_path=model_path,
|
||||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
export_dir=export_dir,
|
export_dir=export_dir,
|
||||||
export_size=max_shard_size
|
export_size=max_shard_size,
|
||||||
|
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||||
|
export_quantization_dataset=export_quantization_dataset,
|
||||||
|
export_legacy_format=export_legacy_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
@@ -51,9 +72,12 @@ def save_model(
|
|||||||
|
|
||||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
export_dir = gr.Textbox()
|
|
||||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||||
|
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||||
|
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||||
|
export_legacy_format = gr.Checkbox()
|
||||||
|
|
||||||
|
export_dir = gr.Textbox()
|
||||||
export_btn = gr.Button()
|
export_btn = gr.Button()
|
||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
|
||||||
@@ -63,18 +87,24 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
engine.manager.get_elem_by_name("top.lang"),
|
engine.manager.get_elem_by_name("top.lang"),
|
||||||
engine.manager.get_elem_by_name("top.model_name"),
|
engine.manager.get_elem_by_name("top.model_name"),
|
||||||
engine.manager.get_elem_by_name("top.model_path"),
|
engine.manager.get_elem_by_name("top.model_path"),
|
||||||
engine.manager.get_elem_by_name("top.checkpoints"),
|
engine.manager.get_elem_by_name("top.adapter_path"),
|
||||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||||
engine.manager.get_elem_by_name("top.template"),
|
engine.manager.get_elem_by_name("top.template"),
|
||||||
max_shard_size,
|
max_shard_size,
|
||||||
export_dir
|
export_quantization_bit,
|
||||||
|
export_quantization_dataset,
|
||||||
|
export_legacy_format,
|
||||||
|
export_dir,
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box],
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
export_dir=export_dir,
|
|
||||||
max_shard_size=max_shard_size,
|
max_shard_size=max_shard_size,
|
||||||
|
export_quantization_bit=export_quantization_bit,
|
||||||
|
export_quantization_dataset=export_quantization_dataset,
|
||||||
|
export_legacy_format=export_legacy_format,
|
||||||
|
export_dir=export_dir,
|
||||||
export_btn=export_btn,
|
export_btn=export_btn,
|
||||||
info_box=info_box
|
info_box=info_box,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
import gradio as gr
|
||||||
|
|
||||||
|
from .chatbot import create_chat_box
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
@@ -22,18 +25,12 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||||
|
|
||||||
load_btn.click(
|
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||||
engine.chatter.load_model, input_elems, [info_box]
|
|
||||||
).then(
|
|
||||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||||
)
|
)
|
||||||
|
|
||||||
unload_btn.click(
|
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||||
engine.chatter.unload_model, input_elems, [info_box]
|
|
||||||
).then(
|
|
||||||
lambda: ([], []), outputs=[chatbot, history]
|
lambda: ([], []), outputs=[chatbot, history]
|
||||||
).then(
|
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
|
||||||
)
|
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from llmtuner.data.template import templates
|
import gradio as gr
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
|
||||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
from ...data import templates
|
||||||
from llmtuner.webui.utils import can_quantize
|
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
|
from ..common import get_model_path, get_template, list_adapters, save_config
|
||||||
|
from ..utils import can_quantize
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@@ -14,61 +16,44 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
|
||||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
adapter_path = gr.Dropdown(multiselect=True, scale=5, allow_custom_value=True)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||||
system_prompt = gr.Textbox(scale=2)
|
|
||||||
|
|
||||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
flash_attn = gr.Checkbox(value=False)
|
|
||||||
shift_attn = gr.Checkbox(value=False)
|
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||||
|
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
|
||||||
).then(
|
|
||||||
get_model_path, [model_name], [model_path], queue=False
|
get_model_path, [model_name], [model_path], queue=False
|
||||||
).then(
|
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
||||||
get_template, [model_name], [template], queue=False
|
|
||||||
) # do not save config since the below line will save
|
|
||||||
|
|
||||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
|
||||||
).then(
|
|
||||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_btn.click(
|
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
checkpoints=checkpoints,
|
adapter_path=adapter_path,
|
||||||
refresh_btn=refresh_btn,
|
refresh_btn=refresh_btn,
|
||||||
advanced_tab=advanced_tab,
|
advanced_tab=advanced_tab,
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
system_prompt=system_prompt,
|
rope_scaling=rope_scaling,
|
||||||
llama_tab=llama_tab,
|
booster=booster,
|
||||||
flash_attn=flash_attn,
|
|
||||||
shift_attn=shift_attn,
|
|
||||||
rope_scaling=rope_scaling
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import gradio as gr
|
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
from transformers.trainer_utils import SchedulerType
|
from transformers.trainer_utils import SchedulerType
|
||||||
|
|
||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from ...extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from ..components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import gen_plot
|
from ..utils import gen_plot
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
from llmtuner.webui.engine import Engine
|
|
||||||
|
from ..engine import Engine
|
||||||
|
|
||||||
|
|
||||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
@@ -28,84 +31,143 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
|
||||||
))
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||||
learning_rate = gr.Textbox(value="5e-5")
|
learning_rate = gr.Textbox(value="5e-5")
|
||||||
num_train_epochs = gr.Textbox(value="3.0")
|
num_train_epochs = gr.Textbox(value="3.0")
|
||||||
max_samples = gr.Textbox(value="100000")
|
max_samples = gr.Textbox(value="100000")
|
||||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
|
||||||
|
|
||||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
dict(
|
||||||
max_samples=max_samples, compute_type=compute_type
|
cutoff_len=cutoff_len,
|
||||||
))
|
learning_rate=learning_rate,
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
max_samples=max_samples,
|
||||||
|
compute_type=compute_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||||
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
|
||||||
)
|
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
|
|
||||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
dict(
|
||||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
batch_size=batch_size,
|
||||||
))
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
|
val_size=val_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||||
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Row():
|
||||||
train_on_prompt = gr.Checkbox(value=False)
|
resize_vocab = gr.Checkbox()
|
||||||
upcast_layernorm = gr.Checkbox(value=False)
|
sft_packing = gr.Checkbox()
|
||||||
|
upcast_layernorm = gr.Checkbox()
|
||||||
|
use_llama_pro = gr.Checkbox()
|
||||||
|
|
||||||
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
|
input_elems.update(
|
||||||
elem_dict.update(dict(
|
{
|
||||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
logging_steps,
|
||||||
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
save_steps,
|
||||||
))
|
warmup_steps,
|
||||||
|
neftune_alpha,
|
||||||
|
resize_vocab,
|
||||||
|
sft_packing,
|
||||||
|
upcast_layernorm,
|
||||||
|
use_llama_pro,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elem_dict.update(
|
||||||
|
dict(
|
||||||
|
extra_tab=extra_tab,
|
||||||
|
logging_steps=logging_steps,
|
||||||
|
save_steps=save_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
neftune_alpha=neftune_alpha,
|
||||||
|
resize_vocab=resize_vocab,
|
||||||
|
sft_packing=sft_packing,
|
||||||
|
upcast_layernorm=upcast_layernorm,
|
||||||
|
use_llama_pro=use_llama_pro,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||||
|
with gr.Row():
|
||||||
|
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||||
|
name_module_trainable = gr.Textbox(scale=3)
|
||||||
|
|
||||||
|
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||||
|
elem_dict.update(
|
||||||
|
dict(
|
||||||
|
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||||
|
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
|
||||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
lora_target = gr.Textbox(scale=1)
|
lora_target = gr.Textbox(scale=2)
|
||||||
additional_target = gr.Textbox(scale=1)
|
|
||||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
|
||||||
|
|
||||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
|
with gr.Row():
|
||||||
elem_dict.update(dict(
|
use_rslora = gr.Checkbox(scale=1)
|
||||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
use_dora = gr.Checkbox(scale=1)
|
||||||
additional_target=additional_target, resume_lora_training=resume_lora_training,
|
create_new_adapter = gr.Checkbox(scale=1)
|
||||||
))
|
additional_target = gr.Textbox(scale=2)
|
||||||
|
|
||||||
|
input_elems.update(
|
||||||
|
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
|
||||||
|
)
|
||||||
|
elem_dict.update(
|
||||||
|
dict(
|
||||||
|
lora_tab=lora_tab,
|
||||||
|
lora_rank=lora_rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
lora_target=lora_target,
|
||||||
|
use_rslora=use_rslora,
|
||||||
|
use_dora=use_dora,
|
||||||
|
create_new_adapter=create_new_adapter,
|
||||||
|
additional_target=additional_target,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
reward_model = gr.Dropdown(scale=3)
|
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||||
|
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
refresh_btn.click(
|
refresh_btn.click(
|
||||||
list_checkpoint,
|
list_adapters,
|
||||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||||
[reward_model],
|
[reward_model],
|
||||||
queue=False
|
queue=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_elems.update({dpo_beta, reward_model})
|
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
|
elem_dict.update(
|
||||||
|
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
@@ -118,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
output_dir = gr.Textbox()
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||||
process_bar = gr.Slider(visible=False, interactive=False)
|
process_bar = gr.Slider(visible=False, interactive=False)
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
@@ -135,20 +197,28 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||||
|
|
||||||
elem_dict.update(dict(
|
elem_dict.update(
|
||||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
dict(
|
||||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
))
|
start_btn=start_btn,
|
||||||
|
stop_btn=stop_btn,
|
||||||
|
output_dir=output_dir,
|
||||||
|
resume_btn=resume_btn,
|
||||||
|
process_bar=process_bar,
|
||||||
|
output_box=output_box,
|
||||||
|
loss_viewer=loss_viewer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
output_box.change(
|
output_box.change(
|
||||||
gen_plot,
|
gen_plot,
|
||||||
[
|
[
|
||||||
engine.manager.get_elem_by_name("top.model_name"),
|
engine.manager.get_elem_by_name("top.model_name"),
|
||||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||||
output_dir
|
output_dir,
|
||||||
],
|
],
|
||||||
loss_viewer,
|
loss_viewer,
|
||||||
queue=False
|
queue=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user