Compare commits
267 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab67528e89 | ||
|
|
27f281480a | ||
|
|
50459a39f4 | ||
|
|
5c9815ef6f | ||
|
|
aed00a97b6 | ||
|
|
7543dc4a9d | ||
|
|
841fa0030f | ||
|
|
66e0e651b9 | ||
|
|
1750218057 | ||
|
|
80637fc06d | ||
|
|
8efc055511 | ||
|
|
be61bfda93 | ||
|
|
1a39f529c0 | ||
|
|
0868d5c550 | ||
|
|
384f0e7678 | ||
|
|
9b390c4bea | ||
|
|
42a13fec46 | ||
|
|
790acc4c17 | ||
|
|
b74cf27538 | ||
|
|
ffc874ec6f | ||
|
|
546d6bd0b2 | ||
|
|
8b68ca029e | ||
|
|
502f84b30c | ||
|
|
b7df920860 | ||
|
|
e4a424cb6a | ||
|
|
d8affd3967 | ||
|
|
a423274fd9 | ||
|
|
f7329b1a0e | ||
|
|
48eb07c956 | ||
|
|
636d8a886c | ||
|
|
97b52c7fdf | ||
|
|
344412e66e | ||
|
|
5cdea14cdf | ||
|
|
7b1a56b96f | ||
|
|
d1ec884e75 | ||
|
|
aa72a4349e | ||
|
|
5ab7fd0842 | ||
|
|
86d5e9802a | ||
|
|
18df39e3a1 | ||
|
|
cfe1e24471 | ||
|
|
2edbe87a8c | ||
|
|
880055bc90 | ||
|
|
ad99bd0a14 | ||
|
|
c5f099138d | ||
|
|
6e64e02f71 | ||
|
|
f95f6ec009 | ||
|
|
8aeecc20e1 | ||
|
|
38d0f6c63f | ||
|
|
ac8534a9e7 | ||
|
|
73cab9d9d4 | ||
|
|
64246d42d2 | ||
|
|
6fa6d4532e | ||
|
|
92b9956c06 | ||
|
|
4d6669c268 | ||
|
|
89f4ae51f9 | ||
|
|
af0659f573 | ||
|
|
45a10d501e | ||
|
|
e529ff1245 | ||
|
|
b29371dc87 | ||
|
|
0bef890000 | ||
|
|
75fe1404b1 | ||
|
|
b460c9372f | ||
|
|
c3e574ceaa | ||
|
|
04ae80a52e | ||
|
|
a7ff095399 | ||
|
|
a655dcebaf | ||
|
|
8c74851b70 | ||
|
|
7168392a51 | ||
|
|
ccc5b324fe | ||
|
|
e85c205a81 | ||
|
|
7e225be16e | ||
|
|
ebb32e85f8 | ||
|
|
90d279f39f | ||
|
|
af3f5b6e16 | ||
|
|
53d7c5109f | ||
|
|
bf381563ff | ||
|
|
de4b9334e1 | ||
|
|
c33fbea469 | ||
|
|
921f593632 | ||
|
|
940403720a | ||
|
|
f869e44fe5 | ||
|
|
bcc92919a0 | ||
|
|
306a70c7ba | ||
|
|
d358d955e5 | ||
|
|
0fdd6074c3 | ||
|
|
6faf9c35a9 | ||
|
|
1066898e32 | ||
|
|
d05febe5de | ||
|
|
67f7034a21 | ||
|
|
79f301a2c6 | ||
|
|
31cbc67986 | ||
|
|
fe66bf3663 | ||
|
|
4691d4b35d | ||
|
|
acf5241845 | ||
|
|
2bce99b82f | ||
|
|
3c330869ef | ||
|
|
dba1af4841 | ||
|
|
2b1e52dcc9 | ||
|
|
b5238e945a | ||
|
|
afc0f29704 | ||
|
|
de0bb1d2da | ||
|
|
cc16ece283 | ||
|
|
31ba802fc9 | ||
|
|
4b27cf5460 | ||
|
|
a53b2a643f | ||
|
|
d925ecae1b | ||
|
|
13fd751a78 | ||
|
|
74575f8922 | ||
|
|
5e7bb5fe73 | ||
|
|
790a31404a | ||
|
|
f927601702 | ||
|
|
c4654d54d7 | ||
|
|
df777c30d1 | ||
|
|
d81ad2d4bc | ||
|
|
9f77e8b025 | ||
|
|
04dc3f4614 | ||
|
|
7d1fe50977 | ||
|
|
c0e5e3c5d5 | ||
|
|
3a45cfb604 | ||
|
|
393e4b0f5a | ||
|
|
296711d502 | ||
|
|
9121722999 | ||
|
|
d8d74091f6 | ||
|
|
33521fb45e | ||
|
|
e5204e60ed | ||
|
|
0409428d87 | ||
|
|
f902b0d420 | ||
|
|
27ef5b1aa7 | ||
|
|
c32303fc7e | ||
|
|
45abe361ba | ||
|
|
3ae479faae | ||
|
|
5698038f49 | ||
|
|
020233f725 | ||
|
|
6f9d55b8eb | ||
|
|
2542b62d77 | ||
|
|
95678bb6b1 | ||
|
|
a78759e7ee | ||
|
|
cc5c523f58 | ||
|
|
e39bbdd287 | ||
|
|
d9a50bf93f | ||
|
|
934d00ea1e | ||
|
|
c27675f70d | ||
|
|
7c9f37c83d | ||
|
|
b9736c13e0 | ||
|
|
c47725ff34 | ||
|
|
3ee3fe0bbb | ||
|
|
e54dad75da | ||
|
|
39c2f03eab | ||
|
|
fb9e1c4087 | ||
|
|
ed26bb3d82 | ||
|
|
0baf32e219 | ||
|
|
79a376d1db | ||
|
|
b634e91c43 | ||
|
|
9e2cc21d04 | ||
|
|
6975124a57 | ||
|
|
9f69307db1 | ||
|
|
c3448a045c | ||
|
|
95c561983c | ||
|
|
7a03c8dab5 | ||
|
|
f3ffa8310f | ||
|
|
596f496f19 | ||
|
|
2e6ed731cf | ||
|
|
24ce319b6f | ||
|
|
7b7bfea37d | ||
|
|
3be461260a | ||
|
|
8dab8d9831 | ||
|
|
fb4c5f3c91 | ||
|
|
5fe3cce5a3 | ||
|
|
09f165d442 | ||
|
|
60aea7521b | ||
|
|
29545d0e5e | ||
|
|
4a14099cfd | ||
|
|
b052574ddf | ||
|
|
5ea6a7c6d6 | ||
|
|
8ca196d51f | ||
|
|
5f572cbd77 | ||
|
|
679bd3ab30 | ||
|
|
da3d59fada | ||
|
|
835d27151d | ||
|
|
f1d7228a74 | ||
|
|
72bbd5bdef | ||
|
|
ad9d866547 | ||
|
|
a1ec668b70 | ||
|
|
389687a56d | ||
|
|
97280c73b9 | ||
|
|
f3c622b665 | ||
|
|
d71e8d8dbf | ||
|
|
02c2089ac8 | ||
|
|
07ad28a053 | ||
|
|
d323ccc3ec | ||
|
|
4738d002c7 | ||
|
|
ec099b0586 | ||
|
|
a51253fea2 | ||
|
|
304ec9ec6a | ||
|
|
8547085615 | ||
|
|
14b139ecb5 | ||
|
|
7b45f5068f | ||
|
|
99ceee840e | ||
|
|
8ed68301e3 | ||
|
|
664267e050 | ||
|
|
7ef8f46591 | ||
|
|
6933c1fed2 | ||
|
|
9d125bf533 | ||
|
|
08d5340bd8 | ||
|
|
0e6f4f981e | ||
|
|
670ee3934f | ||
|
|
569860d7ac | ||
|
|
953a562ec1 | ||
|
|
7f54008d3c | ||
|
|
5f5959bc33 | ||
|
|
0105cd48f2 | ||
|
|
28258aecd2 | ||
|
|
e585950c54 | ||
|
|
bcd661afa6 | ||
|
|
adf2730d1d | ||
|
|
ba2be6371d | ||
|
|
d2ff09a404 | ||
|
|
9f364d3880 | ||
|
|
cfad41b901 | ||
|
|
6889f044fb | ||
|
|
3d1ee27ccd | ||
|
|
775ce62950 | ||
|
|
821a6f2fa6 | ||
|
|
5197fb2fad | ||
|
|
92abe91d22 | ||
|
|
a7bf0b85d7 | ||
|
|
5ce5ea84a9 | ||
|
|
992be39f90 | ||
|
|
cab80a3c56 | ||
|
|
6af7107938 | ||
|
|
bcd31cf245 | ||
|
|
85c4ccfef9 | ||
|
|
dc0f81aabc | ||
|
|
07f934566a | ||
|
|
77cb18e9e3 | ||
|
|
fccaecf730 | ||
|
|
53cdfe8f73 | ||
|
|
ea03523c6a | ||
|
|
caf3cbf8d7 | ||
|
|
da411066c9 | ||
|
|
95d0f77fc2 | ||
|
|
9b2654277b | ||
|
|
f1b3bdac3f | ||
|
|
595fdbd95d | ||
|
|
dab9385297 | ||
|
|
df83def566 | ||
|
|
f9d4e37b3c | ||
|
|
e59a3d71e0 | ||
|
|
de3a84ac59 | ||
|
|
e017266b98 | ||
|
|
f81a8a5e5c | ||
|
|
7a3a0144a5 | ||
|
|
8263b2d32d | ||
|
|
833cd490b8 | ||
|
|
2162c37e41 | ||
|
|
b2ac8376e1 | ||
|
|
8079584143 | ||
|
|
09a4474e7f | ||
|
|
81530133ff | ||
|
|
cc4b384ac3 | ||
|
|
3852daf447 | ||
|
|
5c97111f9d | ||
|
|
75dd1f0f7e | ||
|
|
c9a4551012 | ||
|
|
87197ba91d | ||
|
|
7461bf84e5 | ||
|
|
fbc0357b2e |
58
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
58
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: "\U0001F41B Bug / Help"
|
||||
description: Create a report to help us improve the LLaMA Factory
|
||||
body:
|
||||
- type: checkboxes
|
||||
id: reminder
|
||||
attributes:
|
||||
label: Reminder
|
||||
description: |
|
||||
Please ensure you have read the README carefully and searched the existing issues.
|
||||
请确保您已经认真阅读了 README 并且搜索过现有的 Issue。
|
||||
|
||||
options:
|
||||
- label: I have read the README and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide code snippets, error messages and stack traces that reproduces the problem.
|
||||
请提供运行参数,错误信息以及异常堆栈以便于我们复现该问题。
|
||||
Remember to use Markdown tags to correctly format your code.
|
||||
请合理使用 Markdown 标签来格式化您的文本。
|
||||
|
||||
placeholder: |
|
||||
python src/train_bash.py ...
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: |
|
||||
Please provide a clear and concise description of what you would expect to happen.
|
||||
请提供您原本的目的,即这段代码的期望行为。
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: transformers version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: others
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: Others
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -157,4 +157,9 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# custom .gitignore
|
||||
user.config
|
||||
saves/
|
||||
cache/
|
||||
|
||||
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
`hoshihiyouga AT gmail DOT com`.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
11
Makefile
Normal file
11
Makefile
Normal file
@@ -0,0 +1,11 @@
|
||||
.PHONY: quality style
|
||||
|
||||
check_dirs := src tests
|
||||
|
||||
quality:
|
||||
black --check $(check_dirs)
|
||||
ruff $(check_dirs)
|
||||
|
||||
style:
|
||||
black $(check_dirs)
|
||||
ruff $(check_dirs) --fix
|
||||
198
README.md
198
README.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -6,7 +6,9 @@
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
|
||||
👋 Join our [WeChat](assets/wechat.jpg).
|
||||
|
||||
@@ -14,21 +16,62 @@
|
||||
|
||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||
|
||||
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
|
||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
||||
|
||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
||||
|
||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Benchmark](#benchmark)
|
||||
- [Changelog](#changelog)
|
||||
- [Supported Models](#supported-models)
|
||||
- [Supported Training Approaches](#supported-training-approaches)
|
||||
- [Provided Datasets](#provided-datasets)
|
||||
- [Requirement](#requirement)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Projects using LLaMA Factory](#projects-using-llama-factory)
|
||||
- [License](#license)
|
||||
- [Citation](#citation)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Benchmark
|
||||
|
||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||
|
||||

|
||||
|
||||
<details><summary>Definitions</summary>
|
||||
|
||||
- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
|
||||
- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
|
||||
- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
|
||||
- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA-Factory's LoRA tuning.
|
||||
|
||||
</details>
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||
|
||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||
|
||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
@@ -48,30 +91,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
</details>
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
||||
|
||||
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
|
||||
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
@@ -79,12 +127,12 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||
> Use `--quantization_bit 4` argument to enable QLoRA.
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
@@ -122,10 +170,13 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -133,6 +184,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -141,6 +193,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -162,7 +215,15 @@ huggingface-cli login
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
|
||||
And **powerful GPUs**!
|
||||
### Hardware Requirement
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -189,6 +250,28 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### Use ModelScope Hub (optional)
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Then you can train the corresponding model by specifying a model ID of the ModelScope Hub. (find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models))
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # arguments (same as above)
|
||||
```
|
||||
|
||||
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### Train on a single GPU
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -199,8 +282,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset wiki_demo \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
@@ -222,8 +305,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
@@ -246,14 +329,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -271,19 +354,21 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--top_k 0 \
|
||||
--top_p 0.9 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
@@ -292,19 +377,22 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
||||
|
||||
#### DPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -387,28 +475,36 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
|
||||
</details>
|
||||
|
||||
### Export model
|
||||
### Merge LoRA weights and export model
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--export_dir path_to_export
|
||||
--export_dir path_to_export \
|
||||
--export_size 2 \
|
||||
--export_legacy_format False
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Merging LoRA weights into a quantized model is not supported.
|
||||
|
||||
> [!TIP]
|
||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
|
||||
|
||||
### API Demo
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> [!TIP]
|
||||
> Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
@@ -416,9 +512,9 @@ python src/api_demo.py \
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
@@ -426,9 +522,9 @@ python src/cli_demo.py \
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
@@ -436,9 +532,9 @@ python src/web_demo.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--finetuning_type lora \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
@@ -451,19 +547,23 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> [!WARNING]
|
||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 predict.
|
||||
|
||||
> [!TIP]
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
@@ -472,12 +572,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
|
||||
> [!TIP]
|
||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
205
README_zh.md
205
README_zh.md
@@ -1,4 +1,4 @@
|
||||
# LLaMA Factory: 轻松的大模型训练与评估
|
||||

|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
@@ -6,7 +6,9 @@
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
|
||||
@@ -14,27 +16,68 @@
|
||||
|
||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
|
||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
||||
|
||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## 目录
|
||||
|
||||
- [性能指标](#性能指标)
|
||||
- [更新日志](#更新日志)
|
||||
- [模型](#模型)
|
||||
- [训练方法](#训练方法)
|
||||
- [数据集](#数据集)
|
||||
- [软硬件依赖](#软硬件依赖)
|
||||
- [如何使用](#如何使用)
|
||||
- [使用了 LLaMA Factory 的项目](#使用了-llama-factory-的项目)
|
||||
- [协议](#协议)
|
||||
- [引用](#引用)
|
||||
- [致谢](#致谢)
|
||||
|
||||
## 性能指标
|
||||
|
||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||
|
||||

|
||||
|
||||
<details><summary>变量定义</summary>
|
||||
|
||||
- **Training Speed**: 训练阶段每秒处理的样本数量。(批处理大小=4,截断长度=1024)
|
||||
- **Rouge Score**: [广告文案生成](https://aclanthology.org/D19-1321.pdf)任务验证集上的 Rouge-2 分数。(批处理大小=4,截断长度=1024)
|
||||
- **GPU Memory**: 4 比特量化训练的 GPU 显存峰值。(批处理大小=1,截断长度=1024)
|
||||
- 我们在 ChatGLM 的 P-Tuning 中采用 `pre_seq_len=128`,在 LLaMA-Factory 的 LoRA 微调中采用 `lora_rank=32`。
|
||||
|
||||
</details>
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||
|
||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||
|
||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
|
||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
|
||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
||||
|
||||
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
[23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||
|
||||
@@ -48,30 +91,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
</details>
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
||||
|
||||
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
|
||||
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||
|
||||
## 训练方法
|
||||
|
||||
@@ -79,12 +127,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
|
||||
|
||||
## 数据集
|
||||
|
||||
@@ -122,10 +170,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
@@ -133,6 +184,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -141,6 +193,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -153,7 +206,7 @@ pip install --upgrade huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 软件依赖
|
||||
## 软硬件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
@@ -162,7 +215,15 @@ huggingface-cli login
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
以及 **强而有力的 GPU**!
|
||||
### 硬件依赖
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||
|
||||
## 如何使用
|
||||
|
||||
@@ -189,6 +250,28 @@ pip install -r requirements.txt
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 使用魔搭社区(可跳过)
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
|
||||
```bash
|
||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
接着即可通过指定模型名称来训练对应的模型。(在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||
... # 参数同上
|
||||
```
|
||||
|
||||
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||
```
|
||||
|
||||
### 单 GPU 训练
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -199,8 +282,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset wiki_demo \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
@@ -222,8 +305,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
@@ -246,14 +329,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -271,39 +354,45 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--top_k 0 \
|
||||
--top_p 0.9 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
||||
|
||||
#### DPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_sft_checkpoint \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -386,28 +475,36 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
|
||||
</details>
|
||||
|
||||
### 导出微调后的完整模型
|
||||
### 合并 LoRA 权重并导出模型
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--export_dir path_to_export
|
||||
--export_dir path_to_export \
|
||||
--export_size 2 \
|
||||
--export_legacy_format False
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
||||
|
||||
> [!TIP]
|
||||
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
|
||||
|
||||
### API 服务
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> [!TIP]
|
||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
@@ -415,9 +512,9 @@ python src/api_demo.py \
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
@@ -425,9 +522,9 @@ python src/cli_demo.py \
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 模型评估
|
||||
@@ -435,9 +532,9 @@ python src/web_demo.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--finetuning_type lora \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
@@ -450,19 +547,23 @@ CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--adapter_name_or_path path_to_checkpoint \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> [!WARNING]
|
||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的预测,请使用 `--per_device_eval_batch_size=1`。
|
||||
|
||||
> [!TIP]
|
||||
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
@@ -471,12 +572,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
|
||||
> [!TIP]
|
||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||
|
||||
## 协议
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
1216
assets/benchmark.svg
Normal file
1216
assets/benchmark.svg
Normal file
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 29 KiB |
@@ -2,11 +2,13 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
|
||||
"ms_hub_url": "the name of the dataset repository on the ModelScope hub. (if specified, ignore script_url and file_name)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
|
||||
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"columns": {
|
||||
@@ -16,7 +18,8 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None, for both)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -31,6 +34,7 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"output": "model response (required)",
|
||||
"system": "system prompt (optional)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
@@ -47,6 +51,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
@@ -54,7 +59,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||
|
||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
@@ -85,7 +90,8 @@ The dataset in sharegpt format should follow the below format:
|
||||
"from": "gpt",
|
||||
"value": "model response"
|
||||
}
|
||||
]
|
||||
],
|
||||
"system": "system prompt (optional)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -97,7 +103,8 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
||||
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"columns": {
|
||||
@@ -16,7 +18,8 @@
|
||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -31,6 +34,7 @@
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"system": "系统提示词(选填)",
|
||||
"history": [
|
||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||
@@ -47,6 +51,7 @@
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"system": "system",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
@@ -54,7 +59,7 @@
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||
|
||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
`system` 为模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
@@ -85,7 +90,8 @@
|
||||
"from": "gpt",
|
||||
"value": "模型回答"
|
||||
}
|
||||
]
|
||||
],
|
||||
"system": "系统提示词(选填)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -97,7 +103,8 @@
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
@@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
conversations = []
|
||||
prompt = data["instruction"].strip()
|
||||
response = data["output"].strip()
|
||||
|
||||
@@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
human_idx = prompt.rfind("Human:")
|
||||
query = prompt[human_idx+6:assist_idx].strip()
|
||||
prompt = prompt[:human_idx].strip()
|
||||
history = []
|
||||
conversations.insert(0, {"from": "gpt", "value": response})
|
||||
conversations.insert(0, {"from": "human", "value": query})
|
||||
|
||||
while prompt.rfind("Assistant:") != -1:
|
||||
assist_idx = prompt.rfind("Assistant:")
|
||||
@@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
if human_idx != -1:
|
||||
old_query = prompt[human_idx+6:assist_idx].strip()
|
||||
old_resp = prompt[assist_idx+10:].strip()
|
||||
history.insert(0, (old_query, old_resp))
|
||||
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
||||
conversations.insert(0, {"from": "human", "value": old_query})
|
||||
else:
|
||||
break
|
||||
prompt = prompt[:human_idx].strip()
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": response,
|
||||
"history": history
|
||||
}
|
||||
yield key, {"conversations": conversations}
|
||||
|
||||
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
1
data/glaive_toolcall_10k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
4748dff00d1dc42768a5b6cc772143c313017812
|
||||
@@ -1 +0,0 @@
|
||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||
@@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
yield key, {
|
||||
"conversations": conversations
|
||||
}
|
||||
yield key, {"conversations": conversations}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
Machine learning (ML) is a field devoted to understanding and building methods that let machines "learn" – that is, methods that leverage data to improve computer performance on some set of tasks.
|
||||
Machine learning algorithms build a model based on sample data, known as training data, in order to make predictions or decisions without being explicitly programmed to do so. Machine learning algorithms are used in a wide variety of applications, such as in medicine, email filtering, speech recognition, agriculture, and computer vision, where it is difficult or unfeasible to develop conventional algorithms to perform the needed tasks.
|
||||
A subset of machine learning is closely related to computational statistics, which focuses on making predictions using computers, but not all machine learning is statistical learning. The study of mathematical optimization delivers methods, theory and application domains to the field of machine learning. Data mining is a related field of study, focusing on exploratory data analysis through unsupervised learning.
|
||||
Some implementations of machine learning use data and neural networks in a way that mimics the working of a biological brain.
|
||||
In its application across business problems, machine learning is also referred to as predictive analytics.
|
||||
Learning algorithms work on the basis that strategies, algorithms, and inferences that worked well in the past are likely to continue working well in the future. These inferences can sometimes be obvious, such as "since the sun rose every morning for the last 10,000 days, it will probably rise tomorrow morning as well". Other times, they can be more nuanced, such as "X% of families have geographically separate species with color variants, so there is a Y% chance that undiscovered black swans exist".
|
||||
Machine learning programs can perform tasks without being explicitly programmed to do so. It involves computers learning from data provided so that they carry out certain tasks. For simple tasks assigned to computers, it is possible to program algorithms telling the machine how to execute all steps required to solve the problem at hand; on the computer's part, no learning is needed. For more advanced tasks, it can be challenging for a human to manually create the needed algorithms. In practice, it can turn out to be more effective to help the machine develop its own algorithm, rather than having human programmers specify every needed step.
|
||||
The discipline of machine learning employs various approaches to teach computers to accomplish tasks where no fully satisfactory algorithm is available. In cases where vast numbers of potential answers exist, one approach is to label some of the correct answers as valid. This can then be used as training data for the computer to improve the algorithm(s) it uses to determine correct answers. For example, to train a system for the task of digital character recognition, the MNIST dataset of handwritten digits has often been used.
|
||||
The term machine learning was coined in 1959 by Arthur Samuel, an IBM employee and pioneer in the field of computer gaming and artificial intelligence. The synonym self-teaching computers was also used in this time period.
|
||||
By the early 1960s an experimental "learning machine" with punched tape memory, called Cybertron, had been developed by Raytheon Company to analyze sonar signals, electrocardiograms, and speech patterns using rudimentary reinforcement learning. It was repetitively "trained" by a human operator/teacher to recognize patterns and equipped with a "goof" button to cause it to re-evaluate incorrect decisions. A representative book on research into machine learning during the 1960s was Nilsson's book on Learning Machines, dealing mostly with machine learning for pattern classification. Interest related to pattern recognition continued into the 1970s, as described by Duda and Hart in 1973. In 1981 a report was given on using teaching strategies so that a neural network learns to recognize 40 characters (26 letters, 10 digits, and 4 special symbols) from a computer terminal.
|
||||
Tom M. Mitchell provided a widely quoted, more formal definition of the algorithms studied in the machine learning field: "A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P if its performance at tasks in T, as measured by P, improves with experience E." This definition of the tasks in which machine learning is concerned offers a fundamentally operational definition rather than defining the field in cognitive terms. This follows Alan Turing's proposal in his paper "Computing Machinery and Intelligence", in which the question "Can machines think?" is replaced with the question "Can machines do what we (as thinking entities) can do?".
|
||||
Modern-day machine learning has two objectives, one is to classify data based on models which have been developed, the other purpose is to make predictions for future outcomes based on these models. A hypothetical algorithm specific to classifying data may use computer vision of moles coupled with supervised learning in order to train it to classify the cancerous moles. A machine learning algorithm for stock trading may inform the trader of future potential predictions.
|
||||
As a scientific endeavor, machine learning grew out of the quest for artificial intelligence (AI). In the early days of AI as an academic discipline, some researchers were interested in having machines learn from data. They attempted to approach the problem with various symbolic methods, as well as what were then termed "neural networks"; these were mostly perceptrons and other models that were later found to be reinventions of the generalized linear models of statistics. Probabilistic reasoning was also employed, especially in automated medical diagnosis.: 488
|
||||
However, an increasing emphasis on the logical, knowledge-based approach caused a rift between AI and machine learning. Probabilistic systems were plagued by theoretical and practical problems of data acquisition and representation.: 488 By 1980, expert systems had come to dominate AI, and statistics was out of favor. Work on symbolic/knowledge-based learning did continue within AI, leading to inductive logic programming, but the more statistical line of research was now outside the field of AI proper, in pattern recognition and information retrieval.: 708–710, 755 Neural networks research had been abandoned by AI and computer science around the same time. This line, too, was continued outside the AI/CS field, as "connectionism", by researchers from other disciplines including Hopfield, Rumelhart, and Hinton. Their main success came in the mid-1980s with the reinvention of backpropagation.: 25
|
||||
Machine learning (ML), reorganized and recognized as its own field, started to flourish in the 1990s. The field changed its goal from achieving artificial intelligence to tackling solvable problems of a practical nature. It shifted focus away from the symbolic approaches it had inherited from AI, and toward methods and models borrowed from statistics, fuzzy logic, and probability theory.
|
||||
Machine learning and data mining often employ the same methods and overlap significantly, but while machine learning focuses on prediction, based on known properties learned from the training data, data mining focuses on the discovery of (previously) unknown properties in the data (this is the analysis step of knowledge discovery in databases). Data mining uses many machine learning methods, but with different goals; on the other hand, machine learning also employs data mining methods as "unsupervised learning" or as a preprocessing step to improve learner accuracy. Much of the confusion between these two research communities (which do often have separate conferences and separate journals, ECML PKDD being a major exception) comes from the basic assumptions they work with: in machine learning, performance is usually evaluated with respect to the ability to reproduce known knowledge, while in knowledge discovery and data mining (KDD) the key task is the discovery of previously unknown knowledge. Evaluated with respect to known knowledge, an uninformed (unsupervised) method will easily be outperformed by other supervised methods, while in a typical KDD task, supervised methods cannot be used due to the unavailability of training data.
|
||||
Machine learning also has intimate ties to optimization: many learning problems are formulated as minimization of some loss function on a training set of examples. Loss functions express the discrepancy between the predictions of the model being trained and the actual problem instances (for example, in classification, one wants to assign a label to instances, and models are trained to correctly predict the pre-assigned labels of a set of examples).
|
||||
The difference between optimization and machine learning arises from the goal of generalization: while optimization algorithms can minimize the loss on a training set, machine learning is concerned with minimizing the loss on unseen samples. Characterizing the generalization of various learning algorithms is an active topic of current research, especially for deep learning algorithms.
|
||||
Machine learning and statistics are closely related fields in terms of methods, but distinct in their principal goal: statistics draws population inferences from a sample, while machine learning finds generalizable predictive patterns. According to Michael I. Jordan, the ideas of machine learning, from methodological principles to theoretical tools, have had a long pre-history in statistics. He also suggested the term data science as a placeholder to call the overall field.
|
||||
Leo Breiman distinguished two statistical modeling paradigms: data model and algorithmic model, wherein "algorithmic model" means more or less the machine learning algorithms like Random Forest.
|
||||
Some statisticians have adopted methods from machine learning, leading to a combined field that they call statistical learning.
|
||||
Analytical and computational techniques derived from deep-rooted physics of disordered systems can be extended to large-scale problems, including machine learning, e.g., to analyze the weight space of deep neural networks. Statistical physics is thus finding applications in the area of medical diagnostics.
|
||||
A core objective of a learner is to generalize from its experience. Generalization in this context is the ability of a learning machine to perform accurately on new, unseen examples/tasks after having experienced a learning data set. The training examples come from some generally unknown probability distribution (considered representative of the space of occurrences) and the learner has to build a general model about this space that enables it to produce sufficiently accurate predictions in new cases.
|
||||
The computational analysis of machine learning algorithms and their performance is a branch of theoretical computer science known as computational learning theory via the Probably Approximately Correct Learning (PAC) model. Because training sets are finite and the future is uncertain, learning theory usually does not yield guarantees of the performance of algorithms. Instead, probabilistic bounds on the performance are quite common. The bias–variance decomposition is one way to quantify generalization error.
|
||||
For the best performance in the context of generalization, the complexity of the hypothesis should match the complexity of the function underlying the data. If the hypothesis is less complex than the function, then the model has under fitted the data. If the complexity of the model is increased in response, then the training error decreases. But if the hypothesis is too complex, then the model is subject to overfitting and generalization will be poorer.
|
||||
In addition to performance bounds, learning theorists study the time complexity and feasibility of learning. In computational learning theory, a computation is considered feasible if it can be done in polynomial time. There are two kinds of time complexity results: Positive results show that a certain class of functions can be learned in polynomial time. Negative results show that certain classes cannot be learned in polynomial time.
|
||||
Machine learning approaches are traditionally divided into three broad categories, which correspond to learning paradigms, depending on the nature of the "signal" or "feedback" available to the learning system:
|
||||
Supervised learning: The computer is presented with example inputs and their desired outputs, given by a "teacher", and the goal is to learn a general rule that maps inputs to outputs.
|
||||
Unsupervised learning: No labels are given to the learning algorithm, leaving it on its own to find structure in its input. Unsupervised learning can be a goal in itself (discovering hidden patterns in data) or a means towards an end (feature learning).
|
||||
Reinforcement learning: A computer program interacts with a dynamic environment in which it must perform a certain goal (such as driving a vehicle or playing a game against an opponent). As it navigates its problem space, the program is provided feedback that's analogous to rewards, which it tries to maximize. Although each algorithm has advantages and limitations, no single algorithm works for all problems.
|
||||
Supervised learning algorithms build a mathematical model of a set of data that contains both the inputs and the desired outputs. The data is known as training data, and consists of a set of training examples. Each training example has one or more inputs and the desired output, also known as a supervisory signal. In the mathematical model, each training example is represented by an array or vector, sometimes called a feature vector, and the training data is represented by a matrix. Through iterative optimization of an objective function, supervised learning algorithms learn a function that can be used to predict the output associated with new inputs. An optimal function will allow the algorithm to correctly determine the output for inputs that were not a part of the training data. An algorithm that improves the accuracy of its outputs or predictions over time is said to have learned to perform that task.
|
||||
Types of supervised-learning algorithms include active learning, classification and regression. Classification algorithms are used when the outputs are restricted to a limited set of values, and regression algorithms are used when the outputs may have any numerical value within a range. As an example, for a classification algorithm that filters emails, the input would be an incoming email, and the output would be the name of the folder in which to file the email.
|
||||
Similarity learning is an area of supervised machine learning closely related to regression and classification, but the goal is to learn from examples using a similarity function that measures how similar or related two objects are. It has applications in ranking, recommendation systems, visual identity tracking, face verification, and speaker verification.
|
||||
Unsupervised learning algorithms take a set of data that contains only inputs, and find structure in the data, like grouping or clustering of data points. The algorithms, therefore, learn from test data that has not been labeled, classified or categorized. Instead of responding to feedback, unsupervised learning algorithms identify commonalities in the data and react based on the presence or absence of such commonalities in each new piece of data. A central application of unsupervised learning is in the field of density estimation in statistics, such as finding the probability density function. Though unsupervised learning encompasses other domains involving summarizing and explaining data features. Unsupervised learning algorithms streamlined the process of survey and graph large indel based haplotypes of a gene of interest from pan-genome.
|
||||
Cluster analysis is the assignment of a set of observations into subsets (called clusters) so that observations within the same cluster are similar according to one or more predesignated criteria, while observations drawn from different clusters are dissimilar. Different clustering techniques make different assumptions on the structure of the data, often defined by some similarity metric and evaluated, for example, by internal compactness, or the similarity between members of the same cluster, and separation, the difference between clusters. Other methods are based on estimated density and graph connectivity.
|
||||
Semi-supervised learning falls between unsupervised learning (without any labeled training data) and supervised learning (with completely labeled training data). Some of the training examples are missing training labels, yet many machine-learning researchers have found that unlabeled data, when used in conjunction with a small amount of labeled data, can produce a considerable improvement in learning accuracy.
|
||||
In weakly supervised learning, the training labels are noisy, limited, or imprecise; however, these labels are often cheaper to obtain, resulting in larger effective training sets.
|
||||
Reinforcement learning is an area of machine learning concerned with how software agents ought to take actions in an environment so as to maximize some notion of cumulative reward. Due to its generality, the field is studied in many other disciplines, such as game theory, control theory, operations research, information theory, simulation-based optimization, multi-agent systems, swarm intelligence, statistics and genetic algorithms. In machine learning, the environment is typically represented as a Markov decision process (MDP). Many reinforcements learning algorithms use dynamic programming techniques. Reinforcement learning algorithms do not assume knowledge of an exact mathematical model of the MDP and are used when exact models are infeasible. Reinforcement learning algorithms are used in autonomous vehicles or in learning to play a game against a human opponent.
|
||||
Dimensionality reduction is a process of reducing the number of random variables under consideration by obtaining a set of principal variables. In other words, it is a process of reducing the dimension of the feature set, also called the "number of features". Most of the dimensionality reduction techniques can be considered as either feature elimination or extraction. One of the popular methods of dimensionality reduction is principal component analysis (PCA). PCA involves changing higher-dimensional data (e.g., 3D) to a smaller space (e.g., 2D). This results in a smaller dimension of data (2D instead of 3D), while keeping all original variables in the model without changing the data. The manifold hypothesis proposes that high-dimensional data sets lie along low-dimensional manifolds, and many dimensionality reduction techniques make this assumption, leading to the area of manifold learning and manifold regularization.
|
||||
Although machine learning has been transformative in some fields, machine-learning programs often fail to deliver expected results. Reasons for this are numerous: lack of (suitable) data, lack of access to the data, data bias, privacy problems, badly chosen tasks and algorithms, wrong tools and people, lack of resources, and evaluation problems.
|
||||
In 2018, a self-driving car from Uber failed to detect a pedestrian, who was killed after a collision. Attempts to use machine learning in healthcare with the IBM Watson system failed to deliver even after years of time and billions of dollars invested.
|
||||
Machine learning has been used as a strategy to update the evidence related to a systematic review and increased reviewer burden related to the growth of biomedical literature. While it has improved with training sets, it has not yet developed sufficiently to reduce the workload burden without limiting the necessary sensitivity for the findings research themselves.
|
||||
Machine learning approaches in particular can suffer from different data biases. A machine learning system trained specifically on current customers may not be able to predict the needs of new customer groups that are not represented in the training data. When trained on human-made data, machine learning is likely to pick up the constitutional and unconscious biases already present in society. Language models learned from data have been shown to contain human-like biases. Machine learning systems used for criminal risk assessment have been found to be biased against black people. In 2015, Google photos would often tag black people as gorillas, and in 2018 this still was not well resolved, but Google reportedly was still using the workaround to remove all gorillas from the training data, and thus was not able to recognize real gorillas at all. Similar issues with recognizing non-white people have been found in many other systems. In 2016, Microsoft tested a chatbot that learned from Twitter, and it quickly picked up racist and sexist language. Because of such challenges, the effective use of machine learning may take longer to be adopted in other domains. Concern for fairness in machine learning, that is, reducing bias in machine learning and propelling its use for human good is increasingly expressed by artificial intelligence scientists, including Fei-Fei Li, who reminds engineers that "There's nothing artificial about AI...It's inspired by people, it's created by people, and—most importantly—it impacts people. It is a powerful tool we are only just beginning to understand, and that is a profound responsibility."
|
||||
Learners can also disappoint by "learning the wrong lesson". A toy example is that an image classifier trained only on pictures of brown horses and black cats might conclude that all brown patches are likely to be horses. A real-world example is that, unlike humans, current image classifiers often do not primarily make judgments from the spatial relationship between components of the picture, and they learn relationships between pixels that humans are oblivious to, but that still correlate with images of certain types of real objects. Modifying these patterns on a legitimate image can result in "adversarial" images that the system misclassifies.
|
||||
Adversarial vulnerabilities can also result in nonlinear systems, or from non-pattern perturbations. Some systems are so brittle that changing a single adversarial pixel predictably induces misclassification.[citation needed] Machine learning models are often vulnerable to manipulation and/or evasion via adversarial machine learning.
|
||||
Researchers have demonstrated how backdoors can be placed undetectably into classifying (e.g., for categories "spam" and well-visible "not spam" of posts) machine learning models which are often developed and/or trained by third parties. Parties can change the classification of any input, including in cases for which a type of data/software transparency is provided, possibly including white-box access.
|
||||
Machine learning poses a host of ethical questions. Systems that are trained on datasets collected with biases may exhibit these biases upon use (algorithmic bias), thus digitizing cultural prejudices. For example, in 1988, the UK's Commission for Racial Equality found that St. George's Medical School had been using a computer program trained from data of previous admissions staff and this program had denied nearly 60 candidates who were found to be either women or had non-European sounding names. Using job hiring data from a firm with racist hiring policies may lead to a machine learning system duplicating the bias by scoring job applicants by similarity to previous successful applicants. Responsible collection of data and documentation of algorithmic rules used by a system thus is a critical part of machine learning.
|
||||
AI can be well-equipped to make decisions in technical fields, which rely heavily on data and historical information. These decisions rely on the objectivity and logical reasoning. Because human languages contain biases, machines trained on language corpora will necessarily also learn these biases.
|
||||
Other forms of ethical challenges, not related to personal biases, are seen in health care. There are concerns among health care professionals that these systems might not be designed in the public's interest but as income-generating machines. This is especially true in the United States where there is a long-standing ethical dilemma of improving health care, but also increase profits. For example, the algorithms could be designed to provide patients with unnecessary tests or medication in which the algorithm's proprietary owners hold stakes. There is potential for machine learning in health care to provide professionals an additional tool to diagnose, medicate, and plan recovery paths for patients, but this requires these biases to be mitigated.
|
||||
Since the 2010s, advances in both machine learning algorithms and computer hardware have led to more efficient methods for training deep neural networks (a particular narrow subdomain of machine learning) that contain many layers of non-linear hidden units. By 2019, graphic processing units (GPUs), often with AI-specific enhancements, had displaced CPUs as the dominant method of training large-scale commercial cloud AI. OpenAI estimated the hardware computing used in the largest deep learning projects from AlexNet (2012) to AlphaZero (2017), and found a 300,000-fold increase in the amount of compute required, with a doubling-time trendline of 3.4 months.
|
||||
1
data/wiki_demo.txt.REMOVED.git-id
Normal file
1
data/wiki_demo.txt.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb
|
||||
@@ -1,3 +1,37 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
target-version = ["py38"]
|
||||
|
||||
[tool.ruff]
|
||||
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
line-length = 119
|
||||
|
||||
[tool.ruff.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["llmtuner"]
|
||||
|
||||
[isort]
|
||||
default_section = "FIRSTPARTY"
|
||||
known_first_party = "llmtuner"
|
||||
known_third_party = [
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"gradio",
|
||||
"numpy",
|
||||
"peft",
|
||||
"torch",
|
||||
"transformers",
|
||||
"trl"
|
||||
]
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.31.0,<4.35.0
|
||||
datasets>=2.14.0
|
||||
transformers>=4.36.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.21.0
|
||||
peft>=0.6.0
|
||||
trl>=0.7.4
|
||||
peft>=0.7.0
|
||||
trl>=0.7.6
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
protobuf
|
||||
tiktoken
|
||||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel, create_app
|
||||
@@ -6,8 +8,8 @@ from llmtuner import ChatModel, create_app
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:8000/docs for API document.")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
import readline
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
|
||||
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
history = []
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
@@ -20,19 +29,20 @@ def main():
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(query, history):
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
|
||||
history = history + [(query, response)]
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
|
||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.eval import Evaluator
|
||||
from llmtuner.tuner import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
from .api import create_app
|
||||
from .chat import ChatModel
|
||||
from .eval import Evaluator
|
||||
from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.2.2"
|
||||
__version__ = "0.5.0"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from llmtuner.api.app import create_app
|
||||
from .app import create_app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
||||
|
||||
@@ -1,44 +1,68 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
|
||||
if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
try: # pydantic v2
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except: # pydantic v1
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -49,6 +73,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
@@ -56,91 +82,145 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
if not chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
system = prev_messages.pop(0).content
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
messages = [dictify(message) for message in request.messages]
|
||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = None
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
if len(messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif messages[i]["role"] == Role.TOOL:
|
||||
messages[i]["role"] = DataRole.OBSERVATION
|
||||
|
||||
tool_list = request.tools
|
||||
if len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
|
||||
|
||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
generate = stream_chat_completion(messages, system, tools, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
query, history, system,
|
||||
responses = chat_model.chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length+response_length
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
def stream_chat_completion(
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
finish_reason=None
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, system,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=new_text),
|
||||
finish_reason=None
|
||||
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=Finish.STOP
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield to_json(chunk)
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_score, request)
|
||||
|
||||
def get_score(request: ScoreEvaluationRequest):
|
||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
@@ -1,30 +1,48 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum, unique
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@unique
|
||||
class Finish(str, Enum):
|
||||
STOP = "stop"
|
||||
LENGTH = "length"
|
||||
TOOL = "tool_calls"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
object: Literal["model"] = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Literal["owner"] = "owner"
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Optional[str] = "list"
|
||||
data: Optional[List[ModelCard]] = []
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -32,31 +50,33 @@ class ChatMessage(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
do_sample: Optional[bool] = True
|
||||
tools: Optional[list] = []
|
||||
do_sample: bool = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
|
||||
|
||||
@@ -67,17 +87,30 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Optional[str] = "chat.completion"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Optional[str] = "chatcmpl-default"
|
||||
object: Optional[str] = "chat.completion.chunk"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from .chat_model import ChatModel
|
||||
|
||||
|
||||
__all__ = ["ChatModel"]
|
||||
|
||||
161
src/llmtuner/chat/chat_model.py
Normal file
161
src/llmtuner/chat/chat_model.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..hparams import get_infer_args
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
response_text: str
|
||||
response_length: int
|
||||
prompt_length: int
|
||||
finish_reason: Literal["stop", "length"]
|
||||
|
||||
|
||||
class ChatModel:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
|
||||
def _process_args(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
generating_args = self.generating_args.to_dict()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args.pop("max_length", None)
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=input_ids,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List[Response]:
|
||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(
|
||||
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
results = []
|
||||
for i in range(len(response)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||
results.append(
|
||||
Response(
|
||||
response_text=response[i],
|
||||
response_length=response_length,
|
||||
prompt_length=prompt_length,
|
||||
finish_reason="stop" if len(eos_index) else "length",
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
yield from streamer
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||
|
||||
inputs = self.tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
return scores
|
||||
@@ -1,109 +0,0 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
system = system or self.system_prompt
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
generating_args = self.generating_args.to_dict()
|
||||
generating_args.update(dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
||||
if max_new_tokens:
|
||||
generating_args.pop("max_length", None)
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=input_ids,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor()
|
||||
)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[List[str], Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response_length = 0
|
||||
for i in range(len(response_ids)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
||||
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
yield from streamer
|
||||
6
src/llmtuner/data/__init__.py
Normal file
6
src/llmtuner/data/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .loader import get_dataset
|
||||
from .template import get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
||||
108
src/llmtuner/data/aligner.py
Normal file
108
src/llmtuner/data/aligner.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history:
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
|
||||
instruction = examples[dataset_attr.prompt][i]
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
instruction += "\n" + examples[dataset_attr.query][i]
|
||||
prompt.append({"role": Role.USER, "content": instruction})
|
||||
|
||||
if dataset_attr.response:
|
||||
if isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
}
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
prompt = []
|
||||
response = []
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if turn_idx % 2 == 0:
|
||||
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
|
||||
else:
|
||||
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
|
||||
|
||||
if message[dataset_attr.role_tag] not in accept_tags:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
|
||||
prompt.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
last_message = prompt.pop(-1)
|
||||
response.append(last_message)
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
prompt: [{"role": "user", "content": "..."}]
|
||||
response: [{"role": "assistant", "content": "..."}]
|
||||
system: "..."
|
||||
tools: "..."
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
|
||||
else:
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
148
src/llmtuner/data/formatter.py
Normal file
148
src/llmtuner/data/formatter.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||
)
|
||||
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format to answer the question:\n"
|
||||
"```\n"
|
||||
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
|
||||
"Action Input: the input to the action{format_prompt}.\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Literal["default"] = "default"
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
193
src/llmtuner/data/loader.py
Normal file
193
src/llmtuner/data/loader.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Literal, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from .aligner import align_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
):
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_files = []
|
||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
||||
checksum(data_files, dataset_attr.dataset_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
).to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
# split: Optional[str] = "train", # TODO: add split
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
# Load from cache
|
||||
if data_args.cache_path is not None:
|
||||
if os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.cache_path)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off dataset streaming to save cache files.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
103
src/llmtuner/data/parser.py
Normal file
103
src/llmtuner/data/parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
system: Optional[str] = None
|
||||
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if data_args.dataset is not None:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||
)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names = ["prompt", "query", "response", "history"]
|
||||
else:
|
||||
column_names = ["messages", "tools"]
|
||||
|
||||
column_names += ["system"]
|
||||
for column_name in column_names:
|
||||
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
|
||||
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
|
||||
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
return dataset_list
|
||||
248
src/llmtuner/data/preprocess.py
Normal file
248
src/llmtuner/data/preprocess.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
|
||||
tokenized_examples["attention_mask"][i] += [1]
|
||||
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_unsupervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(
|
||||
"labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
template: "Template",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.sft_packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
574
src/llmtuner/data/template.py
Normal file
574
src/llmtuner/data/template.py
Normal file
@@ -0,0 +1,574 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role, infer_max_len
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .formatter import Formatter
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
format_user: "Formatter"
|
||||
format_assistant: "Formatter"
|
||||
format_system: "Formatter"
|
||||
format_function: "Formatter"
|
||||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids += query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
answer_ids = encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
token_ids = []
|
||||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
if len(elem) != 0:
|
||||
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
elif isinstance(elem, set):
|
||||
if "bos_token" in elem and tokenizer.bos_token_id:
|
||||
token_ids += [tokenizer.bos_token_id]
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
def _make_pairs(
|
||||
self,
|
||||
encoded_messages: Sequence[List[int]],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
name: str,
|
||||
format_user: Optional["Formatter"] = None,
|
||||
format_assistant: Optional["Formatter"] = None,
|
||||
format_system: Optional["Formatter"] = None,
|
||||
format_function: Optional["Formatter"] = None,
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
default_system: Optional[str] = "",
|
||||
stop_words: Optional[List[str]] = [],
|
||||
efficient_eos: Optional[bool] = False,
|
||||
replace_eos: Optional[bool] = False,
|
||||
force_system: Optional[bool] = False,
|
||||
) -> None:
|
||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(slots="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
format_system=format_system or default_user_formatter,
|
||||
format_function=format_function or default_function_formatter,
|
||||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
)
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if name is None: # for pre-training
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
if not stop_words:
|
||||
raise ValueError("Stop words are required to replace the EOS token.")
|
||||
|
||||
tokenizer.eos_token = stop_words[0]
|
||||
stop_words = stop_words[1:]
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if stop_words:
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
|
||||
return template
|
||||
|
||||
|
||||
register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
stop_words=["</s>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||
"by the user such as English and 中文."
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(name="vanilla")
|
||||
|
||||
|
||||
register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
default_system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<eod>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
@@ -1,25 +1,26 @@
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text"
|
||||
}
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
OBSERVATION = "observation"
|
||||
FUNCTION = "function"
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
@@ -37,13 +38,18 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - max_target_len
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
@@ -57,5 +63,5 @@ def split_dataset(
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
@@ -1,3 +0,0 @@
|
||||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||
from llmtuner.dsets.utils import split_dataset
|
||||
@@ -1,145 +0,0 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments"
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path, data_name = None, None
|
||||
data_files: List[str] = []
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
if data_path is None:
|
||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
checksum(data_files, dataset_attr.dataset_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=data_args.streaming
|
||||
)
|
||||
|
||||
if max_samples is not None: # truncate dataset
|
||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
||||
|
||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# convert dataset from sharegpt format to alpaca format
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
||||
for msg_list in examples[dataset_attr.messages]:
|
||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||
if len(msg_list) == 0:
|
||||
continue
|
||||
|
||||
msg_pairs = []
|
||||
user_role, assistant_role = None, None
|
||||
for idx in range(0, len(msg_list), 2):
|
||||
if user_role is None and assistant_role is None:
|
||||
user_role = msg_list[idx][dataset_attr.role]
|
||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
||||
else:
|
||||
if (
|
||||
msg_list[idx][dataset_attr.role] != user_role
|
||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
||||
):
|
||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
||||
|
||||
if len(msg_pairs) != 0:
|
||||
outputs["prompt"].append(msg_pairs[-1][0])
|
||||
outputs["query"].append("")
|
||||
outputs["response"].append(msg_pairs[-1][1])
|
||||
outputs["history"].append(msg_pairs[:-1])
|
||||
|
||||
return outputs
|
||||
|
||||
if dataset_attr.formatting == "sharegpt": # convert format
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
convert_format,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
system_prompt = dataset_attr.system_prompt
|
||||
if data_args.streaming:
|
||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||
else:
|
||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=data_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
@@ -1,272 +0,0 @@
|
||||
import os
|
||||
import tiktoken
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
|
||||
from datasets import load_from_disk
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = examples["history"][i] if "history" in examples else None
|
||||
system = examples["system"][i] if "system" in examples else None
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
||||
setattr(tokenizer, "add_eos_token", True)
|
||||
|
||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# make sure the saved tokenizer is the same as the original one
|
||||
if hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
||||
return result
|
||||
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
total_len = len(source_ids) + len(target_ids)
|
||||
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
|
||||
|
||||
if len(source_ids) > max_source_len:
|
||||
source_ids = source_ids[:max_source_len]
|
||||
if len(target_ids) > max_target_len:
|
||||
target_ids = target_ids[:max_target_len]
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i: i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and query != ""):
|
||||
continue
|
||||
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
if len(labels) > data_args.cutoff_len:
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
||||
continue
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
|
||||
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
|
||||
|
||||
if len(prompt_ids) > max_source_len:
|
||||
prompt_ids = prompt_ids[:max_source_len]
|
||||
if len(chosen_ids) > max_target_len:
|
||||
chosen_ids = chosen_ids[:max_target_len]
|
||||
if len(rejected_ids) > max_target_len:
|
||||
rejected_ids = rejected_ids[:max_target_len]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_func = preprocess_pretrain_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
||||
print_function = print_supervised_dataset_example
|
||||
elif stage == "rm":
|
||||
preprocess_func = preprocess_pairwise_dataset
|
||||
print_function = print_pairwise_dataset_example
|
||||
else:
|
||||
preprocess_func = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
return load_from_disk(data_args.cache_path)
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
@@ -1 +1,4 @@
|
||||
from llmtuner.eval.engine import Evaluator
|
||||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
__all__ = ["Evaluator"]
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
@@ -1,40 +1,34 @@
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
import tiktoken
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
from datasets import load_dataset
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||
from llmtuner.eval.parser import get_eval_args
|
||||
from llmtuner.eval.template import get_eval_template
|
||||
from llmtuner.extras.misc import dispatch_model
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm, trange
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import CHOICES, SUBJECTS
|
||||
from ..hparams import get_eval_args
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
from .template import get_eval_template
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = self._encode_choices()
|
||||
|
||||
def _encode_choices(self) -> List[int]:
|
||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=False)
|
||||
|
||||
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
@@ -45,7 +39,13 @@ class Evaluator:
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
|
||||
mapping = cached_file(
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
filename="mapping.json",
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
@@ -53,35 +53,45 @@ class Evaluator:
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
name=subject,
|
||||
download_mode="force_redownload"
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
download_mode=self.eval_args.download_mode,
|
||||
token=self.model_args.hf_hub_token,
|
||||
**kwargs,
|
||||
)
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, outputs, labels = [], [], []
|
||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
query, resp, history = self.eval_template.format_example(
|
||||
support_set = (
|
||||
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
)
|
||||
messages = self.eval_template.format_example(
|
||||
target_data=dataset[self.data_args.split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=self.template.use_history
|
||||
)
|
||||
input_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
||||
)
|
||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||
labels.append(resp)
|
||||
|
||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
||||
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
|
||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||
labels.append(messages[-1]["content"])
|
||||
|
||||
for i in trange(
|
||||
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
|
||||
):
|
||||
batch_input = self.tokenizer.pad(
|
||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
preds = self.batch_inference(batch_input)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
corrects = np.array(outputs) == np.array(labels)
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||
@@ -91,10 +101,13 @@ class Evaluator:
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join([
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
||||
])
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
for category_name, category_correct in category_corrects.items()
|
||||
if len(category_correct)
|
||||
]
|
||||
)
|
||||
print(score_info)
|
||||
if self.eval_args.save_dir is not None:
|
||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||
@@ -1,49 +0,0 @@
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
)
|
||||
|
||||
|
||||
def parse_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
))
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
||||
@@ -1,7 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from llmtuner.eval.constants import CHOICES
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
@@ -9,60 +11,39 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
example: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
def parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self,
|
||||
target_data: Dict[str, str],
|
||||
support_set: "Dataset",
|
||||
subject_name: str,
|
||||
use_history: bool
|
||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||
query, resp = self.parse_example(target_data)
|
||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
self, target_data: Dict[str, str], support_set: "Dataset", subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self.parse_example(support_set[k])
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
|
||||
if len(history):
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
query = self.system.format(subject=subject_name) + query
|
||||
|
||||
if not use_history:
|
||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||
history = []
|
||||
return query.strip(), resp, history
|
||||
prompt, response = self.parse_example(target_data)
|
||||
messages.append({"role": Role.USER, "content": prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||
return messages
|
||||
|
||||
|
||||
eval_templates: Dict[str, EvalTemplate] = {}
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def register_eval_template(
|
||||
name: str,
|
||||
system: str,
|
||||
choice: str,
|
||||
answer: str,
|
||||
prefix: str
|
||||
) -> None:
|
||||
eval_templates[name] = EvalTemplate(
|
||||
system=system,
|
||||
choice=choice,
|
||||
answer=answer,
|
||||
prefix=prefix
|
||||
)
|
||||
def register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> EvalTemplate:
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
return eval_template
|
||||
@@ -73,7 +54,7 @@ register_eval_template(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
prefix=" ",
|
||||
)
|
||||
|
||||
|
||||
@@ -82,5 +63,5 @@ register_eval_template(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
prefix="\n",
|
||||
)
|
||||
|
||||
@@ -1,46 +1,38 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import LOG_FILE_NAME
|
||||
from .logging import get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||
fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors,
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.in_training = False
|
||||
@@ -106,7 +98,9 @@ class LogCallback(TrainerCallback):
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
||||
def on_predict(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
@@ -132,18 +126,22 @@ class LogCallback(TrainerCallback):
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
if self.runner is not None:
|
||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
))
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
)
|
||||
)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
|
||||
@@ -1,34 +1,52 @@
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"Pre-Training": "pt"
|
||||
"Pre-Training": "pt",
|
||||
}
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, str],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None
|
||||
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
@@ -45,193 +63,608 @@ def register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan"
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
},
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2"
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||
},
|
||||
template="bluelm"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||
"BLOOM-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
|
||||
},
|
||||
"BLOOM-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
|
||||
},
|
||||
"BLOOM-7B1": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||
"BLOOMZ-560M": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
|
||||
},
|
||||
"BLOOMZ-3B": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
|
||||
},
|
||||
"BLOOMZ-7B1-mt": {
|
||||
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="llama2_zh"
|
||||
template="bluelm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-180B": "tiiuae/falcon-180B",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon"
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-20B": "internlm/internlm-20b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||
"ChatGLM3-6B-Base": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
||||
},
|
||||
"ChatGLM3-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
||||
},
|
||||
},
|
||||
template="intern"
|
||||
module="query_key_value",
|
||||
template="chatglm3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||
"ChineseLLaMA2-1.3B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||
},
|
||||
"ChineseLLaMA2-1.3B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||
},
|
||||
},
|
||||
module="qkv_proj"
|
||||
template="llama2_zh",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b"
|
||||
"DeepSeekLLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
||||
},
|
||||
"DeepSeekLLM-67B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
||||
},
|
||||
"DeepSeekLLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
},
|
||||
"DeepSeekLLM-67B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekCoder-6.7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
},
|
||||
"DeepSeekCoder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||
},
|
||||
"DeepSeekCoder-6.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
},
|
||||
"DeepSeekCoder-33B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
},
|
||||
},
|
||||
template="deepseekcoder",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekMoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeekMoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
||||
},
|
||||
"Falcon-40B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
||||
},
|
||||
"Falcon-180B": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
|
||||
},
|
||||
"Falcon-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
|
||||
},
|
||||
"Falcon-40B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
|
||||
},
|
||||
"Falcon-180B-Chat": {
|
||||
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
||||
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
||||
},
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
||||
},
|
||||
},
|
||||
template="intern",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM2-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
|
||||
},
|
||||
"InternLM2-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
|
||||
},
|
||||
"InternLM2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
|
||||
},
|
||||
"InternLM2-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||
},
|
||||
},
|
||||
module="wqkv",
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
},
|
||||
module="qkv_proj",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
|
||||
"LLaMA-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||
},
|
||||
"LLaMA-30B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||
},
|
||||
"LLaMA-65B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||
"LLaMA2-7B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||
},
|
||||
"LLaMA2-13B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||
},
|
||||
"LLaMA2-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||
},
|
||||
"LLaMA2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||
},
|
||||
"LLaMA2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||
},
|
||||
"LLaMA2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||
},
|
||||
},
|
||||
template="llama2"
|
||||
template="llama2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
"Mistral-7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
||||
},
|
||||
"Mistral-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
||||
},
|
||||
"Mistral-7B-v0.2-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
||||
},
|
||||
},
|
||||
template="mistral"
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||
"Mixtral-8x7B": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
||||
},
|
||||
"Mixtral-8x7B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
||||
},
|
||||
},
|
||||
module="Wqkv"
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||
}
|
||||
},
|
||||
template="openchat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
|
||||
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
|
||||
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
|
||||
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
|
||||
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||
},
|
||||
"Qwen-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||
},
|
||||
"Qwen-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||
},
|
||||
"Qwen-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||
},
|
||||
"Qwen-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||
},
|
||||
"Qwen-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||
},
|
||||
"Qwen-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||
},
|
||||
"Qwen-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||
},
|
||||
"Qwen-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen"
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
|
||||
"SOLAR-10.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
||||
},
|
||||
},
|
||||
template="xverse"
|
||||
template="solar",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": "01-ai/Yi-6B",
|
||||
"Yi-34B": "01-ai/Yi-34B"
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||
},
|
||||
"Vicuna1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
|
||||
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
|
||||
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
|
||||
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
|
||||
},
|
||||
template="xuanyuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
|
||||
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
|
||||
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
|
||||
"XVERSE-65B-2": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
||||
},
|
||||
"XVERSE-65B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
||||
},
|
||||
},
|
||||
template="yayi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
|
||||
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
|
||||
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
|
||||
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
|
||||
"Yi-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yuan2-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
|
||||
},
|
||||
"Yuan2-51B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
|
||||
},
|
||||
"Yuan2-102B-Chat": {
|
||||
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
||||
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
|
||||
},
|
||||
},
|
||||
template="yuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Zephyr-7B-Alpha-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
|
||||
},
|
||||
"Zephyr-7B-Beta-Chat": {
|
||||
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
||||
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
||||
},
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import sys
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -19,19 +22,12 @@ class LoggerHandler(logging.Handler):
|
||||
self.log += "\n\n"
|
||||
|
||||
|
||||
def reset_logging():
|
||||
r"""
|
||||
Removes basic config of root logger
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
@@ -41,3 +37,12 @@ def get_logger(name: str) -> logging.Logger:
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
@@ -1,32 +1,44 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
except Exception:
|
||||
_is_bf16_available = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
@@ -65,6 +77,83 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
@@ -77,15 +166,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
@@ -96,37 +176,20 @@ def torch_gc() -> None:
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
return
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
def use_modelscope() -> bool:
|
||||
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
||||
|
||||
49
src/llmtuner/extras/packages.py
Normal file
49
src/llmtuner/extras/packages.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def get_package_version(name: str) -> str:
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
except Exception:
|
||||
return "0.0.0"
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return is_package_available("nltk")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return is_package_available("requests")
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return is_package_available("uvicorn")
|
||||
@@ -1,221 +1,197 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||
is_flash_attn_2_available = True
|
||||
except ImportError:
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
class LlamaShiftShortAttention(LlamaAttention):
|
||||
def llama_torch_attn_forward(
|
||||
self: "LlamaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
if getattr(self, "num_key_value_groups"):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# cast to half precision
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(self.config.torch_dtype)
|
||||
key_states = key_states.to(self.config.torch_dtype)
|
||||
value_states = value_states.to(self.config.torch_dtype)
|
||||
|
||||
if getattr(self, "num_key_value_groups", None):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
||||
|
||||
if attention_mask is not None:
|
||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
||||
# -q_len: assumes left padding when q_len != kv_len
|
||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
unpadded_q,
|
||||
unpadded_k,
|
||||
unpadded_v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_flash_attn_forward(
|
||||
self: "LlamaFlashAttention2",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
||||
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # This uses the faster call when training with full samples
|
||||
|
||||
return attention_mask
|
||||
def apply_llama_patch() -> None:
|
||||
LlamaAttention.forward = llama_torch_attn_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from .logging import get_logger
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -17,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
@@ -26,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
@@ -1,769 +0,0 @@
|
||||
import tiktoken
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
prefix: List[Union[str, Dict[str, str]]]
|
||||
prompt: List[Union[str, Dict[str, str]]]
|
||||
system: str
|
||||
sep: List[Union[str, Dict[str, str]]]
|
||||
stop_words: List[str]
|
||||
use_history: bool
|
||||
efficient_eos: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
return encoded_pairs
|
||||
|
||||
def _format(
|
||||
self,
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Aligns inputs to the standard format.
|
||||
"""
|
||||
system = system or self.system # use system if provided
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, resp)]
|
||||
return system, history
|
||||
|
||||
def _get_special_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||
bos_ids = [tokenizer.bos_token_id]
|
||||
else: # baichuan, qwen and gpt2 models have no bos token
|
||||
bos_ids = []
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
raise ValueError("EOS token is required.")
|
||||
|
||||
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
|
||||
eos_ids = []
|
||||
else:
|
||||
eos_ids = [tokenizer.eos_token_id]
|
||||
|
||||
return bos_ids, eos_ids
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + sep + query resp + eos
|
||||
Turn t: sep + bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
||||
if len(prefix_ids) != 0: # has prefix
|
||||
prefix_ids = bos_ids + prefix_ids + sep_ids
|
||||
else:
|
||||
prefix_ids = bos_ids
|
||||
else:
|
||||
prefix_ids = sep_ids + bos_ids
|
||||
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
||||
return encoded_pairs
|
||||
|
||||
def _convert_inputs_to_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
context: List[Union[str, Dict[str, str]]],
|
||||
system: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
idx: Optional[str] = None
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts context to token ids.
|
||||
"""
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=False)
|
||||
|
||||
token_ids = []
|
||||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||
if len(elem) != 0:
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
else:
|
||||
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + query resp + eos
|
||||
Turn t: bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0: # llama2 template has no sep_ids
|
||||
query = self.prefix[0].replace("{{system}}", system) + query
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
name: str,
|
||||
prefix: List[Union[str, Dict[str, str]]],
|
||||
prompt: List[Union[str, Dict[str, str]]],
|
||||
system: str,
|
||||
sep: List[Union[str, Dict[str, str]]],
|
||||
stop_words: Optional[List[str]] = [],
|
||||
use_history: Optional[bool] = True,
|
||||
efficient_eos: Optional[bool] = False
|
||||
) -> None:
|
||||
template_class = Llama2Template if "llama2" in name else Template
|
||||
templates[name] = template_class(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
sep=sep,
|
||||
stop_words=stop_words,
|
||||
use_history=use_history,
|
||||
efficient_eos=efficient_eos
|
||||
)
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(
|
||||
name: str,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Template:
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=template.stop_words),
|
||||
replace_additional_special_tokens=False
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
"""
|
||||
register_template(
|
||||
name="alpaca",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-34B
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}###Assistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
sep=[
|
||||
"###"
|
||||
],
|
||||
stop_words=[
|
||||
"</s>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<reserved_102>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<reserved_106>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_107>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nBelle: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="bluelm",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "[|Human|]:"},
|
||||
"{{query}}",
|
||||
{"token": "[|AI|]:"}
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm3-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n",
|
||||
"{{query}}",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
|
||||
"""
|
||||
register_template(
|
||||
name="deepseek",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer."
|
||||
),
|
||||
sep=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
register_template(
|
||||
name="default",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
|
||||
"""
|
||||
register_template(
|
||||
name="falcon",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"User: {{query}}\nFalcon:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
https://huggingface.co/internlm/internlm-chat-20b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|User|>:{{query}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
|
||||
"""
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"""
|
||||
register_template(
|
||||
name="mistral",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/openchat/openchat_3.5
|
||||
"""
|
||||
register_template(
|
||||
name="openchat",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"GPT4 Correct User: {{query}}",
|
||||
{"token": "<|end_of_turn|>"},
|
||||
"GPT4 Correct Assistant:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end_of_turn|>"}
|
||||
],
|
||||
stop_words=[
|
||||
"<|end_of_turn|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="qwen",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
"user\n{{query}}",
|
||||
{"token": "<|im_end|>"},
|
||||
"\n",
|
||||
{"token": "<|im_start|>"},
|
||||
"assistant\n"
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
{"token": "<|im_end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||
https://huggingface.co/HuggingFaceH4/starchat-beta
|
||||
"""
|
||||
register_template(
|
||||
name="starchat",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{query}}",
|
||||
{"token": "<|end|>"},
|
||||
"\n",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix=[],
|
||||
prompt=[
|
||||
"{{query}}"
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
|
||||
https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"USER: {{query}} ASSISTANT:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
|
||||
https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="xverse",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nAssistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/wenge-research/yayi-7b
|
||||
https://huggingface.co/wenge-research/yayi-7b-llama2
|
||||
https://huggingface.co/wenge-research/yayi-13b-llama2
|
||||
"""
|
||||
register_template(
|
||||
name="yayi",
|
||||
prefix=[
|
||||
{"token": "<|System|>"},
|
||||
":\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|Human|>"},
|
||||
":\n{{query}}\n\n",
|
||||
{"token": "<|YaYi|>"},
|
||||
":"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|End|>"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
||||
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
|
||||
"""
|
||||
register_template(
|
||||
name="zephyr",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
{"token": "</s>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{query}}",
|
||||
{"token": "</s>"},
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<human>"},
|
||||
":{{query}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
@@ -3,3 +3,16 @@ from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
from .parser import get_eval_args, get_infer_args, get_train_args
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataArguments",
|
||||
"EvaluationArguments",
|
||||
"FinetuningArguments",
|
||||
"GeneratingArguments",
|
||||
"ModelArguments",
|
||||
"get_eval_args",
|
||||
"get_infer_args",
|
||||
"get_train_args",
|
||||
]
|
||||
|
||||
@@ -1,30 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List, Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
messages: Optional[str] = "conversations"
|
||||
role: Optional[str] = "from"
|
||||
content: Optional[str] = "value"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,137 +8,74 @@ class DataArguments:
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."}
|
||||
default="data", metadata={"help": "Path to the folder containing the datasets."}
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
cutoff_len: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
reserved_label_len: Optional[int] = field(
|
||||
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable dataset streaming."}
|
||||
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
default=None, metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
val_size: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
)
|
||||
sft_packing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
default=None, metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.reserved_label_len >= self.cutoff_len:
|
||||
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.streaming and self.cache_path:
|
||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
@@ -10,46 +10,20 @@ class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."}
|
||||
)
|
||||
task: str = field(metadata={"help": "Name of the evaluation task."})
|
||||
task_dir: Optional[str] = field(
|
||||
default="evaluation",
|
||||
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||
)
|
||||
batch_size: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The batch size per GPU for evaluation."}
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed to be used with data loaders."}
|
||||
)
|
||||
lang: Optional[Literal["en", "zh"]] = field(
|
||||
default="en",
|
||||
metadata={"help": "Language used at evaluation."}
|
||||
)
|
||||
n_shot: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."}
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."}
|
||||
default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||
)
|
||||
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
|
||||
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
|
||||
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
|
||||
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
|
||||
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
|
||||
download_mode: Optional[DownloadMode] = field(
|
||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
metadata={"help": "Download mode used for the evaluation datasets."}
|
||||
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
task_available = []
|
||||
for folder in os.listdir(self.task_dir):
|
||||
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
||||
task_available.append(folder)
|
||||
|
||||
if self.task not in task_available:
|
||||
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
||||
|
||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||
raise ValueError("`save_dir` already exists, use another one.")
|
||||
|
||||
@@ -1,105 +1,156 @@
|
||||
import json
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
name_module_trainable: Optional[str] = field(
|
||||
default="mlp",
|
||||
metadata={
|
||||
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
Phi choices: ["mlp", "mixer"], \
|
||||
Others choices: the same as LLaMA.'
|
||||
},
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
|
||||
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
|
||||
lora_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
|
||||
Others choices: the same as LLaMA.'
|
||||
},
|
||||
)
|
||||
lora_bf16_mode: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
|
||||
)
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
"""
|
||||
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||
default="sigmoid", metadata={"help": "The type of DPO loss to use."}
|
||||
)
|
||||
dpo_ftx: Optional[float] = field(
|
||||
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||
)
|
||||
ppo_buffer_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||
)
|
||||
ppo_epochs: Optional[int] = field(
|
||||
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization in PPO training."}
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
)
|
||||
ppo_whiten_rewards: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapters of the reference model."}
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the reference model."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the reward model used for the PPO training."}
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapters of the reward model."}
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the reward model."}
|
||||
)
|
||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
default="sft", metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
default="lora", metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."}
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
||||
)
|
||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
neft_alpha: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to save the training loss curves."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
if isinstance(self.additional_target, str):
|
||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
||||
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
@@ -112,4 +163,5 @@ class FinetuningArguments:
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
return cls(**json.loads(text))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -8,40 +8,37 @@ class GeneratingArguments:
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.95,
|
||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
||||
metadata={
|
||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
)
|
||||
length_penalty: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||
default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -8,67 +8,85 @@ class ModelArguments:
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||
)
|
||||
resize_vocab: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
|
||||
)
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
default=None, metadata={"help": "The number of bits to quantize the model."}
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use in int4 training."}
|
||||
default="nf4", metadata={"help": "Quantization data type to use in int4 training."}
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||
default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
)
|
||||
shift_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
use_unsloth: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
disable_gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
upcast_lmhead_output: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
|
||||
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
default=None, metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
export_size: Optional[int] = field(
|
||||
default=1, metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the exported model."}
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||
)
|
||||
export_quantization_nsamples: Optional[int] = field(
|
||||
default=128, metadata={"help": "The number of samples used for quantization."}
|
||||
)
|
||||
export_quantization_maxlen: Optional[int] = field(
|
||||
default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||
)
|
||||
export_legacy_format: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -78,11 +96,14 @@ class ModelArguments:
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
237
src/llmtuner/hparams/parser.py
Normal file
237
src/llmtuner/hparams/parser.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if finetuning_args.create_new_adapter:
|
||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
_set_transformers_logging()
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
||||
5
src/llmtuner/model/__init__.py
Normal file
5
src/llmtuner/model/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .loader import load_model_and_tokenizer
|
||||
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
||||
|
||||
|
||||
__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]
|
||||
138
src/llmtuner/model/adapter.py
Normal file
138
src/llmtuner/model/adapter.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import find_all_linear_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
num_layers = (
|
||||
getattr(model.config, "num_hidden_layers", None)
|
||||
or getattr(model.config, "num_layers", None)
|
||||
or getattr(model.config, "n_layer", None)
|
||||
)
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
is_mergeable = True
|
||||
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model = PeftModel.from_pretrained(model, adapter)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
|
||||
unsloth_peft_kwargs["loftq_config"] = {}
|
||||
|
||||
if getattr(model.config, "model_type", None) == "llama":
|
||||
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
elif getattr(model.config, "model_type", None) == "mistral":
|
||||
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
135
src/llmtuner/model/loader.py
Normal file
135
src/llmtuner/model/loader.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .utils import load_valuehead_params, register_autoclass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
add_valuehead: Optional[bool] = False,
|
||||
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
try_download_model_from_ms(model_args)
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.hf_hub_token,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right",
|
||||
**config_kwargs,
|
||||
)
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
|
||||
unsloth_kwargs = {
|
||||
"model_name": model_args.model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length,
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": get_current_device(),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
|
||||
elif getattr(config, "model_type", None) == "mistral":
|
||||
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
|
||||
else:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
|
||||
if model_args.adapter_name_or_path:
|
||||
model_args.adapter_name_or_path = None
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
patch_valuehead_model(model)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
vhead_path = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
vhead_path = model_args.model_name_or_path
|
||||
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info(
|
||||
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
)
|
||||
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
299
src/llmtuner/model/patcher.py
Normal file
299
src/llmtuner/model/patcher.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
params = [model.get_input_embeddings().weight]
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
params.append(model.get_output_embeddings().weight)
|
||||
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
logger.warning("Current model does not support resizing token embeddings.")
|
||||
return
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = model_args.export_quantization_dataset
|
||||
else:
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
|
||||
samples = []
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
while True:
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info(
|
||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||
)
|
||||
|
||||
|
||||
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
return
|
||||
|
||||
config_kwargs["use_flash_attention_2"] = True
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def _configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
config_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
config_kwargs["device_map"] = "auto"
|
||||
config_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
|
||||
def _prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
|
||||
) -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
|
||||
if model_args.flash_attn:
|
||||
_configure_flashattn(config_kwargs)
|
||||
|
||||
if is_trainable and model_args.shift_attn:
|
||||
_configure_longlora(config)
|
||||
|
||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||
|
||||
|
||||
def patch_model(
|
||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
if model_args.resize_vocab:
|
||||
_resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if is_trainable:
|
||||
_prepare_model_for_training(model, model_args)
|
||||
|
||||
|
||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
self.pretrained_model.tie_weights()
|
||||
|
||||
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
125
src/llmtuner/model/utils.py
Normal file
125
src/llmtuner/model/utils.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||
"""
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and isinstance(model, PreTrainedModel)
|
||||
and model._no_split_modules is not None
|
||||
and model.config.model_type != "chatglm"
|
||||
):
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
return dispatch_model(model, **device_map_kwargs)
|
||||
else:
|
||||
return model.to(device=get_current_device())
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
"""
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == "bitsandbytes":
|
||||
import bitsandbytes as bnb
|
||||
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def get_modelcard_args(
|
||||
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"license": "other",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
|
||||
}
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
|
||||
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
|
||||
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
|
||||
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
||||
return {key: f.get_tensor(key) for key in f.keys()}
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
|
||||
|
||||
try:
|
||||
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
except Exception as err:
|
||||
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
|
||||
|
||||
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
|
||||
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
4
src/llmtuner/train/__init__.py
Normal file
4
src/llmtuner/train/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .tuner import export_model, run_exp
|
||||
|
||||
|
||||
__all__ = ["export_model", "run_exp"]
|
||||
4
src/llmtuner/train/dpo/__init__.py
Normal file
4
src/llmtuner/train/dpo/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_dpo
|
||||
|
||||
|
||||
__all__ = ["run_dpo"]
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -20,7 +21,7 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||
padded_tensor[start:end] = feature[start:end]
|
||||
padded_labels.append(padded_tensor)
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
@@ -34,10 +35,12 @@ class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
for key in ("chosen_ids", "rejected_ids"):
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||
concatenated_features.append({
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len)
|
||||
})
|
||||
concatenated_features.append(
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len),
|
||||
}
|
||||
)
|
||||
label_positions.append((prompt_len, answer_len))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
148
src/llmtuner/train/dpo/trainer.py
Normal file
148
src/llmtuner/train/dpo/trainer.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
self.precompute_ref_log_probs = False
|
||||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
self.ref_model = ref_model
|
||||
self.beta = beta
|
||||
self.label_smoothing = 0
|
||||
self.loss_type = loss_type
|
||||
self.ftx_gamma = ftx_gamma
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||
return -all_logps
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"], attention_mask=batch_copied["attention_mask"], return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
train_eval: Optional[Literal["train", "eval"]] = "train",
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
if self.ftx_gamma > 1e-6:
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||
|
||||
return losses.mean(), metrics
|
||||
84
src/llmtuner/train/dpo/workflow.py
Normal file
84
src/llmtuner/train/dpo/workflow.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||
from ...train.dpo.trainer import CustomDPOTrainer
|
||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
ftx_gamma=finetuning_args.dpo_ftx,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
4
src/llmtuner/train/ppo/__init__.py
Normal file
4
src/llmtuner/train/ppo/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_ppo
|
||||
|
||||
|
||||
__all__ = ["run_ppo"]
|
||||
@@ -1,25 +1,28 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -37,36 +40,61 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
**kwargs
|
||||
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||
**kwargs,
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
|
||||
self.args = training_args
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reward_model = reward_model
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
**generating_args.to_dict()
|
||||
**generating_args.to_dict(),
|
||||
)
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
def ppo_train(self) -> None:
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
if resume_from_checkpoint is not None:
|
||||
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||
|
||||
total_train_batch_size = (
|
||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||
self.args.per_device_train_batch_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
* self.finetuning_args.ppo_buffer_size
|
||||
* self.args.world_size
|
||||
)
|
||||
if self.args.max_steps > 0:
|
||||
num_examples = total_train_batch_size * self.args.max_steps
|
||||
num_train_epochs = sys.maxsize
|
||||
max_steps = self.args.max_steps
|
||||
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
|
||||
steps_in_epoch = self.args.max_steps
|
||||
else:
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_examples = len(self.dataset)
|
||||
@@ -81,13 +109,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
logger.info(" Num examples = {}".format(num_examples))
|
||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
dataiter = iter(self.dataloader)
|
||||
@@ -108,10 +141,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(
|
||||
batch[idx : idx + self.config.mini_batch_size]
|
||||
)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
@@ -124,7 +159,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
@@ -133,18 +168,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||
logs = dict(
|
||||
loss=round(loss_meter.avg, 4),
|
||||
reward=round(reward_meter.avg, 4),
|
||||
learning_rate=stats["ppo/learning_rate"],
|
||||
epoch=round(step / steps_in_epoch, 2)
|
||||
epoch=round(step / steps_in_epoch, 2),
|
||||
)
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
@@ -153,10 +188,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(
|
||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
||||
))
|
||||
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(
|
||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||
)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
@@ -170,36 +205,40 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
response_length = 1 # allow empty response
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
return queries, responses
|
||||
|
||||
@@ -208,27 +247,41 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self,
|
||||
queries: List[torch.Tensor],
|
||||
responses: List[torch.Tensor],
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead",
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
if self.finetuning_args.reward_model_type == "api":
|
||||
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
reward_model = self.model
|
||||
else:
|
||||
reward_model = self.reward_model
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
@@ -239,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
response_masks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
@@ -262,10 +315,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
@@ -274,14 +328,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
@@ -311,4 +363,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
try:
|
||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||
" use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
self.model.save_checkpoint(output_dir)
|
||||
@@ -1,22 +1,40 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
|
||||
from ...extras.packages import is_requests_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
rewards = json.loads(response.text)["scores"]
|
||||
return torch.Tensor(rewards)
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
})
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict(
|
||||
{
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
@@ -1,21 +1,26 @@
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
|
||||
import math
|
||||
from trl import PPOConfig
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from torch.optim import AdamW
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
from trl import PPOConfig
|
||||
|
||||
from ...data import get_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.ppo.trainer import CustomPPOTrainer
|
||||
from ...train.utils import create_ref_model, create_reward_model
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
@@ -24,22 +29,29 @@ def run_ppo(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||
)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Create ppo config
|
||||
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
learning_rate=training_args.learning_rate,
|
||||
mini_batch_size=training_args.per_device_train_batch_size,
|
||||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
|
||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||
ppo_epochs=1,
|
||||
ppo_epochs=finetuning_args.ppo_epochs,
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
optimize_device_cache=True,
|
||||
@@ -47,23 +59,23 @@ def run_ppo(
|
||||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -72,21 +84,24 @@ def run_ppo(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
reward_model=reward_model,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
4
src/llmtuner/train/pt/__init__.py
Normal file
4
src/llmtuner/train/pt/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_pt
|
||||
|
||||
|
||||
__all__ = ["run_pt"]
|
||||
@@ -1,16 +1,20 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
def run_pt(
|
||||
@@ -18,11 +22,10 @@ def run_pt(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -32,7 +35,7 @@ def run_pt(
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -42,7 +45,7 @@ def run_pt(
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
@@ -58,8 +61,4 @@ def run_pt(
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
4
src/llmtuner/train/rm/__init__.py
Normal file
4
src/llmtuner/train/rm/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_rm
|
||||
|
||||
|
||||
__all__ = ["run_rm"]
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
@@ -20,8 +21,9 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
features = [
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])),
|
||||
}
|
||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||
for key in ("chosen_ids", "rejected_ids")
|
||||
for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -21,13 +23,10 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
@@ -39,7 +38,9 @@ class PairwiseTrainer(Trainer):
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
@@ -66,9 +67,9 @@ class PairwiseTrainer(Trainer):
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
@@ -78,10 +79,7 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
return loss
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from ...train.rm.metric import compute_accuracy
|
||||
from ...train.rm.trainer import PairwiseTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
def run_rm(
|
||||
@@ -21,16 +26,17 @@ def run_rm(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
||||
)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -39,19 +45,21 @@ def run_rm(
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
@@ -68,8 +76,4 @@ def run_rm(
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
4
src/llmtuner/train/sft/__init__.py
Normal file
4
src/llmtuner/train/sft/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
@@ -1,16 +1,24 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import numpy as np
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
if is_jieba_available():
|
||||
import jieba
|
||||
|
||||
if is_nltk_available():
|
||||
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||
|
||||
if is_rouge_available():
|
||||
from rouge_chinese import Rouge
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Seq2SeqTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -33,16 +35,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||
if self.args.predict_with_generate:
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
|
||||
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
|
||||
inputs["labels"] = inputs["labels"][:, :prompt_len]
|
||||
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
if generated_tokens is not None and self.args.predict_with_generate:
|
||||
@@ -51,23 +53,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
def _pad_tensors_to_target_len(
|
||||
self,
|
||||
src_tensor: torch.Tensor,
|
||||
tgt_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
"""
|
||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -79,14 +74,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||
labels = np.where(
|
||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||
)
|
||||
preds = np.where(
|
||||
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
for i in range(len(preds)):
|
||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||
if len(pad_len):
|
||||
preds[i] = np.concatenate(
|
||||
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
|
||||
) # move pad token to last
|
||||
|
||||
decoded_labels = self.tokenizer.batch_decode(
|
||||
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for pred, label in zip(decoded_preds, decoded_labels):
|
||||
for label, pred in zip(decoded_labels, decoded_preds):
|
||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||
writer.write("\n".join(res))
|
||||
@@ -1,19 +1,23 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import get_logits_processor
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model_and_tokenizer
|
||||
from ...train.sft.metric import ComputeMetrics
|
||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
@@ -22,27 +26,31 @@ def run_sft(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(
|
||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
|
||||
))
|
||||
training_args_dict.update(
|
||||
dict(
|
||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams,
|
||||
)
|
||||
)
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -53,7 +61,7 @@ def run_sft(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
@@ -69,13 +77,13 @@ def run_sft(
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||
metrics.pop("eval_loss", None)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
@@ -83,15 +91,11 @@ def run_sft(
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
90
src/llmtuner/train/tuner.py
Normal file
90
src/llmtuner/train/tuner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.logging import get_logger
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model_and_tokenizer
|
||||
from .dpo import run_dpo
|
||||
from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
raise ValueError("Please specify `export_dir`.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||
raise ValueError("Please merge adapters before quantizing the model.")
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
|
||||
raise ValueError("Cannot merge adapters to a quantized model.")
|
||||
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||
|
||||
setattr(model.config, "use_cache", True)
|
||||
if getattr(model.config, "torch_dtype", None) == "bfloat16":
|
||||
model = model.to(torch.bfloat16).to("cpu")
|
||||
else:
|
||||
model = model.to(torch.float16).to("cpu")
|
||||
setattr(model.config, "torch_dtype", "float16")
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
max_shard_size="{}GB".format(model_args.export_size),
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
if model_args.export_hub_model_id is not None:
|
||||
model.push_to_hub(
|
||||
model_args.export_hub_model_id,
|
||||
token=model_args.hf_hub_token,
|
||||
max_shard_size="{}GB".format(model_args.export_size),
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
if model_args.export_hub_model_id is not None:
|
||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||
except Exception:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_exp()
|
||||
116
src/llmtuner/train/utils.py
Normal file
116
src/llmtuner/train/utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_modelcard_and_push(
|
||||
trainer: "Trainer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
|
||||
return
|
||||
try:
|
||||
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
|
||||
except Exception as err:
|
||||
logger.warning("Failed to create model card: {}".format(str(err)))
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
if finetuning_args.ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(
|
||||
dict(
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
adapter_name_or_path=finetuning_args.ref_model_adapters,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit,
|
||||
)
|
||||
)
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(
|
||||
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||
)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
return ref_model
|
||||
|
||||
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> "AutoModelForCausalLMWithValueHead":
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||
return finetuning_args.reward_model
|
||||
elif finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
||||
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer(
|
||||
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
|
||||
)
|
||||
model.register_buffer(
|
||||
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||
)
|
||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return None
|
||||
else:
|
||||
reward_model_args_dict = model_args.to_dict()
|
||||
reward_model_args_dict.update(
|
||||
dict(
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
adapter_name_or_path=finetuning_args.reward_model_adapters,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit,
|
||||
)
|
||||
)
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
reward_model, _ = load_model_and_tokenizer(
|
||||
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.tune import export_model, run_exp
|
||||
@@ -1,3 +0,0 @@
|
||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.utils import generate_model_card
|
||||
@@ -1,129 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
num_layers = getattr(model.config, "num_layers")
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
||||
|
||||
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
checkpoint_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if is_trainable and finetuning_args.resume_lora_training:
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token,
|
||||
"revision": model_args.model_revision
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
@@ -1,236 +0,0 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is not None: # for training
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
else: # for evaluation, priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if LlamaPatches.is_flash_attn_2_available:
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pre-trained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
@@ -1,213 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
||||
))
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
@@ -1,107 +0,0 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
linear_cls = torch.nn.Linear
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, linear_cls)
|
||||
and not any([output_layer in name for output_layer in output_layer_names])
|
||||
):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def generate_model_card(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
||||
}
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
if finetuning_args.upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting weights in layernorm in float32.")
|
||||
|
||||
if finetuning_args.neft_alpha > 1e-6:
|
||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear):
|
||||
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
||||
return args[0].to(output_layer.weight.dtype)
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
return model
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.dpo.workflow import run_dpo
|
||||
@@ -1,71 +0,0 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
@@ -1,102 +0,0 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.dpo_ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
||||
elif training_args.do_train:
|
||||
if isinstance(model, PeftModel):
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from the model itself.")
|
||||
else:
|
||||
ref_model = model
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.ppo.workflow import run_ppo
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.pt.workflow import run_pt
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.rm.workflow import run_rm
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.tuner.sft.workflow import run_sft
|
||||
@@ -1,51 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.dpo import run_dpo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_exp()
|
||||
@@ -1 +1,4 @@
|
||||
from llmtuner.webui.interface import create_ui, create_web_demo
|
||||
from .interface import create_ui, create_web_demo
|
||||
|
||||
|
||||
__all__ = ["create_ui", "create_web_demo"]
|
||||
|
||||
@@ -1,27 +1,47 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import GeneratingArguments
|
||||
from .common import get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
from .manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
||||
def __init__(
|
||||
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if not lazy_init:
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
|
||||
if demo_mode: # load demo_config.json if exists
|
||||
import json
|
||||
|
||||
try:
|
||||
with open("demo_config.json", "r", encoding="utf-8") as f:
|
||||
args = json.load(f)
|
||||
assert args.get("model_name_or_path", None) and args.get("template", None)
|
||||
super().__init__(args)
|
||||
except AssertionError:
|
||||
print("Please provided model name and template in `demo_config.json`.")
|
||||
except Exception:
|
||||
print("Cannot find `demo_config.json` at current directory.")
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
@@ -36,30 +56,34 @@ class WebChatModel(ChatModel):
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif self.demo_mode:
|
||||
error = ALERTS["err_demo"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
adapter_name_or_path = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
flash_attn=(get("top.booster") == "flash_attn"),
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
@@ -67,6 +91,12 @@ class WebChatModel(ChatModel):
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
gr.Warning(ALERTS["err_demo"][lang])
|
||||
yield ALERTS["err_demo"][lang]
|
||||
return
|
||||
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
@@ -77,21 +107,37 @@ class WebChatModel(ChatModel):
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
if tools:
|
||||
result = self.template.format_tools.extract(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
yield chatbot, output_messages
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user