37 Commits

Author SHA1 Message Date
hiyouga
95d0f77fc2 release v0.3.0
Former-commit-id: de7f5b622340ab09ebbe57ad2703e63d06dfdeea
2023-11-16 16:00:11 +08:00
hiyouga
9b2654277b update readme
Former-commit-id: 4018aabc5d1623033d27a8aced25804de79b7e7b
2023-11-16 15:58:37 +08:00
hoshi-hiyouga
f1b3bdac3f Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011
Refactor llmtuner, support full-parameter RLHF

Former-commit-id: 3b92826803dc69471827b4f8204c2c3dc5310619
2023-11-16 15:47:13 +08:00
hiyouga
595fdbd95d fix css
Former-commit-id: 7afec127f60257462828298b25a5f6fd9c6f42c5
2023-11-16 15:45:38 +08:00
hiyouga
dab9385297 fix bug in web ui
Former-commit-id: a598f145ec903dd2b2c984d951b6c450b142ece5
2023-11-16 15:21:24 +08:00
hiyouga
df83def566 update ppo and demo in webui
Former-commit-id: de7571704c82121db13e3fc907379d2453100191
2023-11-16 14:55:26 +08:00
hiyouga
f9d4e37b3c fix bug in freeze tuning
Former-commit-id: f6b436a08421ca17d64abc51497f4aa43729a43b
2023-11-16 14:25:11 +08:00
hiyouga
e59a3d71e0 tiny fix
Former-commit-id: d65519d8a44b73bbb713741c23465f13c35c83f5
2023-11-16 03:27:19 +08:00
hiyouga
de3a84ac59 fix rlhf callback
Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
2023-11-16 03:26:19 +08:00
hiyouga
e017266b98 fix bug in PPO training
Former-commit-id: 2e99f0e53ce6de0acbcab85dd50aef874e8c6336
2023-11-16 02:32:54 +08:00
hiyouga
f81a8a5e5c fix import bug
Former-commit-id: 2356029cdd120d5f7bf630b80681ce8c53bff90d
2023-11-16 02:27:03 +08:00
hiyouga
7a3a0144a5 support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
2023-11-16 02:08:04 +08:00
hiyouga
8263b2d32d add demo mode for web UI
Former-commit-id: 5ad34f08b4e1505d7933b973497347f126b2e818
2023-11-15 23:51:26 +08:00
hoshi-hiyouga
833cd490b8 Create CODE_OF_CONDUCT.md
Former-commit-id: 6bee64cdf9c75488033e600fb5b48738daa1ed3b
2023-11-15 20:42:15 +08:00
hiyouga
2162c37e41 update readme and constants
Former-commit-id: 7d83e3dd9101a4fdd0b589d0c1f7b609c0feecd1
2023-11-15 18:04:37 +08:00
hiyouga
b2ac8376e1 support multiple modules in freeze training #1514
Former-commit-id: 60abac70dfd778df2ae8b3a2e960ed8b607d7ab6
2023-11-15 17:08:18 +08:00
hiyouga
8079584143 fix imports
Former-commit-id: 6156f1abef631c675d150dd1cb0325cfc3820c91
2023-11-15 16:47:45 +08:00
hiyouga
09a4474e7f disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
2023-11-15 16:29:09 +08:00
hiyouga
81530133ff fix #1507
Former-commit-id: 1ba9c53bd9743fa95fca1516c0ed9da352dbe9a1
2023-11-15 16:22:32 +08:00
hiyouga
cc4b384ac3 Update cal_lr.py
Former-commit-id: b92ef6c80ae108982046ec1419efb67c8b10b250
2023-11-14 21:14:42 +08:00
hiyouga
3852daf447 Update cal_lr.py
Former-commit-id: b6c3f9b24324403db41c5680a00aabc6d53bbeb9
2023-11-14 21:13:01 +08:00
hiyouga
5c97111f9d Update cal_lr.py
Former-commit-id: 1258eec806f6f4580a6eb7d9eb44f431f4c0da4f
2023-11-14 21:09:30 +08:00
hiyouga
75dd1f0f7e add cal_lr.py
Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
2023-11-14 20:58:37 +08:00
hiyouga
c9a4551012 fix #1494
Former-commit-id: 07c8d734529f03e47ef638a1bda222e8824d3d38
2023-11-14 18:07:20 +08:00
hiyouga
87197ba91d fix #1489
Former-commit-id: ebdeaca9cdfd6138c690a0fcb9f676deaddff177
2023-11-14 15:27:05 +08:00
hiyouga
7461bf84e5 support eval remote dataset
Former-commit-id: 71dd2698bf8c0b9ef7af995fb1e49e39fa66074e
2023-11-14 02:42:30 +08:00
hiyouga
fbc0357b2e fix dc link
Former-commit-id: 04c3a1f1c98d8f191102e359def0c8dcdc9621e3
2023-11-13 23:22:56 +08:00
hiyouga
ec334f5891 release v0.2.2, fix #1478 #1466
Former-commit-id: c9534c411716e1dceb54c5eb35fe845c93ee2973
2023-11-13 23:09:05 +08:00
hiyouga
885efe772e fix #424
Former-commit-id: ca24d445f825e120e659f5cd080a954c2243b8f2
2023-11-13 22:42:23 +08:00
hiyouga
64fc9ba678 refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
2023-11-13 22:20:35 +08:00
hiyouga
989eccd286 fix flashattn warning
Former-commit-id: 6eb095d39bd82fdbdb729a0ea57fc7246e3a60d6
2023-11-10 18:34:54 +08:00
hiyouga
f0766a2ab0 add todo
Former-commit-id: 0bd884feb11736d0ab24ca19885151cb47d9dcd3
2023-11-10 14:38:18 +08:00
hiyouga
178b85ff9a refactor constants
Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
2023-11-10 14:16:10 +08:00
hiyouga
68dd1ef121 tiny fix
Former-commit-id: 97ba2027bb1ddc01a3c824c40d5a180828810c2c
2023-11-09 17:20:49 +08:00
hoshi-hiyouga
b222cffe98 Merge pull request #1454 from yyq/main
Update finetuning_args.py

Former-commit-id: e67d8b93705383a8590f99e26e9fe8f663712aef
2023-11-09 17:12:18 +08:00
Yanqing
b4f1ab93d1 Update finetuning_args.py
更新 chatglm/falcon/bloom 的 lora_target 的名称

Former-commit-id: 06606739af035a80ae9ddba9d12c965ed289305d
2023-11-09 17:04:40 +08:00
hiyouga
f2e139f5cd fix #1452
Former-commit-id: 4d16214467715df458e24d03bb7d303d62b8bdcd
2023-11-09 16:41:32 +08:00
80 changed files with 1563 additions and 831 deletions

128
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
`hoshihiyouga AT gmail DOT com`.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@@ -6,7 +6,8 @@
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 Join our [WeChat](assets/wechat.jpg).
@@ -14,7 +15,9 @@
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
@@ -57,7 +60,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
@@ -71,7 +74,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches
@@ -79,9 +82,9 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA.
@@ -122,6 +125,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
@@ -158,7 +162,7 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- sentencepiece, protobuf and tiktoken
- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
- jieba, rouge-chinese and nltk (used at evaluation and predict)
- gradio and matplotlib (used in web UI)
- uvicorn, fastapi and sse-starlette (used in API)

View File

@@ -6,7 +6,8 @@
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/e73gccsSd?compact=true&style=flat)](https://discord.gg/e73gccsSd)
[![Discord](https://dcbadge.vercel.app/api/server/c2EPEt5NU?compact=true&style=flat)](https://discord.gg/c2EPEt5NU)
[![Spaces](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
👋 加入我们的[微信群](assets/wechat.jpg)。
@@ -14,7 +15,9 @@
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该模式目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
@@ -57,7 +60,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
@@ -71,7 +74,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
## 训练方法
@@ -79,9 +82,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
@@ -122,6 +125,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
@@ -158,7 +162,7 @@ huggingface-cli login
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece, protobuf 和 tiktoken
- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
- jieba, rouge-chinese 和 nltk (用于评估及预测)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)

View File

@@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []
prompt = data["instruction"].strip()
response = data["output"].strip()
@@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
history = []
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
@@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp))
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:
break
prompt = prompt[:human_idx].strip()
yield key, {
"instruction": query,
"output": response,
"history": history
}
yield key, {"conversations": conversations}

View File

@@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
yield key, {
"conversations": conversations
}
yield key, {"conversations": conversations}

View File

@@ -3,13 +3,12 @@ transformers>=4.31.0,<4.35.0
datasets>=2.14.0
accelerate>=0.21.0
peft>=0.6.0
trl==0.7.2
trl>=0.7.4
gradio>=3.38.0,<4.0.0
scipy
sentencepiece
protobuf
tiktoken
fire
jieba
rouge-chinese
nltk

View File

@@ -1,5 +1,12 @@
import readline
from llmtuner import ChatModel
from llmtuner.extras.misc import torch_gc
try:
import platform
if platform.system() != "Windows":
import readline
except ImportError:
print("Install `readline` for a better experience.")
def main():
@@ -21,6 +28,7 @@ def main():
if query.strip() == "clear":
history = []
torch_gc()
print("History has been removed.")
continue

View File

@@ -1,190 +1,10 @@
# coding=utf-8
# Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import fire
import json
import torch
import numpy as np
import transformers
from collections import Counter
from datasets import load_dataset
from dataclasses import dataclass
from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner import ChatModel
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner import Evaluator
choices = ["A", "B", "C", "D"]
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates = {
"en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
),
"zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
)
}
@torch.inference_mode()
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor],
prefix_char: str
) -> List[str]:
logits = chat_model.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
probs = torch.nn.functional.softmax(
torch.stack(
[
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
for choice in choices
],
dim=-1
),
dim=-1
).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
def evaluate(
model_name_or_path: str,
finetuning_type: Optional[str] = "lora",
checkpoint_dir: Optional[str] = None,
template: Optional[str] = "vanilla",
task: Optional[str] = "ceval",
dataset_dir: Optional[str] = "evaluation",
split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4,
save_name: Optional[str] = None,
seed: Optional[int] = 42
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
transformers.set_seed(seed)
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
checkpoint_dir=checkpoint_dir,
template=template
))
chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
eval_template = eval_templates[lang]
category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
labels, answers, all_outputs = [], [], []
for epoch in range(n_avg):
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
inputs, outputs = [], []
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example(
target_data=dataset[split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history
)
input_ids, _ = chat_model.template.encode_oneturn(
tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
if epoch == 0:
labels.append(resp)
for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds
all_outputs.append(outputs)
for i in range(len(all_outputs[0])):
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
answers.append(count.most_common(1)[0][0])
corrects = (np.array(answers) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): answers[i] for i in range(len(answers))}
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
print(score_info)
if save_name is not None:
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
def main():
evaluator = Evaluator()
evaluator.eval()
if __name__ == "__main__":
fire.Fire(evaluate)
main()

View File

@@ -1,9 +1,10 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.eval import Evaluator
from llmtuner.train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.2.1"
__version__ = "0.3.0"

View File

@@ -1,14 +1,8 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
@@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI:
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
else:
system = None
@@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
responses = chat_model.chat(
query, history, system,
do_sample=request.do_sample,
temperature=request.temperature,
@@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
num_return_sequences=request.n
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
)
choices = [ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP
) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):

View File

@@ -1 +1 @@
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat.chat_model import ChatModel

View File

@@ -1,11 +1,21 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel:
@@ -18,7 +28,7 @@ class ChatModel:
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
def process_args(
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
@@ -79,17 +89,30 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[List[str], Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response_length = 0
for i in range(len(response_ids)):
response = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
return response, (prompt_length, response_length)
return results
@torch.inference_mode()
def stream_chat(
@@ -99,7 +122,7 @@ class ChatModel:
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@@ -0,0 +1,4 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:

View File

@@ -1,13 +1,13 @@
import os
import tiktoken
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@@ -19,6 +19,22 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
@@ -31,14 +47,6 @@ def preprocess_dataset(
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
@@ -79,13 +87,11 @@ def preprocess_dataset(
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
total_len = len(source_ids) + len(target_ids)
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
if len(source_ids) > max_source_len:
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if len(target_ids) > max_target_len:
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
@@ -187,15 +193,12 @@ def preprocess_dataset(
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
if len(prompt_ids) > max_source_len:
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if len(chosen_ids) > max_target_len:
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
if len(rejected_ids) > max_target_len:
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)

View File

@@ -225,9 +225,6 @@ def get_template_and_fix_tokenizer(
return template
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix=[
@@ -246,11 +243,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
"""
register_template(
name="aquila",
prefix=[
@@ -273,9 +265,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix=[
@@ -292,10 +281,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
"""
register_template(
name="baichuan2",
prefix=[
@@ -312,9 +297,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix=[
@@ -330,9 +312,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
"""
register_template(
name="bluelm",
prefix=[
@@ -348,9 +327,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
@@ -369,9 +345,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/THUDM/chatglm3-6b
"""
register_template(
name="chatglm3",
prefix=[
@@ -395,11 +368,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
"""
register_template(
name="deepseek",
prefix=[
@@ -426,9 +394,6 @@ register_template(
)
r"""
Default template.
"""
register_template(
name="default",
prefix=[
@@ -447,10 +412,22 @@ register_template(
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
"""
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
"\n"
],
efficient_eos=True
)
register_template(
name="intern",
prefix=[
@@ -473,11 +450,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix=[
@@ -500,10 +472,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
"""
register_template(
name="llama2_zh",
prefix=[
@@ -517,9 +485,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"""
register_template(
name="mistral",
prefix=[
@@ -533,9 +498,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/openchat/openchat_3.5
"""
register_template(
name="openchat",
prefix=[
@@ -557,10 +519,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
"""
register_template(
name="qwen",
prefix=[
@@ -587,10 +545,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix=[
@@ -631,10 +585,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-v1.5
"""
register_template(
name="vicuna",
prefix=[
@@ -651,10 +601,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template(
name="xverse",
prefix=[
@@ -668,11 +614,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/wenge-research/yayi-7b
https://huggingface.co/wenge-research/yayi-7b-llama2
https://huggingface.co/wenge-research/yayi-13b-llama2
"""
register_template(
name="yayi",
prefix=[
@@ -705,10 +646,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
"""
register_template(
name="zephyr",
prefix=[
@@ -727,11 +664,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
"""
register_template(
name="ziya",
prefix=[

View File

@@ -1,3 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@@ -0,0 +1 @@
from llmtuner.eval.evaluator import Evaluator

View File

@@ -0,0 +1,116 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import json
import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices()
def _encode_choices(self) -> List[int]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
revision=self.model_args.model_revision
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject,
download_mode="force_redownload"
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
query, resp, history = self.eval_template.format_example(
target_data=dataset[self.data_args.split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
use_history=self.template.use_history
)
input_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp=resp, history=history
)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(resp)
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join([
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
for category_name, category_correct in category_corrects.items() if len(category_correct)
])
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__":
evaluator = Evaluator()
evaluator.eval()

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
prefix: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query.strip(), resp, history
eval_templates: Dict[str, EvalTemplate] = {}
def register_eval_template(
name: str,
system: str,
choice: str,
answer: str,
prefix: str
) -> None:
eval_templates[name] = EvalTemplate(
system=system,
choice=choice,
answer=answer,
prefix=prefix
)
def get_eval_template(name: str) -> EvalTemplate:
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
return eval_template
register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: ",
prefix=" "
)
register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
prefix="\n"
)

View File

@@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
@@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback):
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir)
model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
model.pretrained_model.save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):

View File

@@ -1,11 +1,25 @@
from collections import defaultdict, OrderedDict
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
METHODS = ["full", "freeze", "lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
@@ -14,79 +28,251 @@ TRAINING_STAGES = {
"Pre-Training": "pt"
}
SUPPORTED_MODELS = {
def register_model_group(
models: Dict[str, str],
module: Optional[str] = None,
template: Optional[str] = None
) -> None:
prefix = None
for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path
if module is not None:
DEFAULT_MODULE[prefix] = module
if template is not None:
DEFAULT_TEMPLATE[prefix] = template
register_model_group(
models={
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
},
module="W_pack",
template="baichuan"
)
register_model_group(
models={
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
},
module="W_pack",
template="baichuan2"
)
register_model_group(
models={
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1"
},
module="query_key_value"
)
register_model_group(
models={
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
},
module="query_key_value"
)
register_model_group(
models={
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
},
template="bluelm"
)
register_model_group(
models={
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
},
module="query_key_value",
template="chatglm2"
)
register_model_group(
models={
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
},
module="query_key_value",
template="chatglm3"
)
register_model_group(
models={
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
},
template="llama2_zh"
)
register_model_group(
models={
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-180B": "tiiuae/falcon-180B",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
},
module="query_key_value",
template="falcon"
)
register_model_group(
models={
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
},
template="intern"
)
register_model_group(
models={
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
},
module="qkv_proj"
)
register_model_group(
models={
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"LLaMA-65B": "huggyllama/llama-65b"
}
)
register_model_group(
models={
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B": "tiiuae/falcon-7b",
"Falcon-40B": "tiiuae/falcon-40b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b",
"InternLM-20B": "internlm/internlm-20b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
},
template="llama2"
)
register_model_group(
models={
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
},
template="mistral"
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
},
template="openchat"
)
register_model_group(
models={
"Phi1.5-1.3B": "microsoft/phi-1_5"
},
module="Wqkv"
)
register_model_group(
models={
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
},
module="c_attn",
template="qwen"
)
register_model_group(
models={
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
}
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
},
template="vicuna"
)
register_model_group(
models={
"XVERSE-7B": "xverse/XVERSE-7B",
"XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
"Phi1.5-1.3B": "microsoft/phi-1_5"
}
"XVERSE-65B": "xverse/XVERSE-65B",
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
},
template="xverse"
)
DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value",
"ChatGLM3": "query_key_value",
"Phi1.5": "Wqkv"
}
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"Baichuan2": "baichuan2",
"InternLM": "intern",
"Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2",
"ChatGLM3": "chatglm3"
}
register_model_group(
models={
"Yayi-7B": "wenge-research/yayi-7b-llama2",
"Yayi-13B": "wenge-research/yayi-13b-llama2"
},
template="yayi"
)
register_model_group(
models={
"Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B"
}
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
},
template="zephyr"
)

View File

@@ -3,6 +3,9 @@ import logging
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self):
super().__init__()
@@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
@@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)
return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@@ -1,6 +1,8 @@
import gc
import os
import sys
import torch
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
try:
@@ -11,13 +13,13 @@ try:
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers import HfArgumentParser
class AverageMeter:
@@ -62,6 +64,25 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
@@ -74,13 +95,15 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def get_logits_processor() -> LogitsProcessorList:
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def torch_gc() -> None:
@@ -91,28 +114,3 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

@@ -3,13 +3,19 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
except ImportError:
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
logger = logging.get_logger(__name__)

View File

@@ -1,11 +1,14 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)

View File

@@ -1,4 +1,5 @@
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments

View File

@@ -42,7 +42,7 @@ class DataArguments:
)
dataset_dir: Optional[str] = field(
default="data",
metadata={"help": "The name of the folder containing datasets."}
metadata={"help": "Path to the folder containing the datasets."}
)
split: Optional[str] = field(
default="train",
@@ -52,6 +52,10 @@ class DataArguments:
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
@@ -110,6 +114,9 @@ class DataArguments:
)
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")

View File

@@ -0,0 +1,55 @@
import os
from typing import Literal, Optional
from dataclasses import dataclass, field
from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."}
)
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."}
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."}
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."}
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."}
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."}
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."}
)
download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}
)
def __post_init__(self):
task_available = []
for folder in os.listdir(self.task_dir):
if os.path.isdir(os.path.join(self.task_dir, folder)):
task_available.append(folder)
if self.task not in task_available:
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")

View File

@@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
class FreezeArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Arguments pertaining to the freeze (partial-parameter) training.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
@@ -45,11 +45,11 @@ class FinetuningArguments:
default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
@@ -59,30 +59,76 @@ class FinetuningArguments:
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
)
dpo_ref_model: Optional[str] = field(
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
reward_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full"]] = field(
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@@ -91,15 +137,37 @@ class FinetuningArguments:
default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
@@ -112,4 +180,5 @@ class FinetuningArguments:
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@@ -54,22 +54,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self):
self.compute_dtype = None
@@ -81,8 +69,7 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@@ -0,0 +1,5 @@
# Level: loader > adapter > parser, utils
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params

View File

@@ -1,18 +1,9 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.utils import find_all_linear_modules
from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
@@ -38,20 +29,31 @@ def init_adapter(
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
@@ -99,30 +101,3 @@ def init_adapter(
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@@ -15,7 +15,6 @@ from transformers import (
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
@@ -24,11 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
from llmtuner.tuner.core.utils import prepare_model_for_training
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -42,7 +42,7 @@ require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transform
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl==0.7.2", "To fix: pip install trl==0.7.2")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
def load_model_and_tokenizer(
@@ -73,6 +73,7 @@ def load_model_and_tokenizer(
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
@@ -84,10 +85,9 @@ def load_model_and_tokenizer(
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is not None: # for training
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # for evaluation, priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
@@ -123,13 +123,16 @@ def load_model_and_tokenizer(
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
@@ -142,7 +145,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@@ -162,10 +165,10 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead).
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
@@ -183,7 +186,7 @@ def load_model_and_tokenizer(
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files.
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
@@ -197,25 +200,15 @@ def load_model_and_tokenizer(
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable:

View File

@@ -1,5 +1,4 @@
import os
import sys
import torch
import datasets
import transformers
@@ -8,9 +7,11 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
@@ -19,62 +20,42 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return _parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
@@ -101,24 +82,19 @@ def get_train_args(
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"]:
if finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation.")
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
@@ -150,6 +126,9 @@ def get_train_args(
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
if (
training_args.local_rank != -1
@@ -198,14 +177,7 @@ def get_train_args(
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
@@ -222,3 +194,17 @@ def get_infer_args(
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@@ -1,21 +1,53 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
@@ -51,12 +83,38 @@ def generate_model_card(
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:

View File

@@ -0,0 +1 @@
from llmtuner.train.tuner import export_model, run_exp

View File

@@ -0,0 +1 @@
from llmtuner.train.dpo.workflow import run_dpo

View File

@@ -1,6 +1,4 @@
import torch
import deepspeed # type: ignore
from copy import deepcopy
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
@@ -11,7 +9,6 @@ from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import PreTrainedModelWrapper
class CustomDPOTrainer(DPOTrainer):
@@ -46,40 +43,14 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,

View File

@@ -4,23 +4,20 @@ from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -38,23 +35,10 @@ def run_dpo(
)
# Create reference model
if finetuning_args.dpo_ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.dpo_ref_model,
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
elif training_args.do_train:
if isinstance(model, PeftModel):
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from the model itself.")
else:
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
# Update arguments
training_args_dict = training_args.to_dict()
@@ -80,14 +64,13 @@ def run_dpo(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)

View File

@@ -0,0 +1 @@
from llmtuner.train.ppo.workflow import run_ppo

View File

@@ -3,9 +3,9 @@ import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -37,24 +37,43 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if reward_model is not None:
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@@ -108,9 +127,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
rewards.extend(mini_batch_rewards)
# Cast to training mode
unwrapped_model.gradient_checkpointing_enable()
@@ -165,7 +189,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
@torch.no_grad()
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
@@ -208,25 +232,30 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Computes scores using given reward model.
"""
if self.reward_model is None:
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.reward_model is None:
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",

View File

@@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -33,6 +34,11 @@ def run_ppo(
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
@@ -42,14 +48,16 @@ def run_ppo(
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
@@ -73,9 +81,10 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=None,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
@@ -88,5 +97,5 @@ def run_ppo(
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -0,0 +1 @@
from llmtuner.train.pt.workflow import run_pt

View File

@@ -4,9 +4,9 @@ import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -42,7 +42,7 @@ def run_pt(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@@ -0,0 +1 @@
from llmtuner.train.rm.workflow import run_rm

View File

@@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwiseTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -51,7 +51,7 @@ def run_rm(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@@ -0,0 +1 @@
from llmtuner.train.sft.workflow import run_sft

View File

@@ -2,15 +2,23 @@ import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass
class ComputeMetrics:

View File

@@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -69,7 +69,7 @@ def run_sft(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt
from llmtuner.train.sft import run_sft
from llmtuner.train.rm import run_rm
from llmtuner.train.ppo import run_ppo
from llmtuner.train.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir)
tokenizer.save_pretrained(finetuning_args.export_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@@ -0,0 +1,80 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@@ -1 +0,0 @@
from llmtuner.tuner.tune import export_model, run_exp

View File

@@ -1,3 +0,0 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer
from llmtuner.tuner.core.utils import generate_model_card

View File

@@ -1 +0,0 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@@ -1 +0,0 @@
from llmtuner.tuner.ppo.workflow import run_ppo

View File

@@ -1 +0,0 @@
from llmtuner.tuner.pt.workflow import run_pt

View File

@@ -1 +0,0 @@
from llmtuner.tuner.rm.workflow import run_rm

View File

@@ -1 +0,0 @@
from llmtuner.tuner.sft.workflow import run_sft

View File

@@ -2,7 +2,7 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
@@ -14,14 +14,24 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load openchat 3.5 by default
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
@property
def loaded(self) -> bool:
return self.model is not None
@@ -36,6 +46,8 @@ class WebChatModel(ChatModel):
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
if error:
gr.Warning(error)
@@ -67,6 +79,11 @@ class WebChatModel(ChatModel):
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_name("top.lang")]
if self.demo_mode:
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None

View File

@@ -61,13 +61,17 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
def get_prefix(model_name: str) -> str:
return model_name.split("-")[0]
def get_module(model_name: str) -> str:
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"

View File

@@ -1,7 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.tuner import export_model
from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS

View File

@@ -1,8 +1,8 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize

View File

@@ -1,4 +1,11 @@
CSS = r"""
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
.modal-box {
position: fixed !important;
top: 50%;

View File

@@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time
class Engine:
def __init__(self, pure_chat: Optional[bool] = False) -> None:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.pure_chat = pure_chat
self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
self.manager = Manager()
self.runner = Runner(self.manager, demo_mode=demo_mode)
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}

View File

@@ -1,4 +1,5 @@
import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
@@ -17,22 +18,33 @@ from llmtuner.webui.engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
def create_ui() -> gr.Blocks:
engine = Engine(pure_chat=False)
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
with gr.Tab("Evaluate & Predict"):
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):
engine.manager.all_elems["infer"] = create_infer_tab(engine)
if not demo_mode:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)

View File

@@ -659,6 +659,10 @@ ALERTS = {
"en": "Failed.",
"zh": "训练出错。"
},
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@@ -4,7 +4,7 @@ import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
@@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@@ -24,13 +24,13 @@ if TYPE_CHECKING:
class Runner:
def __init__(self, manager: "Manager") -> None:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.thread: "Thread" = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """
self.aborted = False
self.running = False
@@ -46,9 +46,8 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
self.running = False
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
@@ -65,6 +64,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
@@ -72,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running_data = None
self.running = False
torch_gc()
if self.aborted:
@@ -84,9 +87,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
else:
checkpoint_dir = None
@@ -136,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = get("train.reward_model")
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
@@ -154,9 +160,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
@@ -196,7 +202,7 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
@@ -205,16 +211,14 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.running = True
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()
@@ -232,7 +236,12 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
while self.thread.is_alive():
time.sleep(2)
if self.aborted:

View File

@@ -1,17 +1,20 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
@@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")

View File

@@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
def calculate(
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
@@ -41,4 +41,4 @@ def calculate(
if __name__ == "__main__":
fire.Fire(calculate)
fire.Fire(calculate_flops)

63
tests/cal_lr.py Normal file
View File

@@ -0,0 +1,63 @@
# coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import fire
import math
import torch
from tqdm import tqdm
from typing import Optional
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.model import get_train_args, load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data"
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir"
))
trainset = get_dataset(model_args, data_args)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
))
if __name__ == "__main__":
fire.Fire(calculate_lr)

View File

@@ -4,7 +4,6 @@
# --max_length 1024 --max_samples 1024
# dataset format: instruction (string), input (string), output (string), history (List[string])
import fire
from datasets import load_dataset
from transformers import AutoTokenizer