Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95d0f77fc2 | ||
|
|
9b2654277b | ||
|
|
f1b3bdac3f | ||
|
|
595fdbd95d | ||
|
|
dab9385297 | ||
|
|
df83def566 | ||
|
|
f9d4e37b3c | ||
|
|
e59a3d71e0 | ||
|
|
de3a84ac59 | ||
|
|
e017266b98 | ||
|
|
f81a8a5e5c | ||
|
|
7a3a0144a5 | ||
|
|
8263b2d32d | ||
|
|
833cd490b8 | ||
|
|
2162c37e41 | ||
|
|
b2ac8376e1 | ||
|
|
8079584143 | ||
|
|
09a4474e7f | ||
|
|
81530133ff | ||
|
|
cc4b384ac3 | ||
|
|
3852daf447 | ||
|
|
5c97111f9d | ||
|
|
75dd1f0f7e | ||
|
|
c9a4551012 | ||
|
|
87197ba91d | ||
|
|
7461bf84e5 | ||
|
|
fbc0357b2e | ||
|
|
ec334f5891 | ||
|
|
885efe772e | ||
|
|
64fc9ba678 | ||
|
|
989eccd286 | ||
|
|
f0766a2ab0 | ||
|
|
178b85ff9a | ||
|
|
68dd1ef121 | ||
|
|
b222cffe98 | ||
|
|
b4f1ab93d1 | ||
|
|
f2e139f5cd |
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# Contributor Covenant Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
We as members, contributors, and leaders pledge to make participation in our
|
||||||
|
community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||||
|
identity and expression, level of experience, education, socio-economic status,
|
||||||
|
nationality, personal appearance, race, religion, or sexual identity
|
||||||
|
and orientation.
|
||||||
|
|
||||||
|
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||||
|
diverse, inclusive, and healthy community.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to a positive environment for our
|
||||||
|
community include:
|
||||||
|
|
||||||
|
* Demonstrating empathy and kindness toward other people
|
||||||
|
* Being respectful of differing opinions, viewpoints, and experiences
|
||||||
|
* Giving and gracefully accepting constructive feedback
|
||||||
|
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||||
|
and learning from the experience
|
||||||
|
* Focusing on what is best not just for us as individuals, but for the
|
||||||
|
overall community
|
||||||
|
|
||||||
|
Examples of unacceptable behavior include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery, and sexual attention or
|
||||||
|
advances of any kind
|
||||||
|
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or email
|
||||||
|
address, without their explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Enforcement Responsibilities
|
||||||
|
|
||||||
|
Community leaders are responsible for clarifying and enforcing our standards of
|
||||||
|
acceptable behavior and will take appropriate and fair corrective action in
|
||||||
|
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||||
|
or harmful.
|
||||||
|
|
||||||
|
Community leaders have the right and responsibility to remove, edit, or reject
|
||||||
|
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||||
|
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||||
|
decisions when appropriate.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all community spaces, and also applies when
|
||||||
|
an individual is officially representing the community in public spaces.
|
||||||
|
Examples of representing our community include using an official e-mail address,
|
||||||
|
posting via an official social media account, or acting as an appointed
|
||||||
|
representative at an online or offline event.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported to the community leaders responsible for enforcement at
|
||||||
|
`hoshihiyouga AT gmail DOT com`.
|
||||||
|
All complaints will be reviewed and investigated promptly and fairly.
|
||||||
|
|
||||||
|
All community leaders are obligated to respect the privacy and security of the
|
||||||
|
reporter of any incident.
|
||||||
|
|
||||||
|
## Enforcement Guidelines
|
||||||
|
|
||||||
|
Community leaders will follow these Community Impact Guidelines in determining
|
||||||
|
the consequences for any action they deem in violation of this Code of Conduct:
|
||||||
|
|
||||||
|
### 1. Correction
|
||||||
|
|
||||||
|
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||||
|
unprofessional or unwelcome in the community.
|
||||||
|
|
||||||
|
**Consequence**: A private, written warning from community leaders, providing
|
||||||
|
clarity around the nature of the violation and an explanation of why the
|
||||||
|
behavior was inappropriate. A public apology may be requested.
|
||||||
|
|
||||||
|
### 2. Warning
|
||||||
|
|
||||||
|
**Community Impact**: A violation through a single incident or series
|
||||||
|
of actions.
|
||||||
|
|
||||||
|
**Consequence**: A warning with consequences for continued behavior. No
|
||||||
|
interaction with the people involved, including unsolicited interaction with
|
||||||
|
those enforcing the Code of Conduct, for a specified period of time. This
|
||||||
|
includes avoiding interactions in community spaces as well as external channels
|
||||||
|
like social media. Violating these terms may lead to a temporary or
|
||||||
|
permanent ban.
|
||||||
|
|
||||||
|
### 3. Temporary Ban
|
||||||
|
|
||||||
|
**Community Impact**: A serious violation of community standards, including
|
||||||
|
sustained inappropriate behavior.
|
||||||
|
|
||||||
|
**Consequence**: A temporary ban from any sort of interaction or public
|
||||||
|
communication with the community for a specified period of time. No public or
|
||||||
|
private interaction with the people involved, including unsolicited interaction
|
||||||
|
with those enforcing the Code of Conduct, is allowed during this period.
|
||||||
|
Violating these terms may lead to a permanent ban.
|
||||||
|
|
||||||
|
### 4. Permanent Ban
|
||||||
|
|
||||||
|
**Community Impact**: Demonstrating a pattern of violation of community
|
||||||
|
standards, including sustained inappropriate behavior, harassment of an
|
||||||
|
individual, or aggression toward or disparagement of classes of individuals.
|
||||||
|
|
||||||
|
**Consequence**: A permanent ban from any sort of public interaction within
|
||||||
|
the community.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||||
|
version 2.0, available at
|
||||||
|
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||||
|
|
||||||
|
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||||
|
enforcement ladder](https://github.com/mozilla/diversity).
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see the FAQ at
|
||||||
|
https://www.contributor-covenant.org/faq. Translations are available at
|
||||||
|
https://www.contributor-covenant.org/translations.
|
||||||
20
README.md
20
README.md
@@ -6,7 +6,8 @@
|
|||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/e73gccsSd)
|
[](https://discord.gg/c2EPEt5NU)
|
||||||
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg).
|
||||||
|
|
||||||
@@ -14,7 +15,9 @@
|
|||||||
|
|
||||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||||
|
|
||||||
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
|
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
|
||||||
|
|
||||||
|
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
|
||||||
|
|
||||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||||
|
|
||||||
@@ -57,7 +60,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
@@ -71,7 +74,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
>
|
>
|
||||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
||||||
|
|
||||||
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
|
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
@@ -79,9 +82,9 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
|||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||||
@@ -122,6 +125,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
|
|||||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
@@ -158,7 +162,7 @@ huggingface-cli login
|
|||||||
- Python 3.8+ and PyTorch 1.13.1+
|
- Python 3.8+ and PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
- sentencepiece, protobuf and tiktoken
|
- sentencepiece, protobuf and tiktoken
|
||||||
- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
|
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||||
- gradio and matplotlib (used in web UI)
|
- gradio and matplotlib (used in web UI)
|
||||||
- uvicorn, fastapi and sse-starlette (used in API)
|
- uvicorn, fastapi and sse-starlette (used in API)
|
||||||
|
|
||||||
|
|||||||
20
README_zh.md
20
README_zh.md
@@ -6,7 +6,8 @@
|
|||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/e73gccsSd)
|
[](https://discord.gg/c2EPEt5NU)
|
||||||
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||||
|
|
||||||
@@ -14,7 +15,9 @@
|
|||||||
|
|
||||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
|
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
||||||
|
|
||||||
|
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
||||||
|
|
||||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||||
|
|
||||||
@@ -57,7 +60,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
@@ -71,7 +74,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
>
|
>
|
||||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
||||||
|
|
||||||
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
|
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||||
|
|
||||||
## 训练方法
|
## 训练方法
|
||||||
|
|
||||||
@@ -79,9 +82,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||||
@@ -122,6 +125,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||||
|
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
@@ -158,7 +162,7 @@ huggingface-cli login
|
|||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||||
- sentencepiece, protobuf 和 tiktoken
|
- sentencepiece, protobuf 和 tiktoken
|
||||||
- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
|
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
- gradio 和 matplotlib (用于网页端交互)
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
|
|
||||||
def _info(self):
|
def _info(self):
|
||||||
features = datasets.Features({
|
features = datasets.Features({
|
||||||
"instruction": datasets.Value("string"),
|
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||||
"output": datasets.Value("string"),
|
|
||||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
|
||||||
})
|
})
|
||||||
return datasets.DatasetInfo(
|
return datasets.DatasetInfo(
|
||||||
description=_DESCRIPTION,
|
description=_DESCRIPTION,
|
||||||
@@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
conversations = []
|
||||||
prompt = data["instruction"].strip()
|
prompt = data["instruction"].strip()
|
||||||
response = data["output"].strip()
|
response = data["output"].strip()
|
||||||
|
|
||||||
@@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
human_idx = prompt.rfind("Human:")
|
human_idx = prompt.rfind("Human:")
|
||||||
query = prompt[human_idx+6:assist_idx].strip()
|
query = prompt[human_idx+6:assist_idx].strip()
|
||||||
prompt = prompt[:human_idx].strip()
|
prompt = prompt[:human_idx].strip()
|
||||||
history = []
|
conversations.insert(0, {"from": "gpt", "value": response})
|
||||||
|
conversations.insert(0, {"from": "human", "value": query})
|
||||||
|
|
||||||
while prompt.rfind("Assistant:") != -1:
|
while prompt.rfind("Assistant:") != -1:
|
||||||
assist_idx = prompt.rfind("Assistant:")
|
assist_idx = prompt.rfind("Assistant:")
|
||||||
@@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
if human_idx != -1:
|
if human_idx != -1:
|
||||||
old_query = prompt[human_idx+6:assist_idx].strip()
|
old_query = prompt[human_idx+6:assist_idx].strip()
|
||||||
old_resp = prompt[assist_idx+10:].strip()
|
old_resp = prompt[assist_idx+10:].strip()
|
||||||
history.insert(0, (old_query, old_resp))
|
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
||||||
|
conversations.insert(0, {"from": "human", "value": old_query})
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
prompt = prompt[:human_idx].strip()
|
prompt = prompt[:human_idx].strip()
|
||||||
|
|
||||||
yield key, {
|
yield key, {"conversations": conversations}
|
||||||
"instruction": query,
|
|
||||||
"output": response,
|
|
||||||
"history": history
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||||||
"from": "human" if i % 2 == 0 else "gpt",
|
"from": "human" if i % 2 == 0 else "gpt",
|
||||||
"value": content[i]
|
"value": content[i]
|
||||||
} for i in range(len(content))]
|
} for i in range(len(content))]
|
||||||
yield key, {
|
yield key, {"conversations": conversations}
|
||||||
"conversations": conversations
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ transformers>=4.31.0,<4.35.0
|
|||||||
datasets>=2.14.0
|
datasets>=2.14.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.6.0
|
peft>=0.6.0
|
||||||
trl==0.7.2
|
trl>=0.7.4
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
tiktoken
|
tiktoken
|
||||||
fire
|
|
||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
import readline
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
|
||||||
|
try:
|
||||||
|
import platform
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
import readline
|
||||||
|
except ImportError:
|
||||||
|
print("Install `readline` for a better experience.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -21,6 +28,7 @@ def main():
|
|||||||
|
|
||||||
if query.strip() == "clear":
|
if query.strip() == "clear":
|
||||||
history = []
|
history = []
|
||||||
|
torch_gc()
|
||||||
print("History has been removed.")
|
print("History has been removed.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
190
src/evaluate.py
190
src/evaluate.py
@@ -1,190 +1,10 @@
|
|||||||
# coding=utf-8
|
from llmtuner import Evaluator
|
||||||
# Evaluates the performance of pre-trained models.
|
|
||||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
|
||||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
|
|
||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
|
||||||
|
|
||||||
import os
|
|
||||||
import fire
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import transformers
|
|
||||||
from collections import Counter
|
|
||||||
from datasets import load_dataset
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from datasets import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
choices = ["A", "B", "C", "D"]
|
def main():
|
||||||
|
evaluator = Evaluator()
|
||||||
|
evaluator.eval()
|
||||||
@dataclass
|
|
||||||
class EvalTemplate:
|
|
||||||
|
|
||||||
system: str
|
|
||||||
choice: str
|
|
||||||
answer: str
|
|
||||||
prefix: str
|
|
||||||
|
|
||||||
def parse_example(
|
|
||||||
self,
|
|
||||||
example: Dict[str, str]
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
|
|
||||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
|
||||||
|
|
||||||
def format_example(
|
|
||||||
self,
|
|
||||||
target_data: Dict[str, str],
|
|
||||||
support_set: "Dataset",
|
|
||||||
subject_name: str,
|
|
||||||
use_history: bool
|
|
||||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
|
||||||
query, resp = self.parse_example(target_data)
|
|
||||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
|
||||||
|
|
||||||
if len(history):
|
|
||||||
temp = history.pop(0)
|
|
||||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
|
||||||
else:
|
|
||||||
query = self.system.format(subject=subject_name) + query
|
|
||||||
|
|
||||||
if not use_history:
|
|
||||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
|
||||||
history = []
|
|
||||||
return query.strip(), resp, history
|
|
||||||
|
|
||||||
|
|
||||||
eval_templates = {
|
|
||||||
"en": EvalTemplate(
|
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
|
||||||
choice="\n{choice}. {content}",
|
|
||||||
answer="\nAnswer: ",
|
|
||||||
prefix=" "
|
|
||||||
),
|
|
||||||
"zh": EvalTemplate(
|
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
|
||||||
choice="\n{choice}. {content}",
|
|
||||||
answer="\n答案:",
|
|
||||||
prefix="\n"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def batch_inference(
|
|
||||||
chat_model: ChatModel,
|
|
||||||
batch_input: Dict[str, torch.Tensor],
|
|
||||||
prefix_char: str
|
|
||||||
) -> List[str]:
|
|
||||||
logits = chat_model.model(**batch_input).logits
|
|
||||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
|
||||||
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
|
||||||
probs = torch.nn.functional.softmax(
|
|
||||||
torch.stack(
|
|
||||||
[
|
|
||||||
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
|
||||||
for choice in choices
|
|
||||||
],
|
|
||||||
dim=-1
|
|
||||||
),
|
|
||||||
dim=-1
|
|
||||||
).detach()
|
|
||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
model_name_or_path: str,
|
|
||||||
finetuning_type: Optional[str] = "lora",
|
|
||||||
checkpoint_dir: Optional[str] = None,
|
|
||||||
template: Optional[str] = "vanilla",
|
|
||||||
task: Optional[str] = "ceval",
|
|
||||||
dataset_dir: Optional[str] = "evaluation",
|
|
||||||
split: Optional[Literal["validation", "test"]] = "validation",
|
|
||||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
|
||||||
n_shot: Optional[int] = 5,
|
|
||||||
n_avg: Optional[int] = 1,
|
|
||||||
batch_size: Optional[int] = 4,
|
|
||||||
save_name: Optional[str] = None,
|
|
||||||
seed: Optional[int] = 42
|
|
||||||
):
|
|
||||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
|
||||||
|
|
||||||
transformers.set_seed(seed)
|
|
||||||
chat_model = ChatModel(dict(
|
|
||||||
model_name_or_path=model_name_or_path,
|
|
||||||
finetuning_type=finetuning_type,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
template=template
|
|
||||||
))
|
|
||||||
chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
|
||||||
eval_template = eval_templates[lang]
|
|
||||||
|
|
||||||
category_corrects: Dict[str, np.ndarray] = {
|
|
||||||
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
|
||||||
}
|
|
||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
|
||||||
results = {}
|
|
||||||
for subject in pbar:
|
|
||||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
|
||||||
labels, answers, all_outputs = [], [], []
|
|
||||||
for epoch in range(n_avg):
|
|
||||||
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
|
|
||||||
inputs, outputs = [], []
|
|
||||||
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
|
|
||||||
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
|
|
||||||
query, resp, history = eval_template.format_example(
|
|
||||||
target_data=dataset[split][i],
|
|
||||||
support_set=support_set,
|
|
||||||
subject_name=categorys[subject]["name"],
|
|
||||||
use_history=chat_model.template.use_history
|
|
||||||
)
|
|
||||||
input_ids, _ = chat_model.template.encode_oneturn(
|
|
||||||
tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
|
|
||||||
)
|
|
||||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
|
||||||
if epoch == 0:
|
|
||||||
labels.append(resp)
|
|
||||||
|
|
||||||
for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
|
|
||||||
batch_input = chat_model.tokenizer.pad(
|
|
||||||
inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
|
|
||||||
).to(chat_model.model.device)
|
|
||||||
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
|
|
||||||
outputs += preds
|
|
||||||
all_outputs.append(outputs)
|
|
||||||
|
|
||||||
for i in range(len(all_outputs[0])):
|
|
||||||
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
|
|
||||||
answers.append(count.most_common(1)[0][0])
|
|
||||||
|
|
||||||
corrects = (np.array(answers) == np.array(labels))
|
|
||||||
category_name = categorys[subject]["category"]
|
|
||||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
|
||||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
|
||||||
results[subject] = {str(i): answers[i] for i in range(len(answers))}
|
|
||||||
|
|
||||||
score_info = "\n".join([
|
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
|
||||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
|
||||||
])
|
|
||||||
|
|
||||||
print(score_info)
|
|
||||||
if save_name is not None:
|
|
||||||
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
|
|
||||||
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(score_info)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(evaluate)
|
main()
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||||
|
|
||||||
from llmtuner.api import create_app
|
from llmtuner.api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.tuner import export_model, run_exp
|
from llmtuner.eval import Evaluator
|
||||||
|
from llmtuner.train import export_model, run_exp
|
||||||
from llmtuner.webui import create_ui, create_web_demo
|
from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.3.0"
|
||||||
|
|||||||
@@ -1,14 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI, HTTPException, status
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
|
||||||
from llmtuner.chat import ChatModel
|
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
Role,
|
Role,
|
||||||
Finish,
|
Finish,
|
||||||
@@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
|
|||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage
|
ChatCompletionResponseUsage
|
||||||
)
|
)
|
||||||
|
from llmtuner.chat import ChatModel
|
||||||
|
from llmtuner.extras.misc import torch_gc
|
||||||
|
from llmtuner.extras.packages import (
|
||||||
|
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_fastapi_availble():
|
||||||
|
from fastapi import FastAPI, HTTPException, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
if is_starlette_available():
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
|
|
||||||
|
if is_uvicorn_available():
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI): # collects GPU memory
|
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
@@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
|
|||||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
query = request.messages[-1].content
|
query = request.messages[-1].content
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
||||||
system = prev_messages.pop(0).content
|
system = prev_messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
system = None
|
system = None
|
||||||
@@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, system, request)
|
generate = predict(query, history, system, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
response, (prompt_length, response_length) = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
query, history, system,
|
query, history, system,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
@@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
num_return_sequences=request.n
|
num_return_sequences=request.n
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_length, response_length = 0, 0
|
||||||
|
choices = []
|
||||||
|
for i, response in enumerate(responses):
|
||||||
|
choices.append(ChatCompletionResponseChoice(
|
||||||
|
index=i,
|
||||||
|
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
||||||
|
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||||
|
))
|
||||||
|
prompt_length = response.prompt_length
|
||||||
|
response_length += response.response_length
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
prompt_tokens=prompt_length,
|
prompt_tokens=prompt_length,
|
||||||
completion_tokens=response_length,
|
completion_tokens=response_length,
|
||||||
total_tokens=prompt_length+response_length
|
total_tokens=prompt_length+response_length
|
||||||
)
|
)
|
||||||
|
|
||||||
choices = [ChatCompletionResponseChoice(
|
|
||||||
index=i,
|
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
|
||||||
finish_reason=Finish.STOP
|
|
||||||
) for i, choice in enumerate(response)]
|
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.chat_model import ChatModel
|
||||||
|
|||||||
@@ -1,11 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
|
||||||
|
response_text: str
|
||||||
|
response_length: int
|
||||||
|
prompt_length: int
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
@@ -18,7 +28,7 @@ class ChatModel:
|
|||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.system_prompt = data_args.system_prompt
|
self.system_prompt = data_args.system_prompt
|
||||||
|
|
||||||
def process_args(
|
def _process_args(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
@@ -79,17 +89,30 @@ class ChatModel:
|
|||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[List[str], Tuple[int, int]]:
|
) -> List[Response]:
|
||||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
r"""
|
||||||
|
Args: query, history, system, **input_kwargs
|
||||||
|
|
||||||
|
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||||
|
"""
|
||||||
|
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
generate_output = self.model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
response = self.tokenizer.batch_decode(
|
||||||
response_length = 0
|
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
for i in range(len(response_ids)):
|
)
|
||||||
|
results = []
|
||||||
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
|
results.append(Response(
|
||||||
|
response_text=response[i],
|
||||||
|
response_length=response_length,
|
||||||
|
prompt_length=prompt_length,
|
||||||
|
finish_reason="stop" if len(eos_index) else "length"
|
||||||
|
))
|
||||||
|
|
||||||
return response, (prompt_length, response_length)
|
return results
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
@@ -99,7 +122,7 @@ class ChatModel:
|
|||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
||||||
4
src/llmtuner/data/__init__.py
Normal file
4
src/llmtuner/data/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from llmtuner.data.loader import get_dataset
|
||||||
|
from llmtuner.data.preprocess import preprocess_dataset
|
||||||
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
|
from llmtuner.data.utils import split_dataset
|
||||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
from llmtuner.data.utils import checksum, EXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
||||||
|
|
||||||
from datasets import load_from_disk
|
from datasets import load_from_disk
|
||||||
|
|
||||||
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -19,6 +19,22 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
|
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||||
|
history = examples["history"][i] if "history" in examples else None
|
||||||
|
system = examples["system"][i] if "system" in examples else None
|
||||||
|
yield query, response, history, system
|
||||||
|
|
||||||
|
|
||||||
|
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
||||||
|
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
||||||
|
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
||||||
|
max_source_len = data_args.cutoff_len - max_target_len
|
||||||
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
def preprocess_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@@ -31,14 +47,6 @@ def preprocess_dataset(
|
|||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
|
||||||
for i in range(len(examples["prompt"])):
|
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
|
||||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
|
||||||
history = examples["history"][i] if "history" in examples else None
|
|
||||||
system = examples["system"][i] if "system" in examples else None
|
|
||||||
yield query, response, history, system
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
# build grouped texts with format `X1 X2 X3 ...`
|
||||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
@@ -79,13 +87,11 @@ def preprocess_dataset(
|
|||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, query, response, history, system
|
tokenizer, query, response, history, system
|
||||||
)):
|
)):
|
||||||
total_len = len(source_ids) + len(target_ids)
|
source_len, target_len = len(source_ids), len(target_ids)
|
||||||
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
|
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
||||||
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
|
if source_len > max_source_len:
|
||||||
|
|
||||||
if len(source_ids) > max_source_len:
|
|
||||||
source_ids = source_ids[:max_source_len]
|
source_ids = source_ids[:max_source_len]
|
||||||
if len(target_ids) > max_target_len:
|
if target_len > max_target_len:
|
||||||
target_ids = target_ids[:max_target_len]
|
target_ids = target_ids[:max_target_len]
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
@@ -187,15 +193,12 @@ def preprocess_dataset(
|
|||||||
chosen_ids += [tokenizer.eos_token_id]
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
rejected_ids += [tokenizer.eos_token_id]
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
|
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
|
||||||
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
|
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
||||||
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
|
if source_len > max_source_len:
|
||||||
|
|
||||||
if len(prompt_ids) > max_source_len:
|
|
||||||
prompt_ids = prompt_ids[:max_source_len]
|
prompt_ids = prompt_ids[:max_source_len]
|
||||||
if len(chosen_ids) > max_target_len:
|
if target_len > max_target_len:
|
||||||
chosen_ids = chosen_ids[:max_target_len]
|
chosen_ids = chosen_ids[:max_target_len]
|
||||||
if len(rejected_ids) > max_target_len:
|
|
||||||
rejected_ids = rejected_ids[:max_target_len]
|
rejected_ids = rejected_ids[:max_target_len]
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
@@ -225,9 +225,6 @@ def get_template_and_fix_tokenizer(
|
|||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -246,11 +243,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
|
||||||
https://huggingface.co/BAAI/AquilaChat2-7B
|
|
||||||
https://huggingface.co/BAAI/AquilaChat2-34B
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -273,9 +265,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -292,10 +281,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
|
||||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -312,9 +297,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -330,9 +312,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="bluelm",
|
name="bluelm",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -348,9 +327,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -369,9 +345,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/THUDM/chatglm3-6b
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -395,11 +368,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
|
|
||||||
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
|
|
||||||
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -426,9 +394,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Default template.
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="default",
|
name="default",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -447,10 +412,22 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
register_template(
|
||||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
name="falcon",
|
||||||
https://huggingface.co/internlm/internlm-chat-20b
|
prefix=[
|
||||||
"""
|
"{{system}}"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"User: {{query}}\nFalcon:"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
efficient_eos=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -473,11 +450,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
|
||||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
|
||||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -500,10 +472,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
|
||||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2_zh",
|
name="llama2_zh",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -517,9 +485,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -533,9 +498,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/openchat/openchat_3.5
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -557,10 +519,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
|
||||||
https://huggingface.co/Qwen/Qwen-14B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -587,10 +545,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
|
||||||
https://huggingface.co/HuggingFaceH4/starchat-beta
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -631,10 +585,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
|
|
||||||
https://huggingface.co/lmsys/vicuna-13b-v1.5
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="vicuna",
|
name="vicuna",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -651,10 +601,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
|
|
||||||
https://huggingface.co/xverse/XVERSE-13B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="xverse",
|
name="xverse",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -668,11 +614,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/wenge-research/yayi-7b
|
|
||||||
https://huggingface.co/wenge-research/yayi-7b-llama2
|
|
||||||
https://huggingface.co/wenge-research/yayi-13b-llama2
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="yayi",
|
name="yayi",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -705,10 +646,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
|
||||||
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -727,11 +664,6 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
|
||||||
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
|
|
||||||
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
|
|
||||||
"""
|
|
||||||
register_template(
|
register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
prefix=[
|
prefix=[
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from llmtuner.dsets.loader import get_dataset
|
|
||||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
|
||||||
from llmtuner.dsets.utils import split_dataset
|
|
||||||
1
src/llmtuner/eval/__init__.py
Normal file
1
src/llmtuner/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.eval.evaluator import Evaluator
|
||||||
116
src/llmtuner/eval/evaluator.py
Normal file
116
src/llmtuner/eval/evaluator.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import tiktoken
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||||
|
from llmtuner.eval.template import get_eval_template
|
||||||
|
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
||||||
|
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
|
||||||
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
|
self.model = dispatch_model(self.model)
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||||
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
|
self.choice_inputs = self._encode_choices()
|
||||||
|
|
||||||
|
def _encode_choices(self) -> List[int]:
|
||||||
|
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
|
||||||
|
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
logits = self.model(**batch_input).logits
|
||||||
|
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||||
|
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||||
|
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
||||||
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
|
def eval(self) -> None:
|
||||||
|
mapping = cached_file(
|
||||||
|
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
|
filename="mapping.json",
|
||||||
|
cache_dir=self.model_args.cache_dir,
|
||||||
|
token=self.model_args.hf_hub_token,
|
||||||
|
revision=self.model_args.model_revision
|
||||||
|
)
|
||||||
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
|
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||||
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
|
results = {}
|
||||||
|
for subject in pbar:
|
||||||
|
dataset = load_dataset(
|
||||||
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
|
name=subject,
|
||||||
|
download_mode="force_redownload"
|
||||||
|
)
|
||||||
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
|
inputs, outputs, labels = [], [], []
|
||||||
|
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||||
|
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
|
query, resp, history = self.eval_template.format_example(
|
||||||
|
target_data=dataset[self.data_args.split][i],
|
||||||
|
support_set=support_set,
|
||||||
|
subject_name=categorys[subject]["name"],
|
||||||
|
use_history=self.template.use_history
|
||||||
|
)
|
||||||
|
input_ids, _ = self.template.encode_oneturn(
|
||||||
|
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
||||||
|
)
|
||||||
|
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||||
|
labels.append(resp)
|
||||||
|
|
||||||
|
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
||||||
|
batch_input = self.tokenizer.pad(
|
||||||
|
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||||
|
).to(self.model.device)
|
||||||
|
preds = self.batch_inference(batch_input)
|
||||||
|
outputs += preds
|
||||||
|
|
||||||
|
corrects = (np.array(outputs) == np.array(labels))
|
||||||
|
category_name = categorys[subject]["category"]
|
||||||
|
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||||
|
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||||
|
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
|
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
|
score_info = "\n".join([
|
||||||
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
|
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
||||||
|
])
|
||||||
|
print(score_info)
|
||||||
|
if self.eval_args.save_dir is not None:
|
||||||
|
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||||
|
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
||||||
|
f.write(score_info)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
evaluator = Evaluator()
|
||||||
|
evaluator.eval()
|
||||||
86
src/llmtuner/eval/template.py
Normal file
86
src/llmtuner/eval/template.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import CHOICES
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalTemplate:
|
||||||
|
|
||||||
|
system: str
|
||||||
|
choice: str
|
||||||
|
answer: str
|
||||||
|
prefix: str
|
||||||
|
|
||||||
|
def parse_example(
|
||||||
|
self,
|
||||||
|
example: Dict[str, str]
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||||
|
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||||
|
|
||||||
|
def format_example(
|
||||||
|
self,
|
||||||
|
target_data: Dict[str, str],
|
||||||
|
support_set: "Dataset",
|
||||||
|
subject_name: str,
|
||||||
|
use_history: bool
|
||||||
|
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||||
|
query, resp = self.parse_example(target_data)
|
||||||
|
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||||
|
|
||||||
|
if len(history):
|
||||||
|
temp = history.pop(0)
|
||||||
|
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||||
|
else:
|
||||||
|
query = self.system.format(subject=subject_name) + query
|
||||||
|
|
||||||
|
if not use_history:
|
||||||
|
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||||
|
history = []
|
||||||
|
return query.strip(), resp, history
|
||||||
|
|
||||||
|
|
||||||
|
eval_templates: Dict[str, EvalTemplate] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_eval_template(
|
||||||
|
name: str,
|
||||||
|
system: str,
|
||||||
|
choice: str,
|
||||||
|
answer: str,
|
||||||
|
prefix: str
|
||||||
|
) -> None:
|
||||||
|
eval_templates[name] = EvalTemplate(
|
||||||
|
system=system,
|
||||||
|
choice=choice,
|
||||||
|
answer=answer,
|
||||||
|
prefix=prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_template(name: str) -> EvalTemplate:
|
||||||
|
eval_template = eval_templates.get(name, None)
|
||||||
|
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||||
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
|
register_eval_template(
|
||||||
|
name="en",
|
||||||
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
|
choice="\n{choice}. {content}",
|
||||||
|
answer="\nAnswer: ",
|
||||||
|
prefix=" "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_eval_template(
|
||||||
|
name="zh",
|
||||||
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
|
choice="\n{choice}. {content}",
|
||||||
|
answer="\n答案:",
|
||||||
|
prefix="\n"
|
||||||
|
)
|
||||||
@@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||||
model = kwargs.pop("model")
|
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||||
|
model.pretrained_model.config.save_pretrained(output_dir)
|
||||||
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||||
if getattr(model, "is_peft_model", False):
|
if getattr(model, "is_peft_model", False):
|
||||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
model.pretrained_model.save_pretrained(output_dir)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
model = kwargs.pop("model")
|
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
||||||
|
model.pretrained_model.config.save_pretrained(args.output_dir)
|
||||||
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
||||||
if getattr(model, "is_peft_model", False):
|
if getattr(model, "is_peft_model", False):
|
||||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
model.pretrained_model.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|||||||
@@ -1,11 +1,25 @@
|
|||||||
|
from collections import defaultdict, OrderedDict
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
CHOICES = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
DEFAULT_MODULE = defaultdict(str)
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
|
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
"Supervised Fine-Tuning": "sft",
|
"Supervised Fine-Tuning": "sft",
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
@@ -14,79 +28,251 @@ TRAINING_STAGES = {
|
|||||||
"Pre-Training": "pt"
|
"Pre-Training": "pt"
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_MODELS = {
|
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
|
||||||
"LLaMA-13B": "huggyllama/llama-13b",
|
|
||||||
"LLaMA-30B": "huggyllama/llama-30b",
|
|
||||||
"LLaMA-65B": "huggyllama/llama-65b",
|
|
||||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
|
||||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
|
||||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
|
||||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
|
||||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
|
||||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
|
||||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
|
||||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
|
||||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
|
||||||
"BLOOM-3B": "bigscience/bloom-3b",
|
|
||||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
|
||||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
|
||||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
|
||||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
|
||||||
"Falcon-7B": "tiiuae/falcon-7b",
|
|
||||||
"Falcon-40B": "tiiuae/falcon-40b",
|
|
||||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
|
||||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
|
||||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
|
||||||
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
|
|
||||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
|
|
||||||
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
|
|
||||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
|
||||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
|
|
||||||
"InternLM-7B": "internlm/internlm-7b",
|
|
||||||
"InternLM-20B": "internlm/internlm-20b",
|
|
||||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
|
||||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
|
|
||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
|
||||||
"Qwen-14B": "Qwen/Qwen-14B",
|
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
|
||||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
|
||||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
|
|
||||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
|
|
||||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
|
||||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
|
|
||||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_MODULE = {
|
def register_model_group(
|
||||||
"LLaMA": "q_proj,v_proj",
|
models: Dict[str, str],
|
||||||
"LLaMA2": "q_proj,v_proj",
|
module: Optional[str] = None,
|
||||||
"ChineseLLaMA2": "q_proj,v_proj",
|
template: Optional[str] = None
|
||||||
"BLOOM": "query_key_value",
|
) -> None:
|
||||||
"BLOOMZ": "query_key_value",
|
prefix = None
|
||||||
"Falcon": "query_key_value",
|
for name, path in models.items():
|
||||||
"Baichuan": "W_pack",
|
if prefix is None:
|
||||||
"Baichuan2": "W_pack",
|
prefix = name.split("-")[0]
|
||||||
"InternLM": "q_proj,v_proj",
|
else:
|
||||||
"Qwen": "c_attn",
|
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||||
"XVERSE": "q_proj,v_proj",
|
SUPPORTED_MODELS[name] = path
|
||||||
"ChatGLM2": "query_key_value",
|
if module is not None:
|
||||||
"ChatGLM3": "query_key_value",
|
DEFAULT_MODULE[prefix] = module
|
||||||
"Phi1.5": "Wqkv"
|
if template is not None:
|
||||||
}
|
DEFAULT_TEMPLATE[prefix] = template
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = {
|
|
||||||
"LLaMA2": "llama2",
|
register_model_group(
|
||||||
"ChineseLLaMA2": "llama2_zh",
|
models={
|
||||||
"Baichuan": "baichuan",
|
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||||
"Baichuan2": "baichuan2",
|
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||||
"InternLM": "intern",
|
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||||
"Qwen": "chatml",
|
},
|
||||||
"XVERSE": "xverse",
|
module="W_pack",
|
||||||
"ChatGLM2": "chatglm2",
|
template="baichuan"
|
||||||
"ChatGLM3": "chatglm3"
|
)
|
||||||
}
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||||
|
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||||
|
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||||
|
},
|
||||||
|
module="W_pack",
|
||||||
|
template="baichuan2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
|
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||||
|
},
|
||||||
|
module="query_key_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||||
|
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||||
|
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||||
|
},
|
||||||
|
module="query_key_value"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||||
|
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||||
|
},
|
||||||
|
template="bluelm"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="chatglm2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||||
|
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="chatglm3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||||
|
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||||
|
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||||
|
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||||
|
},
|
||||||
|
template="llama2_zh"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Falcon-7B": "tiiuae/falcon-7b",
|
||||||
|
"Falcon-40B": "tiiuae/falcon-40b",
|
||||||
|
"Falcon-180B": "tiiuae/falcon-180B",
|
||||||
|
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||||
|
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||||
|
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||||
|
},
|
||||||
|
module="query_key_value",
|
||||||
|
template="falcon"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"InternLM-7B": "internlm/internlm-7b",
|
||||||
|
"InternLM-20B": "internlm/internlm-20b",
|
||||||
|
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||||
|
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||||
|
},
|
||||||
|
template="intern"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||||
|
},
|
||||||
|
module="qkv_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
|
"LLaMA-13B": "huggyllama/llama-13b",
|
||||||
|
"LLaMA-30B": "huggyllama/llama-30b",
|
||||||
|
"LLaMA-65B": "huggyllama/llama-65b"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||||
|
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||||
|
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||||
|
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||||
|
},
|
||||||
|
template="llama2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||||
|
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||||
|
},
|
||||||
|
template="mistral"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
||||||
|
},
|
||||||
|
template="openchat"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||||
|
},
|
||||||
|
module="Wqkv"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
|
"Qwen-14B": "Qwen/Qwen-14B",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||||
|
},
|
||||||
|
module="c_attn",
|
||||||
|
template="qwen"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
|
||||||
|
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
|
||||||
|
},
|
||||||
|
template="vicuna"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||||
|
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||||
|
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||||
|
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||||
|
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||||
|
},
|
||||||
|
template="xverse"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
||||||
|
"Yayi-13B": "wenge-research/yayi-13b-llama2"
|
||||||
|
},
|
||||||
|
template="yayi"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Yi-6B": "01-ai/Yi-6B",
|
||||||
|
"Yi-34B": "01-ai/Yi-34B"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
||||||
|
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
},
|
||||||
|
template="zephyr"
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ import logging
|
|||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
|
r"""
|
||||||
|
Logger handler used in Web UI.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
|
|||||||
self.log += "\n\n"
|
self.log += "\n\n"
|
||||||
|
|
||||||
|
|
||||||
def reset_logging():
|
|
||||||
r"""
|
|
||||||
Removes basic config of root logger
|
|
||||||
"""
|
|
||||||
root = logging.getLogger()
|
|
||||||
list(map(root.removeHandler, root.handlers))
|
|
||||||
list(map(root.removeFilter, root.filters))
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
r"""
|
||||||
|
Gets a standard logger with a stream hander to stdout.
|
||||||
|
"""
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S"
|
datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
@@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
|
|||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def reset_logging() -> None:
|
||||||
|
r"""
|
||||||
|
Removes basic config of root logger. (unused in script)
|
||||||
|
"""
|
||||||
|
root = logging.getLogger()
|
||||||
|
list(map(root.removeHandler, root.handlers))
|
||||||
|
list(map(root.removeFilter, root.filters))
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -11,13 +13,13 @@ try:
|
|||||||
is_torch_npu_available
|
is_torch_npu_available
|
||||||
)
|
)
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_is_fp16_available = torch.cuda.is_available()
|
_is_fp16_available = torch.cuda.is_available()
|
||||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
@@ -62,6 +64,25 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> str:
|
||||||
|
import accelerate
|
||||||
|
from accelerate import Accelerator
|
||||||
|
dummy_accelerator = Accelerator()
|
||||||
|
if accelerate.utils.is_xpu_available():
|
||||||
|
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||||
|
else:
|
||||||
|
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
|
r"""
|
||||||
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
|
"""
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||||
r"""
|
r"""
|
||||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||||
@@ -74,13 +95,15 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
r"""
|
if args is not None:
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
return parser.parse_dict(args)
|
||||||
"""
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
logits_processor = LogitsProcessorList()
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
return logits_processor
|
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
return parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
@@ -91,28 +114,3 @@ def torch_gc() -> None:
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|
||||||
r"""
|
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
|
||||||
"""
|
|
||||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
|
||||||
return model
|
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
|
||||||
|
|
||||||
if model._no_split_modules is None:
|
|
||||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
|
||||||
# Make sure tied weights are tied before creating the device map.
|
|
||||||
model.tie_weights()
|
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
|
||||||
return dispatch_model(model, device_map)
|
|
||||||
else:
|
|
||||||
return model.cuda()
|
|
||||||
|
|||||||
55
src/llmtuner/extras/packages.py
Normal file
55
src/llmtuner/extras/packages.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
|
def is_package_available(name: str) -> bool:
|
||||||
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version(name: str) -> str:
|
||||||
|
try:
|
||||||
|
return importlib.metadata.version(name)
|
||||||
|
except:
|
||||||
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
_fastapi_available = is_package_available("fastapi")
|
||||||
|
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||||
|
_jieba_available = is_package_available("jieba")
|
||||||
|
_matplotlib_available = is_package_available("matplotlib")
|
||||||
|
_nltk_available = is_package_available("nltk")
|
||||||
|
_rouge_available = is_package_available("rouge-chinese")
|
||||||
|
_starlette_available = is_package_available("sse-starlette")
|
||||||
|
_uvicorn_available = is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def is_fastapi_availble():
|
||||||
|
return _fastapi_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn2_available():
|
||||||
|
return _flash_attn2_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_jieba_available():
|
||||||
|
return _jieba_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_matplotlib_available():
|
||||||
|
return _matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_nltk_available():
|
||||||
|
return _nltk_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_rouge_available():
|
||||||
|
return _rouge_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_starlette_available():
|
||||||
|
return _starlette_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_uvicorn_available():
|
||||||
|
return _uvicorn_available
|
||||||
@@ -3,13 +3,19 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from transformers.models.llama.modeling_llama import repeat_kv
|
||||||
|
except ImportError:
|
||||||
|
print("Please upgrade `transformers`.")
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn2_available():
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||||
except ImportError:
|
|
||||||
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.packages import is_matplotlib_available
|
||||||
|
|
||||||
|
if is_matplotlib_available():
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: Optional[str] = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "The name of the folder containing datasets."}
|
metadata={"help": "Path to the folder containing the datasets."}
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: Optional[str] = field(
|
||||||
default="train",
|
default="train",
|
||||||
@@ -52,6 +52,10 @@ class DataArguments:
|
|||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||||
)
|
)
|
||||||
|
reserved_label_len: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The maximum length reserved for label after tokenization."}
|
||||||
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||||
@@ -110,6 +114,9 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.reserved_label_len >= self.cutoff_len:
|
||||||
|
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
||||||
|
|
||||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||||
raise ValueError("Streaming mode should have an integer val size.")
|
raise ValueError("Streaming mode should have an integer val size.")
|
||||||
|
|
||||||
|
|||||||
55
src/llmtuner/hparams/evaluation_args.py
Normal file
55
src/llmtuner/hparams/evaluation_args.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from datasets import DownloadMode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to specify the evaluation parameters.
|
||||||
|
"""
|
||||||
|
task: str = field(
|
||||||
|
metadata={"help": "Name of the evaluation task."}
|
||||||
|
)
|
||||||
|
task_dir: Optional[str] = field(
|
||||||
|
default="evaluation",
|
||||||
|
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||||
|
)
|
||||||
|
batch_size: Optional[int] = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "The batch size per GPU for evaluation."}
|
||||||
|
)
|
||||||
|
seed: Optional[int] = field(
|
||||||
|
default=42,
|
||||||
|
metadata={"help": "Random seed to be used with data loaders."}
|
||||||
|
)
|
||||||
|
lang: Optional[Literal["en", "zh"]] = field(
|
||||||
|
default="en",
|
||||||
|
metadata={"help": "Language used at evaluation."}
|
||||||
|
)
|
||||||
|
n_shot: Optional[int] = field(
|
||||||
|
default=5,
|
||||||
|
metadata={"help": "Number of examplars for few-shot learning."}
|
||||||
|
)
|
||||||
|
save_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to save the evaluation results."}
|
||||||
|
)
|
||||||
|
download_mode: Optional[DownloadMode] = field(
|
||||||
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
|
metadata={"help": "Download mode used for the evaluation datasets."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
task_available = []
|
||||||
|
for folder in os.listdir(self.task_dir):
|
||||||
|
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
||||||
|
task_available.append(folder)
|
||||||
|
|
||||||
|
if self.task not in task_available:
|
||||||
|
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
||||||
|
|
||||||
|
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||||
|
raise ValueError("`save_dir` already exists, use another one.")
|
||||||
@@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments:
|
class FreezeArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to the freeze (partial-parameter) training.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
|
||||||
default="sft",
|
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
|
||||||
)
|
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
|
||||||
default="lora",
|
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
|
||||||
)
|
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||||
)
|
)
|
||||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
name_module_trainable: Optional[str] = field(
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
|
Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
||||||
Qwen choices: [\"mlp\", \"attn\"], \
|
Qwen choices: [\"mlp\", \"attn\"], \
|
||||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||||
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
Others choices: the same as LLaMA."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the LoRA training.
|
||||||
|
"""
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[float] = field(
|
lora_alpha: Optional[float] = field(
|
||||||
default=32.0,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
@@ -45,11 +45,11 @@ class FinetuningArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||||
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
Others choices: the same as LLaMA."}
|
||||||
)
|
)
|
||||||
additional_target: Optional[str] = field(
|
additional_target: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -59,30 +59,76 @@ class FinetuningArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use score normalization in PPO training."}
|
@dataclass
|
||||||
|
class RLHFArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the PPO and DPO training.
|
||||||
|
"""
|
||||||
|
dpo_beta: Optional[float] = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
ppo_logger: Optional[str] = field(
|
ppo_logger: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||||
)
|
)
|
||||||
|
ppo_score_norm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use score normalization in PPO training."}
|
||||||
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: Optional[float] = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||||
)
|
)
|
||||||
dpo_beta: Optional[float] = field(
|
ppo_whiten_rewards: Optional[bool] = field(
|
||||||
default=0.1,
|
default=False,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||||
)
|
)
|
||||||
dpo_ref_model: Optional[str] = field(
|
ref_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||||
)
|
)
|
||||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
ref_model_checkpoint: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||||
)
|
)
|
||||||
|
ref_model_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the reference model."}
|
||||||
|
)
|
||||||
|
reward_model: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_checkpoint: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
||||||
|
default="lora",
|
||||||
|
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
|
"""
|
||||||
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
|
default="sft",
|
||||||
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
|
)
|
||||||
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
|
default="lora",
|
||||||
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
upcast_layernorm: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
@@ -91,15 +137,37 @@ class FinetuningArguments:
|
|||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||||
)
|
)
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
|
)
|
||||||
|
plot_loss: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
def split_arg(arg):
|
||||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
if isinstance(arg, str):
|
||||||
|
return [item.strip() for item in arg.split(",")]
|
||||||
|
return arg
|
||||||
|
|
||||||
if isinstance(self.additional_target, str):
|
self.name_module_trainable = split_arg(self.name_module_trainable)
|
||||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
||||||
|
self.lora_target = split_arg(self.lora_target)
|
||||||
|
self.additional_target = split_arg(self.additional_target)
|
||||||
|
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
||||||
|
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
|
raise ValueError("Lora reward model only supports lora training.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
@@ -112,4 +180,5 @@ class FinetuningArguments:
|
|||||||
r"""Creates an instance from the content of `json_path`."""
|
r"""Creates an instance from the content of `json_path`."""
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
return cls(**json.loads(text))
|
return cls(**json.loads(text))
|
||||||
|
|||||||
@@ -54,22 +54,10 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
|
||||||
)
|
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
@@ -81,8 +69,7 @@ class ModelArguments:
|
|||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
if self.quantization_bit is not None:
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|||||||
5
src/llmtuner/model/__init__.py
Normal file
5
src/llmtuner/model/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Level: loader > adapter > parser, utils
|
||||||
|
|
||||||
|
from llmtuner.model.loader import load_model_and_tokenizer
|
||||||
|
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||||
|
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params
|
||||||
@@ -1,18 +1,9 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||||
from transformers.utils import cached_file
|
|
||||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|
||||||
from peft import (
|
|
||||||
PeftModel,
|
|
||||||
TaskType,
|
|
||||||
LoraConfig,
|
|
||||||
get_peft_model
|
|
||||||
)
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
from llmtuner.model.utils import find_all_linear_modules
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
@@ -38,20 +29,31 @@ def init_adapter(
|
|||||||
|
|
||||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||||
|
return model
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze":
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = getattr(model.config, "num_layers")
|
num_layers = (
|
||||||
|
getattr(model.config, "num_hidden_layers", None)
|
||||||
|
or getattr(model.config, "num_layers", None)
|
||||||
|
or getattr(model.config, "n_layer", None)
|
||||||
|
)
|
||||||
|
if not num_layers:
|
||||||
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
||||||
|
|
||||||
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
|
trainable_layers = []
|
||||||
|
for module_name in finetuning_args.name_module_trainable:
|
||||||
|
for idx in trainable_layer_ids:
|
||||||
|
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
@@ -99,30 +101,3 @@ def init_adapter(
|
|||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(
|
|
||||||
model: "PreTrainedModel",
|
|
||||||
model_args: "ModelArguments"
|
|
||||||
) -> bool:
|
|
||||||
kwargs = {
|
|
||||||
"path_or_repo_id": model_args.reward_model,
|
|
||||||
"cache_dir": model_args.cache_dir,
|
|
||||||
"token": model_args.hf_hub_token,
|
|
||||||
"revision": model_args.model_revision
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
|
||||||
except:
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
|
||||||
return False
|
|
||||||
|
|
||||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
|
||||||
return True
|
|
||||||
@@ -15,7 +15,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.models.llama import modeling_llama as LlamaModule
|
from transformers.models.llama import modeling_llama as LlamaModule
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from peft import PeftModel
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -24,11 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
|||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||||
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
from llmtuner.model.adapter import init_adapter
|
||||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@@ -42,7 +42,7 @@ require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transform
|
|||||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||||
require_version("trl==0.7.2", "To fix: pip install trl==0.7.2")
|
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
@@ -73,6 +73,7 @@ def load_model_and_tokenizer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||||
|
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
else:
|
else:
|
||||||
model_to_load = model_args.model_name_or_path
|
model_to_load = model_args.model_name_or_path
|
||||||
@@ -84,10 +85,9 @@ def load_model_and_tokenizer(
|
|||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
# Set model dtype
|
# Set model dtype
|
||||||
if model_args.compute_dtype is not None: # for training
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
|
||||||
else: # for evaluation, priority: bf16 > fp16 > fp32
|
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||||
|
|
||||||
# Fix config (for Qwen)
|
# Fix config (for Qwen)
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
@@ -123,13 +123,16 @@ def load_model_and_tokenizer(
|
|||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
if is_flash_attn2_available():
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention-2.")
|
logger.warning("Current model does not support FlashAttention.")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||||
@@ -142,7 +145,7 @@ def load_model_and_tokenizer(
|
|||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library)
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
@@ -162,10 +165,10 @@ def load_model_and_tokenizer(
|
|||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
)
|
)
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
# Load and prepare pre-trained models (without valuehead).
|
# Load pre-trained models (without valuehead)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -183,7 +186,7 @@ def load_model_and_tokenizer(
|
|||||||
setattr(model, "lm_head", model.transformer.output_layer)
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files
|
||||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||||
@@ -197,25 +200,15 @@ def load_model_and_tokenizer(
|
|||||||
model = model.train() if is_trainable else model.eval()
|
model = model.train() if is_trainable else model.eval()
|
||||||
|
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage == "rm" or stage == "ppo":
|
if stage in ["rm", "ppo"]:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
reset_logging()
|
vhead_path = (
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
)
|
||||||
if load_valuehead_params(model, model_args):
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
model.v_head.load_state_dict({
|
if vhead_params is not None:
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
"summary.bias": getattr(model, "reward_head_bias")
|
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||||
})
|
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
|
||||||
if isinstance(model.pretrained_model, PeftModel):
|
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
|
||||||
if "default" in name:
|
|
||||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
|
||||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
|
||||||
|
|
||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import torch
|
import torch
|
||||||
import datasets
|
import datasets
|
||||||
import transformers
|
import transformers
|
||||||
@@ -8,9 +7,11 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
|||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.misc import parse_args
|
||||||
from llmtuner.hparams import (
|
from llmtuner.hparams import (
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
DataArguments,
|
DataArguments,
|
||||||
|
EvaluationArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments
|
GeneratingArguments
|
||||||
)
|
)
|
||||||
@@ -19,62 +20,42 @@ from llmtuner.hparams import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
_TRAIN_ARGS = [
|
||||||
if args is not None:
|
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
return parser.parse_dict(args)
|
]
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
_TRAIN_CLS = Tuple[
|
||||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
]
|
||||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
_INFER_ARGS = [
|
||||||
else:
|
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
return parser.parse_args_into_dataclasses()
|
]
|
||||||
|
_INFER_CLS = Tuple[
|
||||||
|
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
]
|
||||||
|
_EVAL_ARGS = [
|
||||||
|
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||||
|
]
|
||||||
|
_EVAL_CLS = Tuple[
|
||||||
|
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def parse_train_args(
|
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
) -> Tuple[
|
return parse_args(parser, args)
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
Seq2SeqTrainingArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
parser = HfArgumentParser((
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
Seq2SeqTrainingArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
))
|
|
||||||
return _parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(
|
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
) -> Tuple[
|
return parse_args(parser, args)
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
parser = HfArgumentParser((
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
))
|
|
||||||
return _parse_args(parser, args)
|
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
parser = HfArgumentParser(_EVAL_ARGS)
|
||||||
) -> Tuple[
|
return parse_args(parser, args)
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
Seq2SeqTrainingArguments,
|
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -101,24 +82,19 @@ def get_train_args(
|
|||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"]:
|
if finetuning_args.stage in ["rm", "ppo"]:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
|
||||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
|
||||||
if training_args.resume_from_checkpoint is not None:
|
if training_args.resume_from_checkpoint is not None:
|
||||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||||
if training_args.load_best_model_at_end:
|
if training_args.load_best_model_at_end:
|
||||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||||
raise ValueError("PPO training does not support evaluation.")
|
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "dpo"]:
|
if finetuning_args.stage in ["rm", "dpo"]:
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
if not dataset_attr.ranking:
|
if not dataset_attr.ranking:
|
||||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
@@ -150,6 +126,9 @@ def get_train_args(
|
|||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
|
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# postprocess training_args
|
# postprocess training_args
|
||||||
if (
|
if (
|
||||||
training_args.local_rank != -1
|
training_args.local_rank != -1
|
||||||
@@ -198,14 +177,7 @@ def get_train_args(
|
|||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(
|
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
args: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Tuple[
|
|
||||||
ModelArguments,
|
|
||||||
DataArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
]:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
@@ -222,3 +194,17 @@ def get_infer_args(
|
|||||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
|
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||||
|
|
||||||
|
if data_args.template is None:
|
||||||
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, eval_args, finetuning_args
|
||||||
@@ -1,21 +1,53 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||||
|
"""
|
||||||
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||||
|
return model
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|
||||||
|
if model._no_split_modules is None:
|
||||||
|
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||||
|
|
||||||
|
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||||
|
max_memory = get_balanced_memory(model, **kwargs)
|
||||||
|
# Make sure tied weights are tied before creating the device map.
|
||||||
|
model.tie_weights()
|
||||||
|
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||||
|
return dispatch_model(model, device_map)
|
||||||
|
else:
|
||||||
|
return model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(
|
def find_all_linear_modules(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
quantization_bit: Optional[int] = None
|
quantization_bit: Optional[int] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
r"""
|
||||||
|
Finds all available modules to apply lora.
|
||||||
|
"""
|
||||||
if quantization_bit is not None:
|
if quantization_bit is not None:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||||
@@ -51,12 +83,38 @@ def generate_model_card(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_valuehead_params(
|
||||||
|
path_or_repo_id: str,
|
||||||
|
model_args: "ModelArguments"
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
|
||||||
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
|
"""
|
||||||
|
kwargs = {
|
||||||
|
"path_or_repo_id": path_or_repo_id,
|
||||||
|
"cache_dir": model_args.cache_dir,
|
||||||
|
"token": model_args.hf_hub_token
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
|
except:
|
||||||
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Includes:
|
Includes:
|
||||||
1
src/llmtuner/train/__init__.py
Normal file
1
src/llmtuner/train/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.tuner import export_model, run_exp
|
||||||
1
src/llmtuner/train/dpo/__init__.py
Normal file
1
src/llmtuner/train/dpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.dpo.workflow import run_dpo
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import deepspeed # type: ignore
|
|
||||||
from copy import deepcopy
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
@@ -11,7 +9,6 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import PreTrainedModelWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
@@ -46,40 +43,14 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
if not (
|
||||||
|
getattr(ref_model, "is_loaded_in_8bit", False)
|
||||||
|
or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
|
|
||||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
|
||||||
if model is not None:
|
|
||||||
if hasattr(model, "config"):
|
|
||||||
hidden_size = (
|
|
||||||
max(model.config.hidden_sizes)
|
|
||||||
if getattr(model.config, "hidden_sizes", None)
|
|
||||||
else getattr(model.config, "hidden_size", None)
|
|
||||||
)
|
|
||||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
|
||||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
|
||||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
|
||||||
config_kwargs.update(
|
|
||||||
{
|
|
||||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
|
||||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
|
||||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
|
||||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
|
||||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
|
||||||
config_kwargs["zero_optimization"]["stage"] = 0
|
|
||||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self,
|
||||||
model: Optional[torch.nn.Module] = None,
|
model: Optional[torch.nn.Module] = None,
|
||||||
@@ -4,23 +4,20 @@ from peft import PeftModel
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
from llmtuner.train.utils import create_ref_model
|
||||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
||||||
|
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
from llmtuner.hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dpo(
|
def run_dpo(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
@@ -38,23 +35,10 @@ def run_dpo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create reference model
|
# Create reference model
|
||||||
if finetuning_args.dpo_ref_model is not None:
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
ref_model_args_dict = model_args.to_dict()
|
|
||||||
ref_model_args_dict.update(dict(
|
|
||||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
|
||||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
|
||||||
))
|
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
|
||||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
|
||||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
|
||||||
elif training_args.do_train:
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
ref_model = None
|
|
||||||
else:
|
|
||||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
|
||||||
logger.info("Created reference model from the model itself.")
|
|
||||||
else:
|
|
||||||
ref_model = model
|
ref_model = model
|
||||||
|
else:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
@@ -80,14 +64,13 @@ def run_dpo(
|
|||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
|
||||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
metrics.pop(key)
|
metrics.pop(key)
|
||||||
1
src/llmtuner/train/ppo/__init__.py
Normal file
1
src/llmtuner/train/ppo/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.ppo.workflow import run_ppo
|
||||||
@@ -3,9 +3,9 @@ import sys
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
@@ -37,24 +37,43 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
|
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self.reward_model = reward_model
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
**generating_args.to_dict()
|
**generating_args.to_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
|
if reward_model is not None:
|
||||||
|
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||||
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
|
)
|
||||||
|
if is_deepspeed_enabled:
|
||||||
|
if not (
|
||||||
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||||
|
else:
|
||||||
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||||
@@ -108,9 +127,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
queries, responses = self.get_inputs(batch)
|
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
queries, responses, rewards = [], [], []
|
||||||
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
|
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||||
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||||
|
queries.extend(mini_batch_queries)
|
||||||
|
responses.extend(mini_batch_responses)
|
||||||
|
rewards.extend(mini_batch_rewards)
|
||||||
|
|
||||||
# Cast to training mode
|
# Cast to training mode
|
||||||
unwrapped_model.gradient_checkpointing_enable()
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
@@ -165,7 +189,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
@@ -208,25 +232,30 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
"""
|
"""
|
||||||
replace_model(unwrapped_model, target="reward")
|
if self.reward_model is None:
|
||||||
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||||
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
replace_model(unwrapped_model, target="default")
|
if self.reward_model is None:
|
||||||
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
@PPODecorators.empty_cuda_cache()
|
@PPODecorators.empty_device_cache()
|
||||||
def batched_forward_pass(
|
def batched_forward_pass(
|
||||||
self,
|
self,
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
@@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List
|
|||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
from llmtuner.train.utils import create_ref_model, create_reward_model
|
||||||
|
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
@@ -33,6 +34,11 @@ def run_ppo(
|
|||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# Create reference model and reward model
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
||||||
|
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||||
|
|
||||||
|
# Create ppo config
|
||||||
ppo_config = PPOConfig(
|
ppo_config = PPOConfig(
|
||||||
model_name=model_args.model_name_or_path,
|
model_name=model_args.model_name_or_path,
|
||||||
learning_rate=training_args.learning_rate,
|
learning_rate=training_args.learning_rate,
|
||||||
@@ -42,14 +48,16 @@ def run_ppo(
|
|||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm,
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
optimize_cuda_cache=True,
|
optimize_device_cache=True,
|
||||||
target=finetuning_args.ppo_target,
|
target=finetuning_args.ppo_target,
|
||||||
log_with=finetuning_args.ppo_logger,
|
log_with=finetuning_args.ppo_logger,
|
||||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create optimizer and scheduler
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
if training_args.max_steps > 0:
|
if training_args.max_steps > 0:
|
||||||
num_training_steps = training_args.max_steps
|
num_training_steps = training_args.max_steps
|
||||||
@@ -73,9 +81,10 @@ def run_ppo(
|
|||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [SavePeftModelCallback()],
|
||||||
|
reward_model=reward_model,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=ref_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
@@ -88,5 +97,5 @@ def run_ppo(
|
|||||||
ppo_trainer.ppo_train()
|
ppo_trainer.ppo_train()
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
1
src/llmtuner/train/pt/__init__.py
Normal file
1
src/llmtuner/train/pt/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.pt.workflow import run_pt
|
||||||
@@ -4,9 +4,9 @@ import math
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
@@ -42,7 +42,7 @@ def run_pt(
|
|||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
1
src/llmtuner/train/rm/__init__.py
Normal file
1
src/llmtuner/train/rm/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.rm.workflow import run_rm
|
||||||
@@ -3,13 +3,13 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
from llmtuner.train.rm.metric import compute_accuracy
|
||||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
from llmtuner.train.rm.trainer import PairwiseTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
@@ -51,7 +51,7 @@ def run_rm(
|
|||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
1
src/llmtuner/train/sft/__init__.py
Normal file
1
src/llmtuner/train/sft/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.train.sft.workflow import run_sft
|
||||||
@@ -2,15 +2,23 @@ import numpy as np
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import jieba
|
|
||||||
from rouge_chinese import Rouge
|
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.packages import (
|
||||||
|
is_jieba_available, is_nltk_available, is_rouge_available
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
if is_jieba_available():
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
if is_nltk_available():
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
|
||||||
|
if is_rouge_available():
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComputeMetrics:
|
class ComputeMetrics:
|
||||||
@@ -3,13 +3,13 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
from llmtuner.train.sft.metric import ComputeMetrics
|
||||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
@@ -69,7 +69,7 @@ def run_sft(
|
|||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
@@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.pt import run_pt
|
from llmtuner.train.pt import run_pt
|
||||||
from llmtuner.tuner.sft import run_sft
|
from llmtuner.train.sft import run_sft
|
||||||
from llmtuner.tuner.rm import run_rm
|
from llmtuner.train.rm import run_rm
|
||||||
from llmtuner.tuner.ppo import run_ppo
|
from llmtuner.train.ppo import run_ppo
|
||||||
from llmtuner.tuner.dpo import run_dpo
|
from llmtuner.train.dpo import run_dpo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
@@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
|
|||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(model_args.export_dir)
|
tokenizer.save_pretrained(finetuning_args.export_dir)
|
||||||
except:
|
except:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
80
src/llmtuner/train/utils.py
Normal file
80
src/llmtuner/train/utils.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
from typing import TYPE_CHECKING, Literal, Union
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
stage: Literal["ppo", "dpo"]
|
||||||
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||||
|
r"""
|
||||||
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||||
|
|
||||||
|
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||||
|
"""
|
||||||
|
if finetuning_args.ref_model is not None:
|
||||||
|
ref_model_args_dict = model_args.to_dict()
|
||||||
|
ref_model_args_dict.update(dict(
|
||||||
|
model_name_or_path=finetuning_args.ref_model,
|
||||||
|
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
||||||
|
quantization_bit=finetuning_args.ref_model_quantization_bit
|
||||||
|
))
|
||||||
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||||
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
|
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||||
|
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||||
|
else:
|
||||||
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
ref_model = None
|
||||||
|
else:
|
||||||
|
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||||
|
logger.info("Created reference model from the model itself.")
|
||||||
|
|
||||||
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
|
def create_reward_model(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments"
|
||||||
|
) -> "AutoModelForCausalLMWithValueHead":
|
||||||
|
r"""
|
||||||
|
Creates reward model for PPO training.
|
||||||
|
"""
|
||||||
|
if finetuning_args.reward_model_type == "lora":
|
||||||
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
|
if "default" in name:
|
||||||
|
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||||
|
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
|
||||||
|
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||||
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||||
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||||
|
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||||
|
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||||
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
reward_model_args_dict = model_args.to_dict()
|
||||||
|
reward_model_args_dict.update(dict(
|
||||||
|
model_name_or_path=finetuning_args.reward_model,
|
||||||
|
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
||||||
|
quantization_bit=finetuning_args.reward_model_quantization_bit
|
||||||
|
))
|
||||||
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||||
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
|
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||||
|
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
|
return reward_model
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.tune import export_model, run_exp
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
|
||||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
|
||||||
from llmtuner.tuner.core.utils import generate_model_card
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.dpo.workflow import run_dpo
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.ppo.workflow import run_ppo
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.pt.workflow import run_pt
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.rm.workflow import run_rm
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from llmtuner.tuner.sft.workflow import run_sft
|
|
||||||
@@ -2,7 +2,7 @@ import gradio as gr
|
|||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from llmtuner.hparams import GeneratingArguments
|
||||||
from llmtuner.webui.common import get_save_dir
|
from llmtuner.webui.common import get_save_dir
|
||||||
@@ -14,14 +14,24 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: "Manager",
|
||||||
|
demo_mode: Optional[bool] = False,
|
||||||
|
lazy_init: Optional[bool] = True
|
||||||
|
) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
|
self.demo_mode = demo_mode
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
if not lazy_init:
|
|
||||||
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if demo_mode: # load openchat 3.5 by default
|
||||||
|
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loaded(self) -> bool:
|
def loaded(self) -> bool:
|
||||||
return self.model is not None
|
return self.model is not None
|
||||||
@@ -36,6 +46,8 @@ class WebChatModel(ChatModel):
|
|||||||
error = ALERTS["err_no_model"][lang]
|
error = ALERTS["err_no_model"][lang]
|
||||||
elif not get("top.model_path"):
|
elif not get("top.model_path"):
|
||||||
error = ALERTS["err_no_path"][lang]
|
error = ALERTS["err_no_path"][lang]
|
||||||
|
elif self.demo_mode:
|
||||||
|
error = ALERTS["err_demo"][lang]
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
@@ -67,6 +79,11 @@ class WebChatModel(ChatModel):
|
|||||||
|
|
||||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||||
|
|
||||||
|
if self.demo_mode:
|
||||||
|
yield ALERTS["err_demo"][lang]
|
||||||
|
return
|
||||||
|
|
||||||
yield ALERTS["info_unloading"][lang]
|
yield ALERTS["info_unloading"][lang]
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|||||||
@@ -61,13 +61,17 @@ def get_model_path(model_name: str) -> str:
|
|||||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prefix(model_name: str) -> str:
|
||||||
|
return model_name.split("-")[0]
|
||||||
|
|
||||||
|
|
||||||
def get_module(model_name: str) -> str:
|
def get_module(model_name: str) -> str:
|
||||||
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
|
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
|
||||||
|
|
||||||
|
|
||||||
def get_template(model_name: str) -> str:
|
def get_template(model_name: str) -> str:
|
||||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||||
return "default"
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||||
|
|
||||||
from llmtuner.tuner import export_model
|
from llmtuner.train import export_model
|
||||||
from llmtuner.webui.common import get_save_dir
|
from llmtuner.webui.common import get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
from llmtuner.data.template import templates
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
|
||||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||||
from llmtuner.webui.utils import can_quantize
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,11 @@
|
|||||||
CSS = r"""
|
CSS = r"""
|
||||||
|
.duplicate-button {
|
||||||
|
margin: auto !important;
|
||||||
|
color: white !important;
|
||||||
|
background: black !important;
|
||||||
|
border-radius: 100vh !important;
|
||||||
|
}
|
||||||
|
|
||||||
.modal-box {
|
.modal-box {
|
||||||
position: fixed !important;
|
position: fixed !important;
|
||||||
top: 50%;
|
top: 50%;
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time
|
|||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
|
|
||||||
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
||||||
self.pure_chat = pure_chat
|
self.pure_chat = pure_chat
|
||||||
self.manager: "Manager" = Manager()
|
self.manager = Manager()
|
||||||
self.runner: "Runner" = Runner(self.manager)
|
self.runner = Runner(self.manager, demo_mode=demo_mode)
|
||||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
|
||||||
|
|
||||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from typing import Optional
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.webui.components import (
|
from llmtuner.webui.components import (
|
||||||
@@ -17,24 +18,35 @@ from llmtuner.webui.engine import Engine
|
|||||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||||
|
|
||||||
|
|
||||||
def create_ui() -> gr.Blocks:
|
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||||
engine = Engine(pure_chat=False)
|
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||||
|
|
||||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||||
|
if demo_mode:
|
||||||
|
gr.HTML(
|
||||||
|
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
|
||||||
|
)
|
||||||
|
gr.HTML(
|
||||||
|
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
|
||||||
|
"LLaMA Factory</a> for details.</center></h3>"
|
||||||
|
)
|
||||||
|
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||||
|
|
||||||
engine.manager.all_elems["top"] = create_top()
|
engine.manager.all_elems["top"] = create_top()
|
||||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||||
|
|
||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||||
|
|
||||||
with gr.Tab("Evaluate"):
|
with gr.Tab("Evaluate & Predict"):
|
||||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||||
|
|
||||||
with gr.Tab("Chat"):
|
with gr.Tab("Chat"):
|
||||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||||
|
|
||||||
with gr.Tab("Export"):
|
if not demo_mode:
|
||||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
with gr.Tab("Export"):
|
||||||
|
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||||
|
|
||||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||||
|
|||||||
@@ -659,6 +659,10 @@ ALERTS = {
|
|||||||
"en": "Failed.",
|
"en": "Failed.",
|
||||||
"zh": "训练出错。"
|
"zh": "训练出错。"
|
||||||
},
|
},
|
||||||
|
"err_demo": {
|
||||||
|
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
|
||||||
|
"zh": "展示模式不支持训练,请先复制到私人空间。"
|
||||||
|
},
|
||||||
"info_aborting": {
|
"info_aborting": {
|
||||||
"en": "Aborted, wait for terminating...",
|
"en": "Aborted, wait for terminating...",
|
||||||
"zh": "训练中断,正在等待线程结束……"
|
"zh": "训练中断,正在等待线程结束……"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
@@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
|
|||||||
from llmtuner.extras.constants import TRAINING_STAGES
|
from llmtuner.extras.constants import TRAINING_STAGES
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import run_exp
|
from llmtuner.train import run_exp
|
||||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
@@ -24,13 +24,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class Runner:
|
class Runner:
|
||||||
|
|
||||||
def __init__(self, manager: "Manager") -> None:
|
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
|
self.demo_mode = demo_mode
|
||||||
""" Resume """
|
""" Resume """
|
||||||
self.thread: "Thread" = None
|
self.thread: "Thread" = None
|
||||||
self.do_train = True
|
self.do_train = True
|
||||||
self.running_data: Dict["Component", Any] = None
|
self.running_data: Dict["Component", Any] = None
|
||||||
self.monitor_inputs: Dict[str, str] = None
|
|
||||||
""" State """
|
""" State """
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.running = False
|
self.running = False
|
||||||
@@ -46,9 +46,8 @@ class Runner:
|
|||||||
|
|
||||||
def set_abort(self) -> None:
|
def set_abort(self) -> None:
|
||||||
self.aborted = True
|
self.aborted = True
|
||||||
self.running = False
|
|
||||||
|
|
||||||
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
|
||||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||||
@@ -65,6 +64,9 @@ class Runner:
|
|||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
return ALERTS["err_no_dataset"][lang]
|
return ALERTS["err_no_dataset"][lang]
|
||||||
|
|
||||||
|
if self.demo_mode and (not from_preview):
|
||||||
|
return ALERTS["err_demo"][lang]
|
||||||
|
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.logger_handler.reset()
|
self.logger_handler.reset()
|
||||||
self.trainer_callback = LogCallback(self)
|
self.trainer_callback = LogCallback(self)
|
||||||
@@ -72,6 +74,7 @@ class Runner:
|
|||||||
|
|
||||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.running_data = None
|
||||||
self.running = False
|
self.running = False
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
@@ -84,9 +87,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([get_save_dir(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||||
])
|
) for ckpt in get("top.checkpoints")])
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
@@ -136,7 +139,10 @@ class Runner:
|
|||||||
args["upcast_layernorm"] = True
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = get("train.reward_model")
|
args["reward_model"] = get_save_dir(
|
||||||
|
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||||
|
)
|
||||||
|
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||||
|
|
||||||
if args["stage"] == "dpo":
|
if args["stage"] == "dpo":
|
||||||
args["dpo_beta"] = get("train.dpo_beta")
|
args["dpo_beta"] = get("train.dpo_beta")
|
||||||
@@ -154,9 +160,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([get_save_dir(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||||
])
|
) for ckpt in get("top.checkpoints")])
|
||||||
output_dir = get_save_dir(
|
output_dir = get_save_dir(
|
||||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||||
)
|
)
|
||||||
@@ -196,7 +202,7 @@ class Runner:
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
error = self._initialize(data, do_train)
|
error = self._initialize(data, do_train, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error, gr.update(visible=False)
|
yield error, gr.update(visible=False)
|
||||||
@@ -205,16 +211,14 @@ class Runner:
|
|||||||
yield gen_cmd(args), gr.update(visible=False)
|
yield gen_cmd(args), gr.update(visible=False)
|
||||||
|
|
||||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
error = self._initialize(data, do_train)
|
error = self._initialize(data, do_train, from_preview=False)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error, gr.update(visible=False)
|
yield error, gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
self.running = True
|
|
||||||
self.do_train, self.running_data = do_train, data
|
self.do_train, self.running_data = do_train, data
|
||||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
|
||||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
@@ -232,7 +236,12 @@ class Runner:
|
|||||||
yield from self._launch(data, do_train=False)
|
yield from self._launch(data, do_train=False)
|
||||||
|
|
||||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||||
|
self.running = True
|
||||||
|
lang = get("top.lang")
|
||||||
|
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||||
|
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||||
|
))
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_matplotlib_available
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
from llmtuner.webui.common import get_save_dir
|
from llmtuner.webui.common import get_save_dir
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
|
||||||
|
if is_matplotlib_available():
|
||||||
|
import matplotlib.figure
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||||
if not callback.max_steps:
|
if not callback.max_steps:
|
||||||
@@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
|
|||||||
return "```json\n{}\n```\n".format(result)
|
return "```json\n{}\n```\n".format(result)
|
||||||
|
|
||||||
|
|
||||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
|
||||||
if not base_model:
|
if not base_model:
|
||||||
return
|
return
|
||||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
|||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
|
|
||||||
|
|
||||||
def calculate(
|
def calculate_flops(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: Optional[int] = 1,
|
batch_size: Optional[int] = 1,
|
||||||
seq_length: Optional[int] = 256,
|
seq_length: Optional[int] = 256,
|
||||||
@@ -41,4 +41,4 @@ def calculate(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(calculate)
|
fire.Fire(calculate_flops)
|
||||||
|
|||||||
63
tests/cal_lr.py
Normal file
63
tests/cal_lr.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
|
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||||
|
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.model import get_train_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||||
|
BASE_BS = 4_000_000 # from llama paper
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_lr(
|
||||||
|
model_name_or_path: str,
|
||||||
|
dataset: str,
|
||||||
|
cutoff_len: int, # i.e. maximum input length during training
|
||||||
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
|
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||||
|
dataset_dir: Optional[str] = "data"
|
||||||
|
):
|
||||||
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||||
|
stage="sft",
|
||||||
|
model_name_or_path=model_name_or_path,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_dir=dataset_dir,
|
||||||
|
template="default",
|
||||||
|
cutoff_len=cutoff_len,
|
||||||
|
output_dir="dummy_dir"
|
||||||
|
))
|
||||||
|
trainset = get_dataset(model_args, data_args)
|
||||||
|
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||||
|
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
|
||||||
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||||
|
)
|
||||||
|
valid_tokens, total_tokens = 0, 0
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||||
|
total_tokens += torch.numel(batch["labels"])
|
||||||
|
|
||||||
|
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||||
|
valid_ratio = valid_tokens / total_tokens
|
||||||
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
|
lr = lr / 6.0 if is_mistral else lr
|
||||||
|
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(calculate_lr)
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
# --max_length 1024 --max_samples 1024
|
# --max_length 1024 --max_samples 1024
|
||||||
# dataset format: instruction (string), input (string), output (string), history (List[string])
|
# dataset format: instruction (string), input (string), output (string), history (List[string])
|
||||||
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|||||||
Reference in New Issue
Block a user