Compare commits
162 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04423b916f | ||
|
|
bf8d2f8eda | ||
|
|
2a5d02fd0f | ||
|
|
ea550ed9e0 | ||
|
|
02665cd42b | ||
|
|
0c6a94e66d | ||
|
|
ebd6bc2604 | ||
|
|
daab85e3e6 | ||
|
|
769d81a83d | ||
|
|
ac2a401b1d | ||
|
|
bb53c18153 | ||
|
|
04e0fe9147 | ||
|
|
39f75c7001 | ||
|
|
7f99cb1817 | ||
|
|
c555b2cce3 | ||
|
|
2eba1c6851 | ||
|
|
edeed55664 | ||
|
|
92248f9cb2 | ||
|
|
c548ad5e69 | ||
|
|
a57d839e1d | ||
|
|
d88a34bc79 | ||
|
|
60cbc9d0e5 | ||
|
|
d5005e766f | ||
|
|
4d0753cffe | ||
|
|
1cf0f11840 | ||
|
|
052e8b2cc6 | ||
|
|
8963e89633 | ||
|
|
935ee0a023 | ||
|
|
5ed234ca63 | ||
|
|
04884a0911 | ||
|
|
c7af26a9e3 | ||
|
|
d8073488be | ||
|
|
6fc2d7e063 | ||
|
|
e93c7cdb80 | ||
|
|
c32d6c8250 | ||
|
|
757158da63 | ||
|
|
ffdacaa618 | ||
|
|
e194efab10 | ||
|
|
772fc2eac7 | ||
|
|
ed020579dc | ||
|
|
096869c7b6 | ||
|
|
c6873211e9 | ||
|
|
623ee1bd88 | ||
|
|
aabe90343e | ||
|
|
764cfb506d | ||
|
|
249ad56075 | ||
|
|
46f99ff277 | ||
|
|
73f4513c84 | ||
|
|
3c91e86268 | ||
|
|
42473ec150 | ||
|
|
6a4e4b9c5b | ||
|
|
9a784fb4f3 | ||
|
|
43fd80a1aa | ||
|
|
e6ab1a57ea | ||
|
|
282edb9161 | ||
|
|
dff77004f2 | ||
|
|
6c1b4aec75 | ||
|
|
7814db1b42 | ||
|
|
c9ed3fc3a4 | ||
|
|
9ee416a8fc | ||
|
|
4f9a47c026 | ||
|
|
3fcb1c6d09 | ||
|
|
7c492864e9 | ||
|
|
7ff8a064f3 | ||
|
|
c635bbe465 | ||
|
|
4881f4e631 | ||
|
|
c631799f5d | ||
|
|
48846676d8 | ||
|
|
f37d481c5d | ||
|
|
5d7d8bd55c | ||
|
|
8ed1463236 | ||
|
|
43b2ede0f8 | ||
|
|
2f095e2017 | ||
|
|
9b55bb964c | ||
|
|
9b97b23ce7 | ||
|
|
53ab28533e | ||
|
|
940c00e7ae | ||
|
|
18cfd5f349 | ||
|
|
6169df1c52 | ||
|
|
d46c2bbcba | ||
|
|
48d4364586 | ||
|
|
8042c66a76 | ||
|
|
3879d79b89 | ||
|
|
e416cecf62 | ||
|
|
81fcb80466 | ||
|
|
bf812fbe40 | ||
|
|
1e6fb6c8aa | ||
|
|
5d0c95bd02 | ||
|
|
7cd2417002 | ||
|
|
16851d66e5 | ||
|
|
056d2d956a | ||
|
|
9a69cadab3 | ||
|
|
3de642bffd | ||
|
|
286b9d9849 | ||
|
|
cef1ede826 | ||
|
|
5007566588 | ||
|
|
e93fb3cc6c | ||
|
|
7578209735 | ||
|
|
67f02f75d0 | ||
|
|
73d9dfc7ab | ||
|
|
6b407092d9 | ||
|
|
3168abc0a1 | ||
|
|
46ee267cfc | ||
|
|
a10bead9b5 | ||
|
|
3553e301dd | ||
|
|
02b838b9b0 | ||
|
|
b1de6d1025 | ||
|
|
bc67872218 | ||
|
|
0229fffde5 | ||
|
|
3555b87363 | ||
|
|
2dca53962e | ||
|
|
f4f71f2797 | ||
|
|
77ab9457ed | ||
|
|
4fa53b6282 | ||
|
|
790b73586b | ||
|
|
9c29c2a172 | ||
|
|
863960d33e | ||
|
|
330e5381b4 | ||
|
|
5bb411fdb8 | ||
|
|
59a9a5994e | ||
|
|
5306a71b42 | ||
|
|
3eafa2dd9e | ||
|
|
88fddb879d | ||
|
|
71491825bf | ||
|
|
30855b924a | ||
|
|
48d2e6d7fe | ||
|
|
041c83ea03 | ||
|
|
0e621c2dc9 | ||
|
|
544e7a491b | ||
|
|
a2c881fa08 | ||
|
|
c53c7af168 | ||
|
|
a2d93e5269 | ||
|
|
b392e6cfb9 | ||
|
|
13aa2d389a | ||
|
|
1e7962dfc4 | ||
|
|
1c9556c84c | ||
|
|
ca3ca7a5b5 | ||
|
|
0500befdb4 | ||
|
|
f618feab51 | ||
|
|
4b06aa134f | ||
|
|
9cde56d760 | ||
|
|
d0ea203694 | ||
|
|
c5eb3fba62 | ||
|
|
a8bc32553c | ||
|
|
88f3358320 | ||
|
|
a85bdcf2f6 | ||
|
|
caf56b313e | ||
|
|
75603c45fc | ||
|
|
89f86cc970 | ||
|
|
c09a0e4f08 | ||
|
|
7bac6c9460 | ||
|
|
0c7d0bf172 | ||
|
|
a274900188 | ||
|
|
67deefe527 | ||
|
|
823f618cba | ||
|
|
bc16c9a54a | ||
|
|
a3f30038a0 | ||
|
|
e237f618c2 | ||
|
|
688adad665 | ||
|
|
0158812afb | ||
|
|
e52e0d9b07 | ||
|
|
eb2aa2c073 |
11
.dockerignore
Normal file
11
.dockerignore
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
.vscode
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
.venv
|
||||||
|
cache
|
||||||
|
data
|
||||||
|
examples
|
||||||
|
.dockerignore
|
||||||
|
.gitattributes
|
||||||
|
.gitignore
|
||||||
|
Dockerfile
|
||||||
21
.github/CONTRIBUTING.md
vendored
Normal file
21
.github/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Contributing to LLaMA Factory
|
||||||
|
|
||||||
|
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||||
|
|
||||||
|
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||||
|
|
||||||
|
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
|
||||||
|
|
||||||
|
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
|
||||||
|
|
||||||
|
## Ways to contribute
|
||||||
|
|
||||||
|
There are several ways you can contribute to LLaMA Factory:
|
||||||
|
|
||||||
|
* Fix outstanding issues with the existing code.
|
||||||
|
* Submit issues related to bugs or desired new features.
|
||||||
|
* Contribute to the examples or to the documentation.
|
||||||
|
|
||||||
|
### Style guide
|
||||||
|
|
||||||
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# What does this PR do?
|
||||||
|
|
||||||
|
Fixes # (issue)
|
||||||
|
|
||||||
|
## Before submitting
|
||||||
|
|
||||||
|
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
||||||
7
.github/SECURITY.md
vendored
Normal file
7
.github/SECURITY.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Reporting Security Issues
|
||||||
|
|
||||||
|
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
|
||||||
|
|
||||||
|
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||||
|
|
||||||
|
Report security bugs in third-party modules to the person or team maintaining the module.
|
||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install black ruff
|
python -m pip install ruff
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
37
CITATION.cff
Normal file
37
CITATION.cff
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
date-released: 2024-03
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
authors:
|
||||||
|
- family-names: "Zheng"
|
||||||
|
given-names: "Yaowei"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Richong"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Junhao"
|
||||||
|
- family-names: "Ye"
|
||||||
|
given-names: "Yanhan"
|
||||||
|
- family-names: "Luo"
|
||||||
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Ma"
|
||||||
|
given-names: "Yongqiang"
|
||||||
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
|
preferred-citation:
|
||||||
|
type: article
|
||||||
|
authors:
|
||||||
|
- family-names: "Zheng"
|
||||||
|
given-names: "Yaowei"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Richong"
|
||||||
|
- family-names: "Zhang"
|
||||||
|
given-names: "Junhao"
|
||||||
|
- family-names: "Ye"
|
||||||
|
given-names: "Yanhan"
|
||||||
|
- family-names: "Luo"
|
||||||
|
given-names: "Zheyan"
|
||||||
|
- family-names: "Ma"
|
||||||
|
given-names: "Yongqiang"
|
||||||
|
journal: "arXiv preprint arXiv:2403.13372"
|
||||||
|
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||||
|
url: "https://arxiv.org/abs/2403.13372"
|
||||||
|
year: 2024
|
||||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
FROM nvcr.io/nvidia/pytorch:24.01-py3
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt /app/
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
|
||||||
|
|
||||||
|
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
CMD [ "python", "src/train_web.py" ]
|
||||||
10
Makefile
10
Makefile
@@ -1,11 +1,11 @@
|
|||||||
.PHONY: quality style
|
.PHONY: quality style
|
||||||
|
|
||||||
check_dirs := src tests
|
check_dirs := scripts src tests
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
black --check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
|
|
||||||
style:
|
style:
|
||||||
black $(check_dirs)
|
ruff check $(check_dirs) --fix
|
||||||
ruff $(check_dirs) --fix
|
ruff format $(check_dirs)
|
||||||
|
|||||||
241
README.md
241
README.md
@@ -5,27 +5,30 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
|
|
||||||
👋 Join our [WeChat](assets/wechat.jpg).
|
👋 Join our [WeChat](assets/wechat.jpg).
|
||||||
|
|
||||||
\[ English | [中文](README_zh.md) \]
|
\[ English | [中文](README_zh.md) \]
|
||||||
|
|
||||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
**Fine-tuning a large language model can be easy as...**
|
||||||
|
|
||||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
|
||||||
|
|
||||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
Choose your path:
|
||||||
|
|
||||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Features](#features)
|
||||||
- [Benchmark](#benchmark)
|
- [Benchmark](#benchmark)
|
||||||
- [Changelog](#changelog)
|
- [Changelog](#changelog)
|
||||||
- [Supported Models](#supported-models)
|
- [Supported Models](#supported-models)
|
||||||
@@ -38,6 +41,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [Citation](#citation)
|
- [Citation](#citation)
|
||||||
- [Acknowledgement](#acknowledgement)
|
- [Acknowledgement](#acknowledgement)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||||
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
|
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||||
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||||
@@ -55,15 +68,27 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
|
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
|
||||||
|
|
||||||
|
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
|
||||||
|
|
||||||
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
|
||||||
|
|
||||||
|
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||||
|
|
||||||
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
||||||
|
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
|
||||||
|
|
||||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||||
|
|
||||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||||
|
|
||||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
|
||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
@@ -107,16 +132,19 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -126,9 +154,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||||
|
|
||||||
|
You also can add a custom chat template to [template.py](src/llmtuner/data/template.py).
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
@@ -192,6 +222,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
@@ -209,6 +240,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
@@ -225,22 +257,37 @@ huggingface-cli login
|
|||||||
|
|
||||||
## Requirement
|
## Requirement
|
||||||
|
|
||||||
- Python 3.8+ and PyTorch 1.13.1+
|
| Mandatory | Minimum | Recommend |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf and tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
- gradio and matplotlib (used in web UI)
|
| transformers | 4.37.2 | 4.39.1 |
|
||||||
- uvicorn, fastapi and sse-starlette (used in API)
|
| datasets | 2.14.3 | 2.17.1 |
|
||||||
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
|
||||||
|
| Optional | Minimum | Recommend |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.43.0 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.6 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
\* *estimated*
|
||||||
|
|
||||||
|
| Method | Bits | 7B | 13B | 30B | 70B | 8x7B |
|
||||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||||
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
@@ -261,12 +308,14 @@ cd LLaMA-Factory
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||||
|
|
||||||
### Use ModelScope Hub (optional)
|
### Use ModelScope Hub (optional)
|
||||||
|
|
||||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||||
@@ -280,7 +329,7 @@ Then you can train the corresponding model by specifying a model ID of the Model
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||||
... # arguments (same as above)
|
... # arguments (same as below)
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
||||||
@@ -294,6 +343,13 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
|
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
|
||||||
|
|
||||||
|
|
||||||
|
#### LLaMA Board GUI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
|
```
|
||||||
|
|
||||||
#### Pre-Training
|
#### Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -360,7 +416,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-6 \
|
--learning_rate 1e-5 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -394,6 +450,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
||||||
|
|
||||||
@@ -422,19 +481,24 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
#### Use Huggingface Accelerate
|
#### Use Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate config # configure the environment
|
accelerate launch --config_file config.yaml src/train_bash.py \
|
||||||
accelerate launch src/train_bash.py # arguments (same as above)
|
--ddp_timeout 180000000 \
|
||||||
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Example config for LoRA training</summary>
|
<details><summary>Example config.yaml for LoRA training</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
distributed_type: MULTI_GPU
|
distributed_type: MULTI_GPU
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
gpu_ids: all
|
gpu_ids: all
|
||||||
@@ -453,15 +517,19 @@ use_cpu: false
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> We commend using Accelerate for LoRA tuning.
|
||||||
|
|
||||||
#### Use DeepSpeed
|
#### Use DeepSpeed
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
... # arguments (same as above)
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
|
<details><summary>Example ds_config.json for full-parameter training with DeepSpeed ZeRO-2</summary>
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -473,29 +541,36 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": "auto",
|
"enabled": "auto",
|
||||||
"loss_scale": 0,
|
"loss_scale": 0,
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
"allgather_bucket_size": 5e8,
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": true,
|
||||||
"reduce_scatter": true,
|
"reduce_scatter": true,
|
||||||
"reduce_bucket_size": 5e8,
|
"reduce_bucket_size": 5e8,
|
||||||
"overlap_comm": false,
|
"contiguous_gradients": true,
|
||||||
"contiguous_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Refer to [examples](examples) for more training scripts.
|
||||||
|
|
||||||
### Merge LoRA weights and export model
|
### Merge LoRA weights and export model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -509,12 +584,14 @@ python src/export_model.py \
|
|||||||
> Merging LoRA weights into a quantized model is not supported.
|
> Merging LoRA weights into a quantized model is not supported.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
|
> Use `--model_name_or_path path_to_export` solely to use the exported model.
|
||||||
|
>
|
||||||
|
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model with AutoGPTQ after merging the LoRA weights.
|
||||||
|
|
||||||
### API Demo
|
### Inference with OpenAI-style API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -524,20 +601,20 @@ python src/api_demo.py \
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit `http://localhost:8000/docs` for API documentation.
|
> Visit `http://localhost:8000/docs` for API documentation.
|
||||||
|
|
||||||
### CLI Demo
|
### Inference with command line
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora
|
--finetuning_type lora
|
||||||
```
|
```
|
||||||
|
|
||||||
### Web Demo
|
### Inference with web browser
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -571,7 +648,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
--per_device_eval_batch_size 8 \
|
--per_device_eval_batch_size 1 \
|
||||||
--max_samples 100 \
|
--max_samples 100 \
|
||||||
--predict_with_generate \
|
--predict_with_generate \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -583,13 +660,60 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
||||||
|
|
||||||
|
### Dockerize Training
|
||||||
|
|
||||||
|
#### Get ready
|
||||||
|
|
||||||
|
Necessary dockerized environment is needed, such as Docker or Docker Compose.
|
||||||
|
|
||||||
|
#### Docker support
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f ./Dockerfile -t llama-factory:latest .
|
||||||
|
|
||||||
|
docker run --gpus=all -v ./hf_cache:/root/.cache/huggingface/ -v ./data:/app/data -v ./output:/app/output -p 7860:7860 --shm-size 16G --name llama_factory -d llama-factory:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Docker Compose support
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f ./docker-compose.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Details about volume:
|
||||||
|
> * hf_cache: Utilize Huggingface cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
|
> * data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
|
> * output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
## Projects using LLaMA Factory
|
## Projects using LLaMA Factory
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||||
@@ -598,18 +722,19 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If this work is helpful, please kindly cite as:
|
If this work is helpful, please kindly cite as:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@Misc{llama-factory,
|
@article{zheng2024llamafactory,
|
||||||
title = {LLaMA Factory},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author = {hiyouga},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
journal={arXiv preprint arXiv:2403.13372},
|
||||||
year = {2023}
|
year={2024},
|
||||||
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
212
README_zh.md
212
README_zh.md
@@ -5,27 +5,30 @@
|
|||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
[](https://pypi.org/project/llmtuner/)
|
[](https://pypi.org/project/llmtuner/)
|
||||||
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||||
|
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||||
|
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||||
|
|
||||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||||
|
|
||||||
\[ [English](README.md) | 中文 \]
|
\[ [English](README.md) | 中文 \]
|
||||||
|
|
||||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
**微调大模型可以像这样轻松…**
|
||||||
|
|
||||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd-d76c6d0a6594
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
选择你的打开方式:
|
||||||
|
|
||||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
- **Colab**:https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
|
- [项目特色](#项目特色)
|
||||||
- [性能指标](#性能指标)
|
- [性能指标](#性能指标)
|
||||||
- [更新日志](#更新日志)
|
- [更新日志](#更新日志)
|
||||||
- [模型](#模型)
|
- [模型](#模型)
|
||||||
@@ -38,6 +41,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [引用](#引用)
|
- [引用](#引用)
|
||||||
- [致谢](#致谢)
|
- [致谢](#致谢)
|
||||||
|
|
||||||
|
## 项目特色
|
||||||
|
|
||||||
|
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
|
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
||||||
|
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||||
|
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
|
||||||
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
|
||||||
## 性能指标
|
## 性能指标
|
||||||
|
|
||||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||||
@@ -55,15 +68,27 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。
|
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
|
||||||
|
|
||||||
|
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。
|
||||||
|
|
||||||
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。
|
||||||
|
|
||||||
|
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||||
|
|
||||||
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
||||||
|
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`。
|
||||||
|
|
||||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||||
|
|
||||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
|
||||||
|
|
||||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
@@ -107,16 +132,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||||
|
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | att_proj | olmo |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||||
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
| [Yi](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -126,6 +154,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||||
|
|
||||||
|
您也可以在 [template.py](src/llmtuner/data/template.py) 中添加自己的对话模板。
|
||||||
|
|
||||||
## 训练方法
|
## 训练方法
|
||||||
|
|
||||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||||
@@ -192,6 +222,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||||
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||||
@@ -209,6 +240,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
@@ -225,22 +257,37 @@ huggingface-cli login
|
|||||||
|
|
||||||
## 软硬件依赖
|
## 软硬件依赖
|
||||||
|
|
||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
| 必需项 | 至少 | 推荐 |
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
| ------------ | ------- | --------- |
|
||||||
- sentencepiece, protobuf 和 tiktoken
|
| python | 3.8 | 3.10 |
|
||||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
| transformers | 4.37.2 | 4.39.1 |
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
| datasets | 2.14.3 | 2.17.1 |
|
||||||
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
|
||||||
|
| 可选项 | 至少 | 推荐 |
|
||||||
|
| ------------ | ------- | --------- |
|
||||||
|
| CUDA | 11.6 | 12.2 |
|
||||||
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
|
| bitsandbytes | 0.39.0 | 0.43.0 |
|
||||||
|
| flash-attn | 2.3.0 | 2.5.6 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
\* *估算值*
|
||||||
|
|
||||||
|
| 训练方法 | 精度 | 7B | 13B | 30B | 70B | 8x7B |
|
||||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||||
|
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||||
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||||
|
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
@@ -261,12 +308,14 @@ cd LLaMA-Factory
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||||
|
|
||||||
### 使用魔搭社区(可跳过)
|
### 使用魔搭社区(可跳过)
|
||||||
|
|
||||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||||
@@ -280,7 +329,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--model_name_or_path modelscope/Llama-2-7b-ms \
|
--model_name_or_path modelscope/Llama-2-7b-ms \
|
||||||
... # 参数同上
|
... # 参数同下
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
||||||
@@ -294,6 +343,12 @@ CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
|||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
|
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
|
||||||
|
|
||||||
|
#### LLaMA Board GUI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
|
```
|
||||||
|
|
||||||
#### 预训练
|
#### 预训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -360,7 +415,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-6 \
|
--learning_rate 1e-5 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -394,6 +449,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
||||||
|
|
||||||
@@ -422,19 +480,24 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
|
||||||
|
|
||||||
### 多 GPU 分布式训练
|
### 多 GPU 分布式训练
|
||||||
|
|
||||||
#### 使用 Huggingface Accelerate
|
#### 使用 Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate config # 首先配置分布式环境
|
accelerate launch --config_file config.yaml src/train_bash.py \
|
||||||
accelerate launch src/train_bash.py # 参数同上
|
--ddp_timeout 180000000 \
|
||||||
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>LoRA 训练的 Accelerate 配置示例</summary>
|
<details><summary>使用 Accelerate 进行 LoRA 训练的 config.yaml 示例</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
distributed_type: MULTI_GPU
|
distributed_type: MULTI_GPU
|
||||||
downcast_bf16: 'no'
|
downcast_bf16: 'no'
|
||||||
gpu_ids: all
|
gpu_ids: all
|
||||||
@@ -453,15 +516,19 @@ use_cpu: false
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 我们推荐使用 Accelerate 进行 LoRA 训练。
|
||||||
|
|
||||||
#### 使用 DeepSpeed
|
#### 使用 DeepSpeed
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
deepspeed --num_gpus 8 src/train_bash.py \
|
||||||
--deepspeed ds_config.json \
|
--deepspeed ds_config.json \
|
||||||
|
--ddp_timeout 180000000 \
|
||||||
... # 参数同上
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
|
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 ds_config.json 示例</summary>
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -473,29 +540,36 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
|||||||
"fp16": {
|
"fp16": {
|
||||||
"enabled": "auto",
|
"enabled": "auto",
|
||||||
"loss_scale": 0,
|
"loss_scale": 0,
|
||||||
"initial_scale_power": 16,
|
|
||||||
"loss_scale_window": 1000,
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
"hysteresis": 2,
|
"hysteresis": 2,
|
||||||
"min_loss_scale": 1
|
"min_loss_scale": 1
|
||||||
},
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 2,
|
"stage": 2,
|
||||||
"allgather_partitions": true,
|
"allgather_partitions": true,
|
||||||
"allgather_bucket_size": 5e8,
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": true,
|
||||||
"reduce_scatter": true,
|
"reduce_scatter": true,
|
||||||
"reduce_bucket_size": 5e8,
|
"reduce_bucket_size": 5e8,
|
||||||
"overlap_comm": false,
|
"contiguous_gradients": true,
|
||||||
"contiguous_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 更多训练脚本请查看 [examples](examples)。
|
||||||
|
|
||||||
### 合并 LoRA 权重并导出模型
|
### 合并 LoRA 权重并导出模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -509,12 +583,14 @@ python src/export_model.py \
|
|||||||
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
|
> 仅使用 `--model_name_or_path path_to_export` 来加载导出后的模型。
|
||||||
|
>
|
||||||
|
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 基于 AutoGPTQ 量化模型。
|
||||||
|
|
||||||
### API 服务
|
### 使用 OpenAI 风格 API 推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python src/api_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -524,20 +600,20 @@ python src/api_demo.py \
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||||
|
|
||||||
### 命令行测试
|
### 使用命令行推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora
|
--finetuning_type lora
|
||||||
```
|
```
|
||||||
|
|
||||||
### 浏览器测试
|
### 使用浏览器推理
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--adapter_name_or_path path_to_checkpoint \
|
--adapter_name_or_path path_to_checkpoint \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -571,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_predict_result \
|
||||||
--per_device_eval_batch_size 8 \
|
--per_device_eval_batch_size 1 \
|
||||||
--max_samples 100 \
|
--max_samples 100 \
|
||||||
--predict_with_generate \
|
--predict_with_generate \
|
||||||
--fp16
|
--fp16
|
||||||
@@ -585,11 +661,32 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
## 使用了 LLaMA Factory 的项目
|
## 使用了 LLaMA Factory 的项目
|
||||||
|
|
||||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||||
|
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||||
|
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||||
|
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||||
|
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||||
|
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||||
|
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||||
|
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||||
|
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||||
|
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||||
|
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||||
|
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||||
|
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||||
|
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||||
|
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||||
|
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||||
|
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||||
|
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||||
|
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||||
|
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||||
@@ -598,18 +695,19 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@Misc{llama-factory,
|
@article{zheng2024llamafactory,
|
||||||
title = {LLaMA Factory},
|
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||||
author = {hiyouga},
|
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
journal={arXiv preprint arXiv:2403.13372},
|
||||||
year = {2023}
|
year={2024},
|
||||||
|
url={http://arxiv.org/abs/2403.13372}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
|
||||||
_DESCRIPTION = "BELLE multiturn chat dataset."
|
_DESCRIPTION = "BELLE multiturn chat dataset."
|
||||||
|
|
||||||
_CITATION = """\
|
_CITATION = """\
|
||||||
@@ -13,9 +16,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M"
|
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, Generator, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
_DESCRIPTION = "An example of dataset."
|
_DESCRIPTION = "An example of dataset."
|
||||||
@@ -40,7 +40,7 @@ class ExampleDataset(datasets.GeneratorBasedBuilder):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
|
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||||
for key, example in enumerate(example_dataset):
|
for key, example in enumerate(example_dataset):
|
||||||
yield key, example
|
yield key, example
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
|
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
|
|||||||
1
data/orca_rlhf.json.REMOVED.git-id
Normal file
1
data/orca_rlhf.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
|||||||
|
736bcedea2b24a1414765c6d69cbdafaea839f3c
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import datasets
|
import datasets
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
|
||||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||||
|
|
||||||
@@ -16,9 +18,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
|
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
|
|||||||
23
docker-compose.yml
Normal file
23
docker-compose.yml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
llama-factory:
|
||||||
|
build:
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
context: .
|
||||||
|
container_name: llama_factory
|
||||||
|
volumes:
|
||||||
|
- ./hf_cache:/root/.cache/huggingface/
|
||||||
|
- ./data:/app/data
|
||||||
|
- ./output:/app/output
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
ipc: host
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: "all"
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
||||||
25
examples/accelerate/fsdp_config.yaml
Normal file
25
examples/accelerate/fsdp_config.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
18
examples/accelerate/master_config.yaml
Normal file
18
examples/accelerate/master_config.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: 192.168.0.1
|
||||||
|
main_process_port: 29555
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 2
|
||||||
|
num_processes: 16
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
16
examples/accelerate/single_config.yaml
Normal file
16
examples/accelerate/single_config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
18
examples/accelerate/slave_config.yaml
Normal file
18
examples/accelerate/slave_config.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
gpu_ids: all
|
||||||
|
machine_rank: 1
|
||||||
|
main_process_ip: 192.168.0.1
|
||||||
|
main_process_port: 29555
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 2
|
||||||
|
num_processes: 16
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
31
examples/extras/galore/adamw.sh
Normal file
31
examples/extras/galore/adamw.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/extras/galore/adamw_8bit_bf16.sh
Normal file
32
examples/extras/galore/adamw_8bit_bf16.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--optim adamw_8bit \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--pure_bf16
|
||||||
35
examples/extras/galore/galore_adamw.sh
Normal file
35
examples/extras/galore/galore_adamw.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--use_galore \
|
||||||
|
--galore_layerwise \
|
||||||
|
--galore_target mlp,self_attn \
|
||||||
|
--galore_rank 128 \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
36
examples/extras/galore/galore_adamw_8bit_bf16.sh
Normal file
36
examples/extras/galore/galore_adamw_8bit_bf16.sh
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--optim adamw_8bit \
|
||||||
|
--use_galore \
|
||||||
|
--galore_layerwise \
|
||||||
|
--galore_target mlp,self_attn \
|
||||||
|
--galore_rank 128 \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 1 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--pure_bf16
|
||||||
6
examples/extras/llama_pro/expand.sh
Normal file
6
examples/extras/llama_pro/expand.sh
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python ../../../scripts/llama_pro.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--output_dir ../../../models/llama2-7b-pro \
|
||||||
|
--num_expand 8
|
||||||
34
examples/extras/llama_pro/sft.sh
Normal file
34
examples/extras/llama_pro/sft.sh
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path ../../../models/llama2-7b-pro \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type freeze \
|
||||||
|
--name_module_trainable all \
|
||||||
|
--num_layer_trainable 8 \
|
||||||
|
--use_llama_pro \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B-Pro/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
33
examples/extras/loraplus/sft.sh
Normal file
33
examples/extras/loraplus/sft.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16 \
|
||||||
|
--loraplus_lr_ratio 16.0
|
||||||
5
examples/fsdp_qlora/README.md
Normal file
5
examples/fsdp_qlora/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
```bash
|
||||||
|
pip install git+https://github.com/huggingface/transformers.git
|
||||||
|
pip install "accelerate>=0.28.0"
|
||||||
|
pip install "bitsandbytes>=0.43.0"
|
||||||
|
```
|
||||||
33
examples/fsdp_qlora/fsdp.sh
Normal file
33
examples/fsdp_qlora/fsdp.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||||
|
--config_file ../accelerate/fsdp_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-70b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-70B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
38
examples/full_multi_gpu/multi_node.sh
Normal file
38
examples/full_multi_gpu/multi_node.sh
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python -m torch.distributed.run \
|
||||||
|
--nproc_per_node $NPROC_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 1800000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/full_multi_gpu/single_node.sh
Normal file
32
examples/full_multi_gpu/single_node.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||||
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 1800000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
35
examples/lora_multi_gpu/multi_node.sh
Normal file
35
examples/lora_multi_gpu/multi_node.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||||
|
--config_file ../accelerate/master_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 1800000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
35
examples/lora_multi_gpu/single_node.sh
Normal file
35
examples/lora_multi_gpu/single_node.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
||||||
|
--config_file ../accelerate/single_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--ddp_timeout 1800000 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
8
examples/lora_single_gpu/README.md
Normal file
8
examples/lora_single_gpu/README.md
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
Usage:
|
||||||
|
|
||||||
|
- `pretrain.sh`: do pre-train (optional)
|
||||||
|
- `sft.sh`: do supervised fine-tune
|
||||||
|
- `reward.sh`: do reward modeling (must after sft.sh)
|
||||||
|
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
|
||||||
|
- `dpo.sh`: do DPO training (must after sft.sh)
|
||||||
|
- `predict.sh`: do predict (must after sft.sh and dpo.sh)
|
||||||
35
examples/lora_single_gpu/dpo.sh
Normal file
35
examples/lora_single_gpu/dpo.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--dpo_ftx 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/lora_single_gpu/ppo.sh
Normal file
32
examples/lora_single_gpu/ppo.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage ppo \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset alpaca_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--reward_model ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 512 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 1000 \
|
||||||
|
--top_k 0 \
|
||||||
|
--top_p 0.9 \
|
||||||
|
--max_new_tokens 256 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
19
examples/lora_single_gpu/predict.sh
Normal file
19
examples/lora_single_gpu/predict.sh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_predict \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/predict \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--max_samples 20 \
|
||||||
|
--predict_with_generate
|
||||||
31
examples/lora_single_gpu/pretrain.sh
Normal file
31
examples/lora_single_gpu/pretrain.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage pt \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset c4_demo \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 10000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
33
examples/lora_single_gpu/reward.sh
Normal file
33
examples/lora_single_gpu/reward.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage rm \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--create_new_adapter \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/reward \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--max_samples 5000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
32
examples/lora_single_gpu/sft.sh
Normal file
32
examples/lora_single_gpu/sft.sh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
4
examples/merge_lora/README.md
Normal file
4
examples/merge_lora/README.md
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
Usage:
|
||||||
|
|
||||||
|
- `merge.sh`: merge the lora weights
|
||||||
|
- `quantize.sh`: quantize the model with AutoGPTQ (must after merge.sh, optional)
|
||||||
10
examples/merge_lora/merge.sh
Normal file
10
examples/merge_lora/merge.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--export_dir ../../models/llama2-7b-sft \
|
||||||
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
10
examples/merge_lora/quantize.sh
Normal file
10
examples/merge_lora/quantize.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
|
--model_name_or_path ../../models/llama2-7b-sft \
|
||||||
|
--template default \
|
||||||
|
--export_dir ../../models/llama2-7b-sft-int4 \
|
||||||
|
--export_quantization_bit 4 \
|
||||||
|
--export_quantization_dataset ../../data/c4_demo.json \
|
||||||
|
--export_size 2 \
|
||||||
|
--export_legacy_format False
|
||||||
30
examples/qlora_single_gpu/aqlm.sh
Normal file
30
examples/qlora_single_gpu/aqlm.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/awq.sh
Normal file
30
examples/qlora_single_gpu/awq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
30
examples/qlora_single_gpu/gptq.sh
Normal file
30
examples/qlora_single_gpu/gptq.sh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
@@ -2,11 +2,8 @@
|
|||||||
requires = ["setuptools>=61.0"]
|
requires = ["setuptools>=61.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
line-length = 119
|
|
||||||
target-version = ["py38"]
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
target-version = "py38"
|
||||||
line-length = 119
|
line-length = 119
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
|
|
||||||
@@ -17,17 +14,7 @@ select = ["C", "E", "F", "I", "W"]
|
|||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
lines-after-imports = 2
|
lines-after-imports = 2
|
||||||
known-first-party = ["llmtuner"]
|
known-first-party = ["llmtuner"]
|
||||||
|
known-third-party = [
|
||||||
[tool.ruff.format]
|
|
||||||
quote-style = "double"
|
|
||||||
indent-style = "space"
|
|
||||||
skip-magic-trailing-comma = false
|
|
||||||
line-ending = "auto"
|
|
||||||
|
|
||||||
[isort]
|
|
||||||
default_section = "FIRSTPARTY"
|
|
||||||
known_first_party = "llmtuner"
|
|
||||||
known_third_party = [
|
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"datasets",
|
"datasets",
|
||||||
"gradio",
|
"gradio",
|
||||||
@@ -37,10 +24,10 @@ known_third_party = [
|
|||||||
"transformers",
|
"transformers",
|
||||||
"trl"
|
"trl"
|
||||||
]
|
]
|
||||||
line_length = 119
|
|
||||||
lines_after_imports = 2
|
[tool.ruff.format]
|
||||||
multi_line_output = 3
|
quote-style = "double"
|
||||||
include_trailing_comma = true
|
indent-style = "space"
|
||||||
force_grid_wrap = 0
|
docstring-code-format = true
|
||||||
use_parentheses = true
|
skip-magic-trailing-comma = false
|
||||||
ensure_newline_before_comments = true
|
line-ending = "auto"
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.37.2
|
transformers>=4.37.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.27.2
|
||||||
peft>=0.8.2
|
peft>=0.9.0
|
||||||
trl>=0.7.6
|
trl>=0.8.1
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
jieba
|
|
||||||
rouge-chinese
|
|
||||||
nltk
|
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic
|
pydantic
|
||||||
fastapi
|
fastapi
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
|
fire
|
||||||
|
galore-torch
|
||||||
|
|||||||
26
setup.py
26
setup.py
@@ -1,13 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
||||||
version, = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
@@ -18,8 +19,21 @@ def get_requires():
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def main():
|
extra_require = {
|
||||||
|
"deepspeed": ["deepspeed"],
|
||||||
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
|
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
||||||
|
"vllm": ["vllm>=0.3.3"],
|
||||||
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
|
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||||
|
"awq": ["autoawq"],
|
||||||
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
|
"qwen": ["tiktoken", "transformers_stream_generator"],
|
||||||
|
"quality": ["ruff"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
setup(
|
setup(
|
||||||
name="llmtuner",
|
name="llmtuner",
|
||||||
version=get_version(),
|
version=get_version(),
|
||||||
@@ -35,8 +49,9 @@ def main():
|
|||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.8.0",
|
||||||
install_requires=get_requires(),
|
install_requires=get_requires(),
|
||||||
|
extras_require=extra_require,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
@@ -46,8 +61,9 @@ def main():
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ from .train import export_model, run_exp
|
|||||||
from .webui import create_ui, create_web_demo
|
from .webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.5.2"
|
__version__ = "0.6.0"
|
||||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@@ -73,13 +72,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
|
||||||
role_mapping = {
|
role_mapping = {
|
||||||
Role.USER: DataRole.USER,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
Role.SYSTEM: DataRole.SYSTEM,
|
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||||
Role.FUNCTION: DataRole.FUNCTION,
|
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||||
Role.TOOL: DataRole.OBSERVATION,
|
Role.TOOL: DataRole.OBSERVATION.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
@@ -89,13 +87,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if not chat_model.can_generate:
|
if not chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
|
||||||
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
system = request.messages.pop(0).content
|
system = request.messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
system = ""
|
system = ""
|
||||||
@@ -105,11 +103,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||||
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
||||||
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
|
||||||
|
|
||||||
tool_list = request.tools
|
tool_list = request.tools
|
||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
@@ -120,20 +119,15 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
else:
|
else:
|
||||||
tools = ""
|
tools = ""
|
||||||
|
|
||||||
async with semaphore:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
|
|
||||||
|
|
||||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
generate = stream_chat_completion(messages, system, tools, request)
|
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = await chat_model.achat(
|
||||||
messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
@@ -147,7 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
if tools:
|
if tools:
|
||||||
result = chat_model.template.format_tools.extract(response.response_text)
|
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||||
else:
|
else:
|
||||||
result = response.response_text
|
result = response.response_text
|
||||||
|
|
||||||
@@ -176,7 +170,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
def stream_chat_completion(
|
async def stream_chat_completion(
|
||||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
@@ -185,7 +179,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
async for new_token in chat_model.astream_chat(
|
||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
@@ -194,11 +188,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_token) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
@@ -212,18 +206,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||||
if chat_model.can_generate:
|
if chat_model.engine.can_generate:
|
||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
async with semaphore:
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
return await loop.run_in_executor(None, get_score, request)
|
|
||||||
|
|
||||||
def get_score(request: ScoreEvaluationRequest):
|
|
||||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class ChatCompletionMessage(BaseModel):
|
|||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
tools: Optional[list] = []
|
tools: list = []
|
||||||
do_sample: bool = True
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
|
from .base_engine import BaseEngine
|
||||||
from .chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ChatModel"]
|
__all__ = ["BaseEngine", "ChatModel"]
|
||||||
|
|||||||
69
src/llmtuner/chat/base_engine.py
Normal file
69
src/llmtuner/chat/base_engine.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Response:
|
||||||
|
response_text: str
|
||||||
|
response_length: int
|
||||||
|
prompt_length: int
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC):
|
||||||
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
can_generate: bool
|
||||||
|
template: "Template"
|
||||||
|
generating_args: Dict[str, Any]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def start(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]: ...
|
||||||
@@ -1,124 +1,55 @@
|
|||||||
from dataclasses import dataclass
|
import asyncio
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
|
||||||
from ..extras.misc import get_logits_processor
|
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from .hf_engine import HuggingfaceEngine
|
||||||
|
from .vllm_engine import VllmEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
if TYPE_CHECKING:
|
||||||
class Response:
|
from .base_engine import BaseEngine, Response
|
||||||
response_text: str
|
|
||||||
response_length: int
|
|
||||||
prompt_length: int
|
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
finish_reason: Literal["stop", "length"]
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
if model_args.infer_backend == "huggingface":
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
elif model_args.infer_backend == "vllm":
|
||||||
)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
else:
|
||||||
self.model = dispatch_model(self.model)
|
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
|
||||||
|
|
||||||
def _process_args(
|
self._loop = asyncio.new_event_loop()
|
||||||
self,
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
messages: Sequence[Dict[str, str]],
|
self._thread.start()
|
||||||
system: Optional[str] = None,
|
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||||
tools: Optional[str] = None,
|
|
||||||
**input_kwargs,
|
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
|
||||||
prompt, _ = self.template.encode_oneturn(
|
|
||||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
|
||||||
)
|
|
||||||
prompt_length = len(prompt)
|
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
|
||||||
|
|
||||||
do_sample = input_kwargs.pop("do_sample", None)
|
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
|
||||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
|
||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
|
||||||
|
|
||||||
generating_args = self.generating_args.to_dict()
|
|
||||||
generating_args.update(
|
|
||||||
dict(
|
|
||||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
|
||||||
temperature=temperature or generating_args["temperature"],
|
|
||||||
top_p=top_p or generating_args["top_p"],
|
|
||||||
top_k=top_k or generating_args["top_k"],
|
|
||||||
num_return_sequences=num_return_sequences or 1,
|
|
||||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
|
||||||
generating_args["do_sample"] = True
|
|
||||||
|
|
||||||
if max_length:
|
|
||||||
generating_args.pop("max_new_tokens", None)
|
|
||||||
generating_args["max_length"] = max_length
|
|
||||||
|
|
||||||
if max_new_tokens:
|
|
||||||
generating_args.pop("max_length", None)
|
|
||||||
generating_args["max_new_tokens"] = max_new_tokens
|
|
||||||
|
|
||||||
gen_kwargs = dict(
|
|
||||||
inputs=input_ids,
|
|
||||||
generation_config=GenerationConfig(**generating_args),
|
|
||||||
logits_processor=get_logits_processor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[Response]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||||
raise ValueError("The current model does not support `chat`.")
|
return task.result()
|
||||||
|
|
||||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
async def achat(
|
||||||
generate_output = self.model.generate(**gen_kwargs)
|
self,
|
||||||
response_ids = generate_output[:, prompt_length:]
|
messages: Sequence[Dict[str, str]],
|
||||||
response = self.tokenizer.batch_decode(
|
system: Optional[str] = None,
|
||||||
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
tools: Optional[str] = None,
|
||||||
)
|
**input_kwargs,
|
||||||
results = []
|
) -> List["Response"]:
|
||||||
for i in range(len(response)):
|
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
|
||||||
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
|
||||||
results.append(
|
|
||||||
Response(
|
|
||||||
response_text=response[i],
|
|
||||||
response_length=response_length,
|
|
||||||
prompt_length=prompt_length,
|
|
||||||
finish_reason="stop" if len(eos_index) else "length",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@@ -126,44 +57,35 @@ class ChatModel:
|
|||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if not self.can_generate:
|
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||||
raise ValueError("The current model does not support `stream_chat`.")
|
while True:
|
||||||
|
try:
|
||||||
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
|
yield task.result()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
async def astream_chat(
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
self,
|
||||||
gen_kwargs["streamer"] = streamer
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||||
|
yield new_token
|
||||||
|
|
||||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
def get_scores(
|
||||||
thread.start()
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||||
|
return task.result()
|
||||||
|
|
||||||
yield from streamer
|
async def aget_scores(
|
||||||
|
self,
|
||||||
@torch.inference_mode()
|
batch_input: List[str],
|
||||||
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
**input_kwargs,
|
||||||
if self.can_generate:
|
) -> List[float]:
|
||||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||||
|
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
|
||||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
batch_input,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=True,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
|
||||||
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|||||||
263
src/llmtuner/chat/hf_engine.py
Normal file
263
src/llmtuner/chat/hf_engine.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
from threading import Thread
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_logits_processor
|
||||||
|
from ..model import load_model_and_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
|
from ..data import Template
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
self.model, self.tokenizer = load_model_and_tokenizer(
|
||||||
|
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
|
)
|
||||||
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _process_args(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = template.encode_oneturn(
|
||||||
|
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
|
||||||
|
do_sample = input_kwargs.pop("do_sample", None)
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||||
|
generating_args["do_sample"] = True
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args.pop("max_new_tokens", None)
|
||||||
|
generating_args["max_length"] = max_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args.pop("max_length", None)
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
gen_kwargs = dict(
|
||||||
|
inputs=inputs,
|
||||||
|
generation_config=GenerationConfig(**generating_args),
|
||||||
|
logits_processor=get_logits_processor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List["Response"]:
|
||||||
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
generate_output = model.generate(**gen_kwargs)
|
||||||
|
response_ids = generate_output[:, prompt_length:]
|
||||||
|
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
|
results = []
|
||||||
|
for i in range(len(response)):
|
||||||
|
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||||
|
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=response[i],
|
||||||
|
response_length=response_length,
|
||||||
|
prompt_length=prompt_length,
|
||||||
|
finish_reason="stop" if len(eos_index) else "length",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _stream_chat(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
template: "Template",
|
||||||
|
generating_args: Dict[str, Any],
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Callable[[], str]:
|
||||||
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
|
model, tokenizer, template, generating_args, messages, system, tools, input_kwargs
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
gen_kwargs["streamer"] = streamer
|
||||||
|
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def stream():
|
||||||
|
try:
|
||||||
|
return streamer.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _get_scores(
|
||||||
|
model: "PreTrainedModelWrapper",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
batch_input: List[str],
|
||||||
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> List[float]:
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
|
inputs = tokenizer(
|
||||||
|
batch_input,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
input_ids: torch.Tensor = inputs["input_ids"]
|
||||||
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(input_ids.size(0)):
|
||||||
|
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
||||||
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
|
scores.append(values[i, end_index].nan_to_num().item())
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
if not self.can_generate:
|
||||||
|
raise ValueError("The current model does not support `stream_chat`.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.template,
|
||||||
|
self.generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
input_kwargs,
|
||||||
|
)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
stream = self._stream_chat(*input_args)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await loop.run_in_executor(pool, stream)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
if self.can_generate:
|
||||||
|
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
|
async with self._semaphore:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||||
149
src/llmtuner/chat/vllm_engine.py
Normal file
149
src/llmtuner/chat/vllm_engine.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras.misc import get_device_count
|
||||||
|
from ..extras.packages import is_vllm_available
|
||||||
|
from ..model import load_tokenizer
|
||||||
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
class VllmEngine(BaseEngine):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
) -> None:
|
||||||
|
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||||
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model=model_args.model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=model_args.vllm_maxlen,
|
||||||
|
tensor_parallel_size=get_device_count() or 1,
|
||||||
|
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||||
|
disable_log_stats=True,
|
||||||
|
disable_log_requests=True,
|
||||||
|
enforce_eager=model_args.vllm_enforce_eager,
|
||||||
|
)
|
||||||
|
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
self.tokenizer = load_tokenizer(model_args)
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
|
async def _generate(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
|
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||||
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
|
prompt_ids, _ = self.template.encode_oneturn(
|
||||||
|
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
|
)
|
||||||
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||||
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
|
generating_args = self.generating_args.copy()
|
||||||
|
generating_args.update(
|
||||||
|
dict(
|
||||||
|
temperature=temperature or generating_args["temperature"],
|
||||||
|
top_p=top_p or generating_args["top_p"],
|
||||||
|
top_k=top_k or generating_args["top_k"],
|
||||||
|
num_return_sequences=num_return_sequences or 1,
|
||||||
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
generating_args["max_new_tokens"] = max_length - prompt_length
|
||||||
|
|
||||||
|
if max_new_tokens:
|
||||||
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=generating_args["num_return_sequences"],
|
||||||
|
repetition_penalty=generating_args["repetition_penalty"],
|
||||||
|
temperature=generating_args["temperature"],
|
||||||
|
top_p=generating_args["top_p"],
|
||||||
|
top_k=generating_args["top_k"],
|
||||||
|
use_beam_search=generating_args["num_beams"] > 1,
|
||||||
|
length_penalty=generating_args["length_penalty"],
|
||||||
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
|
max_tokens=generating_args["max_new_tokens"],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
result_generator = self.model.generate(
|
||||||
|
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
||||||
|
)
|
||||||
|
return result_generator
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List["Response"]:
|
||||||
|
final_output = None
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for request_output in generator:
|
||||||
|
final_output = request_output
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for output in final_output.outputs:
|
||||||
|
results.append(
|
||||||
|
Response(
|
||||||
|
response_text=output.text,
|
||||||
|
response_length=len(output.token_ids),
|
||||||
|
prompt_length=len(final_output.prompt_token_ids),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
tools: Optional[str] = None,
|
||||||
|
**input_kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
generated_text = ""
|
||||||
|
generator = await self._generate(messages, system, tools, **input_kwargs)
|
||||||
|
async for result in generator:
|
||||||
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
|
generated_text = result.outputs[0].text
|
||||||
|
yield delta_text
|
||||||
|
|
||||||
|
async def get_scores(
|
||||||
|
self,
|
||||||
|
batch_input: List[str],
|
||||||
|
**input_kwargs,
|
||||||
|
) -> List[float]:
|
||||||
|
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import get_template_and_fix_tokenizer, templates
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from .utils import Role, split_dataset
|
from .utils import Role, split_dataset
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||||||
prompt = []
|
prompt = []
|
||||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||||
@@ -29,12 +29,14 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||||
content.append(examples[dataset_attr.query][i])
|
content.append(examples[dataset_attr.query][i])
|
||||||
|
|
||||||
prompt.append({"role": Role.USER, "content": "\n".join(content)})
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||||
|
|
||||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
response = [
|
||||||
|
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||||
|
]
|
||||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||||
else:
|
else:
|
||||||
response = []
|
response = []
|
||||||
|
|
||||||
@@ -49,11 +51,11 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
|||||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||||
dataset_attr.function_tag: Role.FUNCTION,
|
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||||
dataset_attr.system_tag: Role.SYSTEM,
|
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||||
}
|
}
|
||||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
@@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Literal["default"] = "default"
|
tool_format: Optional[Literal["default"]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
@@ -83,12 +83,30 @@ class Formatter(ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmptyFormatter(Formatter):
|
class EmptyFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if has_placeholder:
|
||||||
|
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
return self.slots
|
return self.slots
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StringFormatter(Formatter):
|
class StringFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_placeholder = False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
||||||
|
has_placeholder = True
|
||||||
|
|
||||||
|
if not has_placeholder:
|
||||||
|
raise ValueError("A placeholder is required in the string formatter.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
@@ -109,6 +127,17 @@ class StringFormatter(Formatter):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
has_name, has_args = False, False
|
||||||
|
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||||
|
if "{{name}}" in slot:
|
||||||
|
has_name = True
|
||||||
|
if "{{arguments}}" in slot:
|
||||||
|
has_args = True
|
||||||
|
|
||||||
|
if not has_name or not has_args:
|
||||||
|
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
@@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.tool_format is None:
|
||||||
|
raise ValueError("Tool format was not found.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, List, Literal, Union
|
from typing import TYPE_CHECKING, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@@ -10,7 +10,7 @@ from .aligner import align_dataset
|
|||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer
|
from .template import get_template_and_fix_tokenizer
|
||||||
from .utils import checksum
|
from .utils import checksum, merge_dataset
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -29,7 +29,7 @@ def load_single_dataset(
|
|||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
):
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||||
@@ -44,7 +44,7 @@ def load_single_dataset(
|
|||||||
|
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_files = []
|
data_files = []
|
||||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
@@ -111,30 +111,6 @@ def load_single_dataset(
|
|||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
|
||||||
if len(all_datasets) == 1:
|
|
||||||
return all_datasets[0]
|
|
||||||
elif data_args.mix_strategy == "concat":
|
|
||||||
if data_args.streaming:
|
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
|
||||||
return concatenate_datasets(all_datasets)
|
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
|
||||||
if not data_args.streaming:
|
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
|
||||||
return interleave_datasets(
|
|
||||||
datasets=all_datasets,
|
|
||||||
probabilities=data_args.interleave_probs,
|
|
||||||
seed=training_args.seed,
|
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown mixing strategy.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
@@ -156,6 +132,9 @@ def get_dataset(
|
|||||||
dataset = dataset.to_iterable_dataset()
|
dataset = dataset.to_iterable_dataset()
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
all_datasets = []
|
||||||
for dataset_attr in get_dataset_list(data_args):
|
for dataset_attr in get_dataset_list(data_args):
|
||||||
|
|||||||
@@ -19,13 +19,13 @@ class DatasetAttr:
|
|||||||
|
|
||||||
""" basic configs """
|
""" basic configs """
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: str
|
||||||
""" extra configs """
|
""" extra configs """
|
||||||
file_sha1: Optional[str] = None
|
file_sha1: Optional[str] = None
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: Optional[bool] = False
|
ranking: bool = False
|
||||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
""" columns """
|
""" columns """
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
""" columns for the alpaca format """
|
""" columns for the alpaca format """
|
||||||
|
|||||||
@@ -21,8 +21,11 @@ logger = get_logger(__name__)
|
|||||||
def preprocess_pretrain_dataset(
|
def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
|
if not data_args.packing:
|
||||||
|
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||||
|
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
@@ -34,6 +37,10 @@ def preprocess_pretrain_dataset(
|
|||||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
|
if data_args.template == "gemma":
|
||||||
|
for i in range(len(result["input_ids"])):
|
||||||
|
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -99,12 +106,12 @@ def preprocess_packed_supervised_dataset(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
for source_ids, target_ids in template.encode_multiturn(
|
||||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||||
):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_mask = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif len(input_ids) != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
@@ -122,9 +129,10 @@ def preprocess_packed_supervised_dataset(
|
|||||||
total_length = (total_length // block_size) * block_size
|
total_length = (total_length // block_size) * block_size
|
||||||
# split by chunks of cutoff_len
|
# split by chunks of cutoff_len
|
||||||
for i in range(0, total_length, block_size):
|
for i in range(0, total_length, block_size):
|
||||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||||
model_inputs["attention_mask"].append([1] * block_size)
|
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||||
model_inputs["labels"].append(labels[i : i + block_size])
|
model_inputs["attention_mask"].append([1] * block_size)
|
||||||
|
model_inputs["labels"].append(labels[i : i + block_size])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@@ -145,7 +153,7 @@ def preprocess_unsupervised_dataset(
|
|||||||
if len(examples["response"][i]) == 1:
|
if len(examples["response"][i]) == 1:
|
||||||
messages = examples["prompt"][i] + examples["response"][i]
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
else:
|
else:
|
||||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -180,7 +188,6 @@ def preprocess_pairwise_dataset(
|
|||||||
|
|
||||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chosen_messages,
|
chosen_messages,
|
||||||
@@ -245,7 +252,7 @@ def get_preprocess_and_print_func(
|
|||||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.sft_packing:
|
if data_args.packing:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .utils import Role, infer_max_len
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from .formatter import Formatter
|
from .formatter import SLOTS, Formatter
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -36,8 +36,8 @@ class Template:
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: int = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 1,
|
reserved_label_len: int = 1,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
@@ -56,8 +56,8 @@ class Template:
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: int = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 1,
|
reserved_label_len: int = 1,
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
@@ -88,16 +88,16 @@ class Template:
|
|||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||||
elif message["role"] == Role.ASSISTANT:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
elements += self.format_observation.apply(content=message["content"])
|
elements += self.format_observation.apply(content=message["content"])
|
||||||
elif message["role"] == Role.FUNCTION:
|
elif message["role"] == Role.FUNCTION.value:
|
||||||
elements += self.format_function.apply(content=message["content"])
|
elements += self.format_function.apply(content=message["content"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
@@ -179,16 +179,16 @@ class Llama2Template(Template):
|
|||||||
elif i > 0 and i % 2 == 0:
|
elif i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER:
|
if message["role"] == Role.USER.value:
|
||||||
elements += self.format_user.apply(content=system_text + message["content"])
|
elements += self.format_user.apply(content=system_text + message["content"])
|
||||||
elif message["role"] == Role.ASSISTANT:
|
elif message["role"] == Role.ASSISTANT.value:
|
||||||
elements += self.format_assistant.apply(content=message["content"])
|
elements += self.format_assistant.apply(content=message["content"])
|
||||||
elif message["role"] == Role.OBSERVATION:
|
elif message["role"] == Role.OBSERVATION.value:
|
||||||
elements += self.format_observation.apply(content=message["content"])
|
elements += self.format_observation.apply(content=message["content"])
|
||||||
elif message["role"] == Role.FUNCTION:
|
elif message["role"] == Role.FUNCTION.value:
|
||||||
elements += self.format_function.apply(content=message["content"])
|
elements += self.format_function.apply(content=message["content"])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||||
|
|
||||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||||
|
|
||||||
@@ -207,12 +207,38 @@ def _register_template(
|
|||||||
format_observation: Optional["Formatter"] = None,
|
format_observation: Optional["Formatter"] = None,
|
||||||
format_tools: Optional["Formatter"] = None,
|
format_tools: Optional["Formatter"] = None,
|
||||||
format_separator: Optional["Formatter"] = None,
|
format_separator: Optional["Formatter"] = None,
|
||||||
default_system: Optional[str] = "",
|
default_system: str = "",
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: List[str] = [],
|
||||||
efficient_eos: Optional[bool] = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: Optional[bool] = False,
|
replace_eos: bool = False,
|
||||||
force_system: Optional[bool] = False,
|
force_system: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Registers a chat template.
|
||||||
|
|
||||||
|
To add the following chat template:
|
||||||
|
```
|
||||||
|
[HUMAN]:
|
||||||
|
user prompt here
|
||||||
|
[AI]:
|
||||||
|
model response here
|
||||||
|
|
||||||
|
[HUMAN]:
|
||||||
|
user prompt here
|
||||||
|
[AI]:
|
||||||
|
model response here
|
||||||
|
```
|
||||||
|
|
||||||
|
The corresponding code should be:
|
||||||
|
```
|
||||||
|
_register_template(
|
||||||
|
name="custom",
|
||||||
|
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
|
efficient_eos=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
@@ -238,18 +264,80 @@ def _register_template(
|
|||||||
|
|
||||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||||
is_added = tokenizer.eos_token_id is None
|
is_added = tokenizer.eos_token_id is None
|
||||||
is_oov = eos_token not in tokenizer.get_vocab()
|
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||||
tokenizer.add_special_tokens({"eos_token": eos_token})
|
|
||||||
|
|
||||||
if is_added:
|
if is_added:
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
else:
|
else:
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
if is_oov:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
|
||||||
|
def _jinja_escape(content: str) -> str:
|
||||||
|
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||||
|
slot_items = []
|
||||||
|
for slot in slots:
|
||||||
|
if isinstance(slot, str):
|
||||||
|
slot_pieces = slot.split("{{content}}")
|
||||||
|
if slot_pieces[0]:
|
||||||
|
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
|
||||||
|
if len(slot_pieces) > 1:
|
||||||
|
slot_items.append(placeholder)
|
||||||
|
if slot_pieces[1]:
|
||||||
|
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
||||||
|
elif isinstance(slot, set):
|
||||||
|
if "bos_token" in slot:
|
||||||
|
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||||
|
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||||
|
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||||
|
elif isinstance(slot, dict):
|
||||||
|
raise ValueError("Dict is not supported.")
|
||||||
|
|
||||||
|
return " + ".join(slot_items)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
jinja_template = ""
|
||||||
|
|
||||||
|
if template.default_system:
|
||||||
|
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||||
|
|
||||||
|
jinja_template += (
|
||||||
|
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||||
|
if isinstance(template, Llama2Template):
|
||||||
|
pass
|
||||||
|
elif template.force_system:
|
||||||
|
jinja_template += "{{ " + system_message + " }}"
|
||||||
|
else:
|
||||||
|
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||||
|
|
||||||
|
jinja_template += "{% for message in messages %}"
|
||||||
|
jinja_template += "{% set content = message['content'] %}"
|
||||||
|
if isinstance(template, Llama2Template):
|
||||||
|
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||||
|
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||||
|
jinja_template += "{% endif %}"
|
||||||
|
jinja_template += "{% if message['role'] == 'user' %}"
|
||||||
|
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||||
|
jinja_template += "{{ " + user_message + " }}"
|
||||||
|
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
||||||
|
assistant_message = _convert_slots_to_jinja(
|
||||||
|
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
||||||
|
)
|
||||||
|
jinja_template += "{{ " + assistant_message + " }}"
|
||||||
|
jinja_template += "{% endif %}"
|
||||||
|
jinja_template += "{% endfor %}"
|
||||||
|
return jinja_template
|
||||||
|
|
||||||
|
|
||||||
def get_template_and_fix_tokenizer(
|
def get_template_and_fix_tokenizer(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
@@ -258,7 +346,7 @@ def get_template_and_fix_tokenizer(
|
|||||||
template = templates["vanilla"] # placeholder
|
template = templates["vanilla"] # placeholder
|
||||||
else:
|
else:
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
if templates is None:
|
if template is None:
|
||||||
raise ValueError("Template {} does not exist.".format(name))
|
raise ValueError("Template {} does not exist.".format(name))
|
||||||
|
|
||||||
stop_words = template.stop_words
|
stop_words = template.stop_words
|
||||||
@@ -277,10 +365,17 @@ def get_template_and_fix_tokenizer(
|
|||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
|
if num_added_tokens > 0:
|
||||||
|
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||||
|
except ValueError:
|
||||||
|
logger.info("Cannot add this chat template to tokenizer.")
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@@ -308,16 +403,25 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="atom",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||||
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
format_user=StringFormatter(slots=["<reserved_102>{{content}}<reserved_103>"]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,6 +455,21 @@ _register_template(
|
|||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
|
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
|
),
|
||||||
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
|
efficient_eos=True,
|
||||||
|
force_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="chatglm3_system",
|
||||||
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||||
),
|
),
|
||||||
@@ -367,13 +486,23 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="chatml",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="chatml_de",
|
name="chatml_de",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -405,7 +534,7 @@ _register_template(
|
|||||||
name="deepseekcoder",
|
name="deepseekcoder",
|
||||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
|
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
@@ -433,6 +562,16 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="gemma",
|
||||||
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||||
|
efficient_eos=True,
|
||||||
|
force_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||||
@@ -492,10 +631,19 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="olmo",
|
||||||
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||||
|
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||||
|
force_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
force_system=True,
|
force_system=True,
|
||||||
)
|
)
|
||||||
@@ -530,10 +678,8 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
||||||
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||||
),
|
|
||||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
@@ -614,6 +760,7 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="zephyr",
|
name="zephyr",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||||
)
|
)
|
||||||
@@ -621,6 +768,6 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ import hashlib
|
|||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from datasets import concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import TrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
@@ -46,8 +48,32 @@ def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label
|
|||||||
return max_source_len, max_target_len
|
return max_source_len, max_target_len
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dataset(
|
||||||
|
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
if len(all_datasets) == 1:
|
||||||
|
return all_datasets[0]
|
||||||
|
elif data_args.mix_strategy == "concat":
|
||||||
|
if data_args.streaming:
|
||||||
|
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
return concatenate_datasets(all_datasets)
|
||||||
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
|
if not data_args.streaming:
|
||||||
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
return interleave_datasets(
|
||||||
|
datasets=all_datasets,
|
||||||
|
probabilities=data_args.interleave_probs,
|
||||||
|
seed=training_args.seed,
|
||||||
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> Dict[str, "Dataset"]:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from transformers.utils import cached_file
|
|||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.constants import CHOICES, SUBJECTS
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from ..hparams import get_eval_args
|
from ..hparams import get_eval_args
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from ..model import load_model_and_tokenizer
|
||||||
from .template import get_eval_template
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +23,6 @@ class Evaluator:
|
|||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = [
|
self.choice_inputs = [
|
||||||
|
|||||||
@@ -324,6 +324,29 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Gemma-2B": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
||||||
|
},
|
||||||
|
"Gemma-7B": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-7b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
||||||
|
},
|
||||||
|
"Gemma-2B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
||||||
|
},
|
||||||
|
"Gemma-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gemma",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
@@ -469,6 +492,24 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"OLMo-1B": {
|
||||||
|
DownloadSource.DEFAULT: "allenai/OLMo-1B",
|
||||||
|
},
|
||||||
|
"OLMo-7B": {
|
||||||
|
DownloadSource.DEFAULT: "allenai/OLMo-7B",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/OLMo-7B",
|
||||||
|
},
|
||||||
|
"OLMo-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "allenai/OLMo-7B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
module="att_proj",
|
||||||
|
template="olmo",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"OpenChat3.5-7B-Chat": {
|
"OpenChat3.5-7B-Chat": {
|
||||||
@@ -543,7 +584,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
|
"Qwen-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
||||||
|
},
|
||||||
"Qwen-14B-Chat": {
|
"Qwen-14B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||||
@@ -645,48 +689,48 @@ register_model_group(
|
|||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-0.5B-int4-Chat": {
|
"Qwen1.5-0.5B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int8-Chat": {
|
"Qwen1.5-1.8B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int4-Chat": {
|
"Qwen1.5-1.8B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int8-Chat": {
|
"Qwen1.5-4B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int4-Chat": {
|
"Qwen1.5-4B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int8-Chat": {
|
"Qwen1.5-7B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int4-Chat": {
|
"Qwen1.5-7B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int8-Chat": {
|
"Qwen1.5-14B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int4-Chat": {
|
"Qwen1.5-14B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int8-Chat": {
|
"Qwen1.5-72B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int4-Chat": {
|
"Qwen1.5-72B-int4-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="qwen",
|
template="qwen",
|
||||||
@@ -717,6 +761,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"StarCoder2-3B": {
|
||||||
|
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
||||||
|
},
|
||||||
|
"StarCoder2-7B": {
|
||||||
|
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
||||||
|
},
|
||||||
|
"StarCoder2-15B": {
|
||||||
|
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": {
|
"Vicuna1.5-7B-Chat": {
|
||||||
@@ -807,6 +866,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
||||||
},
|
},
|
||||||
|
"Yi-9B": {
|
||||||
|
DownloadSource.DEFAULT: "01-ai/Yi-9B",
|
||||||
|
DownloadSource.MODELSCOPE: "01ai/Yi-9B",
|
||||||
|
},
|
||||||
"Yi-34B": {
|
"Yi-34B": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
||||||
@@ -823,10 +886,18 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||||
},
|
},
|
||||||
|
"Yi-6B-int4-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||||
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||||
|
},
|
||||||
"Yi-34B-int8-Chat": {
|
"Yi-34B-int8-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||||
},
|
},
|
||||||
|
"Yi-34B-int4-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||||
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="yi",
|
template="yi",
|
||||||
)
|
)
|
||||||
@@ -864,3 +935,18 @@ register_model_group(
|
|||||||
},
|
},
|
||||||
template="zephyr",
|
template="zephyr",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Atom-7B": {
|
||||||
|
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||||
|
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||||
|
},
|
||||||
|
"Atom-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="atom",
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from transformers.utils import (
|
|||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@@ -56,6 +57,17 @@ class AverageMeter:
|
|||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
def check_dependencies() -> None:
|
||||||
|
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
|
||||||
|
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
else:
|
||||||
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
|
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||||
|
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
@@ -69,7 +81,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
num_params = num_params * 2
|
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||||
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
else:
|
||||||
|
num_bytes = 1
|
||||||
|
|
||||||
|
num_params = num_params * 2 * num_bytes
|
||||||
|
|
||||||
all_param += num_params
|
all_param += num_params
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
@@ -145,6 +162,12 @@ def get_current_device() -> torch.device:
|
|||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
def get_device_count() -> int:
|
||||||
|
r"""
|
||||||
|
Gets the number of available GPU devices.
|
||||||
|
"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return 0
|
||||||
|
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ def is_flash_attn2_available():
|
|||||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
|
def is_galore_available():
|
||||||
|
return _is_package_available("galore_torch")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return _is_package_available("jieba")
|
return _is_package_available("jieba")
|
||||||
|
|
||||||
@@ -51,3 +55,7 @@ def is_unsloth_available():
|
|||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
|
||||||
|
def is_vllm_available():
|
||||||
|
return _is_package_available("vllm")
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# Modified from:
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_torch_attn_forward(
|
def llama_torch_attn_forward(
|
||||||
self: "LlamaAttention",
|
self: "LlamaAttention",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -24,6 +26,7 @@ def llama_torch_attn_forward(
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional["Cache"] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@@ -36,15 +39,12 @@ def llama_torch_attn_forward(
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
if past_key_value is not None:
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
@@ -96,14 +96,16 @@ def llama_torch_attn_forward(
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# Modified from:
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_flash_attn_forward(
|
def llama_flash_attn_forward(
|
||||||
self: "LlamaFlashAttention2",
|
self: "LlamaFlashAttention2",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional["Cache"] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# LlamaFlashAttention2 attention does not support output_attentions
|
# LlamaFlashAttention2 attention does not support output_attentions
|
||||||
@@ -120,15 +122,13 @@ def llama_flash_attn_forward(
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
if past_key_value is not None:
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
@@ -193,5 +193,6 @@ def llama_flash_attn_forward(
|
|||||||
|
|
||||||
|
|
||||||
def apply_llama_patch() -> None:
|
def apply_llama_patch() -> None:
|
||||||
|
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
|
||||||
LlamaAttention.forward = llama_torch_attn_forward
|
LlamaAttention.forward = llama_torch_attn_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
return smoothed
|
return smoothed
|
||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
@@ -46,11 +46,12 @@ def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||||
plt.title("training {} of {}".format(key, save_dictionary))
|
plt.title("training {} of {}".format(key, save_dictionary))
|
||||||
plt.xlabel("step")
|
plt.xlabel("step")
|
||||||
plt.ylabel(key)
|
plt.ylabel(key)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace(os.path.sep, "_")))
|
||||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
plt.savefig(figure_path, format="png", dpi=100)
|
||||||
|
print("Figure saved at:", figure_path)
|
||||||
|
|||||||
@@ -16,35 +16,35 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: Optional[str] = field(
|
dataset_dir: str = field(
|
||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
split: Optional[str] = field(
|
split: str = field(
|
||||||
default="train",
|
default="train",
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||||
)
|
)
|
||||||
streaming: Optional[bool] = field(
|
streaming: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable dataset streaming."},
|
metadata={"help": "Enable dataset streaming."},
|
||||||
)
|
)
|
||||||
buffer_size: Optional[int] = field(
|
buffer_size: int = field(
|
||||||
default=16384,
|
default=16384,
|
||||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
@@ -52,13 +52,13 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -68,23 +68,25 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||||
)
|
)
|
||||||
sft_packing: Optional[bool] = field(
|
packing: Optional[bool] = field(
|
||||||
default=False,
|
default=None,
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
|
metadata={
|
||||||
|
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
cache_path: Optional[str] = field(
|
cache_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the preprocessed datasets."},
|
metadata={"help": "Path to save or load the pre-processed datasets."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
@@ -14,23 +14,23 @@ class EvaluationArguments:
|
|||||||
task: str = field(
|
task: str = field(
|
||||||
metadata={"help": "Name of the evaluation task."},
|
metadata={"help": "Name of the evaluation task."},
|
||||||
)
|
)
|
||||||
task_dir: Optional[str] = field(
|
task_dir: str = field(
|
||||||
default="evaluation",
|
default="evaluation",
|
||||||
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||||
)
|
)
|
||||||
batch_size: Optional[int] = field(
|
batch_size: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The batch size per GPU for evaluation."},
|
metadata={"help": "The batch size per GPU for evaluation."},
|
||||||
)
|
)
|
||||||
seed: Optional[int] = field(
|
seed: int = field(
|
||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed to be used with data loaders."},
|
metadata={"help": "Random seed to be used with data loaders."},
|
||||||
)
|
)
|
||||||
lang: Optional[Literal["en", "zh"]] = field(
|
lang: Literal["en", "zh"] = field(
|
||||||
default="en",
|
default="en",
|
||||||
metadata={"help": "Language used at evaluation."},
|
metadata={"help": "Language used at evaluation."},
|
||||||
)
|
)
|
||||||
n_shot: Optional[int] = field(
|
n_shot: int = field(
|
||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."},
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
@@ -38,7 +38,7 @@ class EvaluationArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."},
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
download_mode: Optional[DownloadMode] = field(
|
download_mode: DownloadMode = field(
|
||||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||||
metadata={"help": "Download mode used for the evaluation datasets."},
|
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ class FreezeArguments:
|
|||||||
Arguments pertaining to the freeze (partial-parameter) training.
|
Arguments pertaining to the freeze (partial-parameter) training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name_module_trainable: Optional[str] = field(
|
name_module_trainable: str = field(
|
||||||
default=None,
|
default="all",
|
||||||
metadata={
|
metadata={
|
||||||
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
Use commas to separate multiple modules. \
|
Use commas to separate multiple modules. \
|
||||||
@@ -22,14 +22,10 @@ class FreezeArguments:
|
|||||||
Others choices: the same as LLaMA."""
|
Others choices: the same as LLaMA."""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: int = field(
|
||||||
default=3,
|
default=2,
|
||||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||||
)
|
)
|
||||||
use_llama_pro: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -48,20 +44,20 @@ class LoraArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: float = field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: int = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||||
)
|
)
|
||||||
lora_target: Optional[str] = field(
|
lora_target: str = field(
|
||||||
default=None,
|
default="all",
|
||||||
metadata={
|
metadata={
|
||||||
"help": """Name(s) of target modules to apply LoRA. \
|
"help": """Name(s) of target modules to apply LoRA. \
|
||||||
Use commas to separate multiple modules. \
|
Use commas to separate multiple modules. \
|
||||||
Use "all" to specify all the available modules. \
|
Use "all" to specify all the linear modules. \
|
||||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
@@ -70,15 +66,23 @@ class LoraArguments:
|
|||||||
Others choices: the same as LLaMA."""
|
Others choices: the same as LLaMA."""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_bf16_mode: Optional[bool] = field(
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
default=False,
|
default=None,
|
||||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||||
)
|
)
|
||||||
use_rslora: Optional[bool] = field(
|
loraplus_lr_embedding: float = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
|
use_rslora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||||
)
|
)
|
||||||
create_new_adapter: Optional[bool] = field(
|
use_dora: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||||
|
)
|
||||||
|
create_new_adapter: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||||
)
|
)
|
||||||
@@ -90,23 +94,23 @@ class RLHFArguments:
|
|||||||
Arguments pertaining to the PPO and DPO training.
|
Arguments pertaining to the PPO and DPO training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: float = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."},
|
metadata={"help": "The beta parameter for the DPO loss."},
|
||||||
)
|
)
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
dpo_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."},
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
)
|
)
|
||||||
dpo_ftx: Optional[float] = field(
|
dpo_ftx: float = field(
|
||||||
default=0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
ppo_buffer_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
ppo_epochs: Optional[int] = field(
|
ppo_epochs: int = field(
|
||||||
default=4,
|
default=4,
|
||||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||||
)
|
)
|
||||||
@@ -114,15 +118,15 @@ class RLHFArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
ppo_score_norm: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use score normalization in PPO training."},
|
metadata={"help": "Use score normalization in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: float = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||||
)
|
)
|
||||||
ppo_whiten_rewards: Optional[bool] = field(
|
ppo_whiten_rewards: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
@@ -150,31 +154,74 @@ class RLHFArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."},
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
reward_model_type: Literal["lora", "full", "api"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
class GaloreArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the GaLore algorithm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_galore: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||||
|
)
|
||||||
|
galore_target: str = field(
|
||||||
|
default="all",
|
||||||
|
metadata={
|
||||||
|
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
|
||||||
|
Use "all" to specify all the linear modules."""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
galore_rank: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={"help": "The rank of GaLore gradients."},
|
||||||
|
)
|
||||||
|
galore_update_interval: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||||
|
)
|
||||||
|
galore_scale: float = field(
|
||||||
|
default=0.25,
|
||||||
|
metadata={"help": "GaLore scaling coefficient."},
|
||||||
|
)
|
||||||
|
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
||||||
|
default="std",
|
||||||
|
metadata={"help": "Type of GaLore projection."},
|
||||||
|
)
|
||||||
|
galore_layerwise: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
pure_bf16: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||||
|
)
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."},
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
disable_version_checking: Optional[bool] = field(
|
use_llama_pro: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable version checking."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
@@ -189,19 +236,23 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
|
self.galore_target = split_arg(self.galore_target)
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||||
|
|
||||||
if self.use_llama_pro and self.finetuning_type != "freeze":
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
|
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||||
|
|
||||||
|
if self.use_galore and self.finetuning_type == "lora":
|
||||||
|
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -8,41 +8,41 @@ class GeneratingArguments:
|
|||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
||||||
)
|
)
|
||||||
temperature: Optional[float] = field(
|
temperature: float = field(
|
||||||
default=0.95,
|
default=0.95,
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."},
|
metadata={"help": "The value used to modulate the next token probabilities."},
|
||||||
)
|
)
|
||||||
top_p: Optional[float] = field(
|
top_p: float = field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
top_k: Optional[int] = field(
|
top_k: int = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||||
)
|
)
|
||||||
num_beams: Optional[int] = field(
|
num_beams: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||||
)
|
)
|
||||||
max_length: Optional[int] = field(
|
max_length: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: int = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: Optional[float] = field(
|
repetition_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
||||||
)
|
)
|
||||||
length_penalty: Optional[float] = field(
|
length_penalty: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, Literal, Optional
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
@@ -21,31 +21,35 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
use_fast_tokenizer: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||||
)
|
)
|
||||||
resize_vocab: Optional[bool] = field(
|
resize_vocab: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
||||||
)
|
)
|
||||||
split_special_tokens: Optional[bool] = field(
|
split_special_tokens: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
model_revision: Optional[str] = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
)
|
)
|
||||||
|
low_cpu_mem_usage: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||||
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model."},
|
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
||||||
)
|
)
|
||||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
quantization_type: Literal["fp4", "nf4"] = field(
|
||||||
default="nf4",
|
default="nf4",
|
||||||
metadata={"help": "Quantization data type to use in int4 training."},
|
metadata={"help": "Quantization data type to use in int4 training."},
|
||||||
)
|
)
|
||||||
double_quantization: Optional[bool] = field(
|
double_quantization: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||||
)
|
)
|
||||||
@@ -53,30 +57,54 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||||
)
|
)
|
||||||
shift_attn: Optional[bool] = field(
|
shift_attn: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
use_unsloth: Optional[bool] = field(
|
use_unsloth: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||||
)
|
)
|
||||||
disable_gradient_checkpointing: Optional[bool] = field(
|
disable_gradient_checkpointing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
upcast_layernorm: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
||||||
)
|
)
|
||||||
upcast_lmhead_output: Optional[bool] = field(
|
upcast_lmhead_output: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||||
)
|
)
|
||||||
|
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||||
|
default="huggingface",
|
||||||
|
metadata={"help": "Backend engine used at inference."},
|
||||||
|
)
|
||||||
|
vllm_maxlen: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "Maximum input length of the vLLM engine."},
|
||||||
|
)
|
||||||
|
vllm_gpu_util: float = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
||||||
|
)
|
||||||
|
vllm_enforce_eager: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||||
|
)
|
||||||
|
offload_folder: str = field(
|
||||||
|
default="offload",
|
||||||
|
metadata={"help": "Path to offload model weights."},
|
||||||
|
)
|
||||||
|
use_cache: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
@@ -89,7 +117,7 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory to save the exported model."},
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
)
|
)
|
||||||
export_size: Optional[int] = field(
|
export_size: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||||
)
|
)
|
||||||
@@ -101,15 +129,15 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
)
|
)
|
||||||
export_quantization_nsamples: Optional[int] = field(
|
export_quantization_nsamples: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={"help": "The number of samples used for quantization."},
|
metadata={"help": "The number of samples used for quantization."},
|
||||||
)
|
)
|
||||||
export_quantization_maxlen: Optional[int] = field(
|
export_quantization_maxlen: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||||
)
|
)
|
||||||
export_legacy_format: Optional[bool] = field(
|
export_legacy_format: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
)
|
)
|
||||||
@@ -117,13 +145,14 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
)
|
)
|
||||||
print_param_status: Optional[bool] = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
self.device_map = None
|
||||||
self.model_max_length = None
|
self.model_max_length = None
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
|
|||||||
@@ -3,14 +3,15 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.misc import check_dependencies
|
||||||
from ..extras.packages import is_unsloth_available
|
from ..extras.packages import is_unsloth_available
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
@@ -22,6 +23,9 @@ from .model_args import ModelArguments
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
@@ -30,17 +34,6 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
|
|||||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
|
||||||
def _check_dependencies(disabled: bool) -> None:
|
|
||||||
if disabled:
|
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
|
||||||
else:
|
|
||||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
|
||||||
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
|
|
||||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
@@ -62,13 +55,15 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||||
datasets.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||||
|
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
@@ -79,9 +74,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
|
||||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
@@ -133,21 +125,37 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and training_args.predict_with_generate:
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
|
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available():
|
||||||
|
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||||
|
|
||||||
|
if finetuning_args.use_dora:
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||||
|
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
raise ValueError("Unsloth does not support DoRA.")
|
||||||
|
|
||||||
|
if finetuning_args.pure_bf16:
|
||||||
|
if not is_torch_bf16_gpu_available():
|
||||||
|
raise ValueError("This device does not support `pure_bf16`.")
|
||||||
|
|
||||||
|
if training_args.fp16 or training_args.bf16:
|
||||||
|
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
finetuning_args.use_galore
|
||||||
and finetuning_args.finetuning_type == "freeze"
|
and finetuning_args.galore_layerwise
|
||||||
and finetuning_args.name_module_trainable is None
|
and training_args.parallel_mode.value == "distributed"
|
||||||
):
|
):
|
||||||
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
|
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||||
|
|
||||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
||||||
|
|
||||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
if model_args.infer_backend == "vllm":
|
||||||
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
@@ -163,6 +171,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
|
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||||
|
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
@@ -171,14 +182,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
training_args.local_rank != -1
|
training_args.parallel_mode.value == "distributed"
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args_dict = training_args.to_dict()
|
training_args.ddp_find_unused_parameters = False
|
||||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
can_resume_from_checkpoint = False
|
can_resume_from_checkpoint = False
|
||||||
@@ -200,9 +209,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args_dict = training_args.to_dict()
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
@@ -221,22 +228,24 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Post-process model arguments
|
# Post-process model arguments
|
||||||
model_args.compute_dtype = (
|
if training_args.bf16 or finetuning_args.pure_bf16:
|
||||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
model_args.compute_dtype = torch.bfloat16
|
||||||
)
|
elif training_args.fp16:
|
||||||
|
model_args.compute_dtype = torch.float16
|
||||||
|
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
|
||||||
training_args.local_rank,
|
training_args.local_rank,
|
||||||
training_args.device,
|
training_args.device,
|
||||||
training_args.n_gpu,
|
training_args.n_gpu,
|
||||||
bool(training_args.local_rank != -1),
|
training_args.parallel_mode.value == "distributed",
|
||||||
str(model_args.compute_dtype),
|
str(model_args.compute_dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
|
||||||
|
|
||||||
transformers.set_seed(training_args.seed)
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
@@ -247,12 +256,27 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
if finetuning_args.stage != "sft":
|
||||||
|
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
raise ValueError("vLLM engine does not support quantization.")
|
||||||
|
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||||
|
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
@@ -260,12 +284,17 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
|
if model_args.infer_backend == "vllm":
|
||||||
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, eval_args, finetuning_args
|
return model_args, data_args, eval_args, finetuning_args
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from .loader import load_model_and_tokenizer
|
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
||||||
from .utils import dispatch_model, load_valuehead_params
|
from .utils import find_all_linear_modules, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]
|
__all__ = [
|
||||||
|
"load_model",
|
||||||
|
"load_model_and_tokenizer",
|
||||||
|
"load_tokenizer",
|
||||||
|
"load_valuehead_params",
|
||||||
|
"find_all_linear_modules",
|
||||||
|
]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import find_all_linear_modules
|
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -34,7 +34,8 @@ def init_adapter(
|
|||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
model = model.float()
|
if not finetuning_args.pure_bf16:
|
||||||
|
model = model.float()
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
@@ -78,12 +79,15 @@ def init_adapter(
|
|||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
param.data = param.data.to(torch.float32)
|
if not finetuning_args.pure_bf16:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
logger.info("Fine-tuning method: LoRA")
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
@@ -103,14 +107,18 @@ def init_adapter(
|
|||||||
adapter_to_merge = model_args.adapter_name_or_path
|
adapter_to_merge = model_args.adapter_name_or_path
|
||||||
|
|
||||||
for adapter in adapter_to_merge:
|
for adapter in adapter_to_merge:
|
||||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
|
model: "LoraModel" = PeftModel.from_pretrained(
|
||||||
|
model, adapter, offload_folder=model_args.offload_folder
|
||||||
|
)
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
|
model = PeftModel.from_pretrained(
|
||||||
|
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
|
||||||
|
)
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@@ -118,6 +126,13 @@ def init_adapter(
|
|||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
|
if finetuning_args.use_llama_pro:
|
||||||
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
|
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
||||||
|
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
||||||
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
"target_modules": target_modules,
|
"target_modules": target_modules,
|
||||||
@@ -136,12 +151,14 @@ def init_adapter(
|
|||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
inference_mode=False,
|
inference_mode=False,
|
||||||
modules_to_save=finetuning_args.additional_target,
|
modules_to_save=finetuning_args.additional_target,
|
||||||
|
use_dora=finetuning_args.use_dora,
|
||||||
**peft_kwargs,
|
**peft_kwargs,
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
if not finetuning_args.pure_bf16:
|
||||||
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@@ -20,38 +19,48 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||||
model_args: "ModelArguments",
|
return {
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
is_trainable: Optional[bool] = False,
|
|
||||||
add_valuehead: Optional[bool] = False,
|
|
||||||
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
|
||||||
r"""
|
|
||||||
Loads pretrained model and tokenizer.
|
|
||||||
|
|
||||||
Support both training and inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try_download_model_from_ms(model_args)
|
|
||||||
|
|
||||||
config_kwargs = {
|
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
"revision": model_args.model_revision,
|
"revision": model_args.model_revision,
|
||||||
"token": model_args.hf_hub_token,
|
"token": model_args.hf_hub_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||||
|
r"""
|
||||||
|
Loads pretrained tokenizer. Must before load_model.
|
||||||
|
|
||||||
|
Note: including inplace operation of model_args.
|
||||||
|
"""
|
||||||
|
try_download_model_from_ms(model_args)
|
||||||
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
split_special_tokens=model_args.split_special_tokens,
|
split_special_tokens=model_args.split_special_tokens,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
**config_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
|
||||||
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
def load_model(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool = False,
|
||||||
|
add_valuehead: bool = False,
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Loads pretrained model. Must after load_tokenizer.
|
||||||
|
"""
|
||||||
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
if is_trainable and model_args.use_unsloth:
|
if is_trainable and model_args.use_unsloth:
|
||||||
@@ -77,13 +86,7 @@ def load_model_and_tokenizer(
|
|||||||
logger.warning("Unsloth does not support loading adapters.")
|
logger.warning("Unsloth does not support loading adapters.")
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||||
model_args.model_name_or_path,
|
|
||||||
config=config,
|
|
||||||
torch_dtype=model_args.compute_dtype,
|
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
|
||||||
**config_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
register_autoclass(config, model, tokenizer)
|
register_autoclass(config, model, tokenizer)
|
||||||
@@ -106,20 +109,21 @@ def load_model_and_tokenizer(
|
|||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.device.type == "cuda":
|
||||||
|
param.data = param.data.to(model_args.compute_dtype)
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
logger.info(
|
if is_trainable:
|
||||||
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
param_stats = "all params: {:d}".format(all_param)
|
||||||
if not is_trainable:
|
logger.info(param_stats)
|
||||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
|
||||||
|
|
||||||
if model_args.print_param_status:
|
if model_args.print_param_status:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
@@ -129,4 +133,18 @@ def load_model_and_tokenizer(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool = False,
|
||||||
|
add_valuehead: bool = False,
|
||||||
|
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
||||||
|
r"""
|
||||||
|
Loads pretrained model and tokenizer.
|
||||||
|
"""
|
||||||
|
tokenizer = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, is_trainable, add_valuehead)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
|
|||||||
from ..extras.packages import is_flash_attn2_available
|
from ..extras.packages import is_flash_attn2_available
|
||||||
from ..extras.patches.llama_patch import apply_llama_patch
|
from ..extras.patches.llama_patch import apply_llama_patch
|
||||||
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
||||||
|
from .utils import QuantizationMethod
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -102,19 +103,27 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
|
def _configure_attn_implementation(
|
||||||
|
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if is_flash_attn2_available():
|
if not is_flash_attn2_available():
|
||||||
config_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
else:
|
|
||||||
logger.warning("FlashAttention2 is not installed.")
|
logger.warning("FlashAttention2 is not installed.")
|
||||||
config_kwargs["attn_implementation"] = None
|
return
|
||||||
|
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
|
setattr(config, "attn_implementation", "flash_attention_2")
|
||||||
|
else:
|
||||||
|
init_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
else:
|
else:
|
||||||
config_kwargs["attn_implementation"] = "eager"
|
init_kwargs["attn_implementation"] = "eager"
|
||||||
|
|
||||||
|
|
||||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if model_args.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
if not hasattr(config, "rope_scaling"):
|
if not hasattr(config, "rope_scaling"):
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
return
|
return
|
||||||
@@ -141,7 +150,10 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _configure_longlora(config: "PretrainedConfig") -> None:
|
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if not is_trainable or not model_args.shift_attn:
|
||||||
|
return
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
apply_llama_patch()
|
apply_llama_patch()
|
||||||
@@ -154,20 +166,29 @@ def _configure_quantization(
|
|||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
config_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # gptq
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
|
||||||
|
if quant_method == QuantizationMethod.GPTQ:
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
|
||||||
|
if quant_method == QuantizationMethod.AQLM:
|
||||||
|
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||||
|
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||||
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
|
quant_bits = quantization_config.get("bits", "?")
|
||||||
|
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
@@ -177,38 +198,41 @@ def _configure_quantization(
|
|||||||
if getattr(config, "model_type", None) == "chatglm":
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
raise ValueError("ChatGLM model is not supported.")
|
raise ValueError("ChatGLM model is not supported.")
|
||||||
|
|
||||||
config_kwargs["quantization_config"] = GPTQConfig(
|
init_kwargs["quantization_config"] = GPTQConfig(
|
||||||
bits=model_args.export_quantization_bit,
|
bits=model_args.export_quantization_bit,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||||
)
|
)
|
||||||
config_kwargs["device_map"] = "auto"
|
init_kwargs["device_map"] = "auto"
|
||||||
config_kwargs["max_memory"] = get_max_memory()
|
init_kwargs["max_memory"] = get_max_memory()
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||||
|
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||||
|
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
elif model_args.quantization_bit == 4:
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
bnb_4bit_quant_type=model_args.quantization_type,
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
|
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||||
)
|
)
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
def _prepare_model_for_training(
|
def _prepare_model_for_training(
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
|
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Includes:
|
Includes:
|
||||||
@@ -218,10 +242,10 @@ def _prepare_model_for_training(
|
|||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
|
logger.info("Upcasting layernorm weights in float32.")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
if not model_args.disable_gradient_checkpointing:
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
@@ -231,7 +255,7 @@ def _prepare_model_for_training(
|
|||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||||
@@ -239,6 +263,7 @@ def _prepare_model_for_training(
|
|||||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||||
return output.to(torch.float32)
|
return output.to(torch.float32)
|
||||||
|
|
||||||
|
logger.info("Upcasting lm_head outputs in float32.")
|
||||||
output_layer = getattr(model, output_layer_name)
|
output_layer = getattr(model, output_layer_name)
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||||
@@ -253,25 +278,35 @@ def patch_config(
|
|||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
config_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
|
setattr(config, "use_flash_attn", model_args.flash_attn)
|
||||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||||
|
|
||||||
_configure_attn_implementation(model_args, config_kwargs)
|
_configure_attn_implementation(config, model_args, init_kwargs)
|
||||||
|
_configure_rope(config, model_args, is_trainable)
|
||||||
|
_configure_longlora(config, model_args, is_trainable)
|
||||||
|
_configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
|
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.use_cache and not is_trainable:
|
||||||
_configure_rope(config, model_args, is_trainable)
|
setattr(config, "use_cache", True)
|
||||||
|
logger.info("Using KV cache for faster generation.")
|
||||||
|
|
||||||
if is_trainable and model_args.shift_attn:
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
_configure_longlora(config)
|
if not is_deepspeed_zero3_enabled():
|
||||||
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
|
if init_kwargs["low_cpu_mem_usage"]:
|
||||||
|
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||||
|
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||||
|
|
||||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
if init_kwargs["device_map"] == "auto":
|
||||||
|
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||||
|
|
||||||
|
|
||||||
def patch_model(
|
def patch_model(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import inspect
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -7,7 +7,6 @@ from transformers.utils import cached_file
|
|||||||
|
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -19,34 +18,16 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
@unique
|
||||||
|
class QuantizationMethod(str, Enum):
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
|
||||||
"""
|
"""
|
||||||
if getattr(model, "quantization_method", None): # already set on current device
|
|
||||||
return model
|
|
||||||
|
|
||||||
if (
|
BITS_AND_BYTES = "bitsandbytes"
|
||||||
torch.cuda.device_count() > 1
|
GPTQ = "gptq"
|
||||||
and isinstance(model, PreTrainedModel)
|
AWQ = "awq"
|
||||||
and model._no_split_modules is not None
|
AQLM = "aqlm"
|
||||||
and model.config.model_type != "chatglm"
|
|
||||||
):
|
|
||||||
from accelerate import dispatch_model
|
|
||||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
|
||||||
|
|
||||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
|
||||||
max_memory = get_balanced_memory(model, **kwargs)
|
|
||||||
# Make sure tied weights are tied before creating the device map.
|
|
||||||
model.tie_weights()
|
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
|
||||||
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
|
||||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
|
||||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
|
||||||
return dispatch_model(model, **device_map_kwargs)
|
|
||||||
else:
|
|
||||||
return model.to(device=get_current_device())
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
@@ -56,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
|||||||
quantization_method = getattr(model, "quantization_method", None)
|
quantization_method = getattr(model, "quantization_method", None)
|
||||||
if quantization_method is None:
|
if quantization_method is None:
|
||||||
linear_cls = torch.nn.Linear
|
linear_cls = torch.nn.Linear
|
||||||
elif quantization_method == "bitsandbytes":
|
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||||
@@ -76,6 +57,33 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
|||||||
return list(module_names)
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||||
|
r"""
|
||||||
|
Finds the modules in the expanded blocks to apply lora.
|
||||||
|
"""
|
||||||
|
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||||
|
if not num_layers:
|
||||||
|
raise ValueError("Model was not supported.")
|
||||||
|
|
||||||
|
if num_layers % num_layer_trainable != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||||
|
)
|
||||||
|
|
||||||
|
stride = num_layers // num_layer_trainable
|
||||||
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
|
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||||
|
module_names = []
|
||||||
|
for name, _ in model.named_modules():
|
||||||
|
if any(target_module in name for target_module in target_modules) and any(
|
||||||
|
trainable_layer in name for trainable_layer in trainable_layers
|
||||||
|
):
|
||||||
|
module_names.append(name)
|
||||||
|
|
||||||
|
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
|||||||
@@ -8,21 +8,25 @@ from trl import DPOTrainer
|
|||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
class CustomDPOTrainer(DPOTrainer):
|
class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||||
ftx_gamma: float,
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: Optional[bool] = True,
|
disable_dropout: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if disable_dropout:
|
if disable_dropout:
|
||||||
@@ -30,6 +34,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
self.reference_free = False
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
@@ -60,6 +66,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
|
|
||||||
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
@@ -94,7 +107,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
batch: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
train_eval: Optional[Literal["train", "eval"]] = "train",
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
|
|||||||
@@ -2,20 +2,18 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.dpo.collator import DPODataCollatorWithPadding
|
from ..utils import create_modelcard_and_push, create_ref_model
|
||||||
from ...train.dpo.trainer import CustomDPOTrainer
|
from .collator import DPODataCollatorWithPadding
|
||||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
from .trainer import CustomDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
from ...hparams import DataArguments, FinetuningArguments
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
@@ -27,8 +25,9 @@ def run_dpo(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
tokenizer = load_tokenizer(model_args)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
@@ -42,15 +41,14 @@ def run_dpo(
|
|||||||
ref_model = create_ref_model(model_args, finetuning_args)
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
loss_type=finetuning_args.dpo_loss,
|
loss_type=finetuning_args.dpo_loss,
|
||||||
ftx_gamma=finetuning_args.dpo_ftx,
|
ftx_gamma=finetuning_args.dpo_ftx,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||||||
|
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +49,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.reward_model = reward_model
|
self.reward_model = reward_model
|
||||||
|
self.current_device = get_current_device() # patch for deepspeed training
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
@@ -291,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
queries: torch.Tensor,
|
queries: torch.Tensor,
|
||||||
responses: torch.Tensor,
|
responses: torch.Tensor,
|
||||||
model_inputs: dict,
|
model_inputs: dict,
|
||||||
return_logits: Optional[bool] = False,
|
return_logits: bool = False,
|
||||||
response_masks: Optional[torch.Tensor] = None,
|
response_masks: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ from ...data import get_dataset
|
|||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.ppo.trainer import CustomPPOTrainer
|
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
|
||||||
from ...train.utils import create_ref_model, create_reward_model
|
from .trainer import CustomPPOTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -31,10 +31,9 @@ def run_ppo(
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(
|
tokenizer = load_tokenizer(model_args)
|
||||||
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
|
||||||
)
|
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="ppo")
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
@@ -61,16 +60,20 @@ def run_ppo(
|
|||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||||
|
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
|
||||||
if training_args.max_steps > 0:
|
if training_args.max_steps > 0:
|
||||||
num_training_steps = training_args.max_steps
|
num_training_steps = training_args.max_steps
|
||||||
else:
|
else:
|
||||||
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
|
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps)
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|||||||
30
src/llmtuner/train/pt/trainer.py
Normal file
30
src/llmtuner/train/pt/trainer.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTrainer(Trainer):
|
||||||
|
r"""
|
||||||
|
Inherits Trainer for custom optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
@@ -3,12 +3,13 @@
|
|||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ..utils import create_modelcard_and_push
|
||||||
|
from .trainer import CustomTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -24,14 +25,16 @@ def run_pt(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
|
tokenizer = load_tokenizer(model_args)
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="pt")
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = CustomTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|||||||
@@ -1,32 +1,43 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PairwiseTrainer(Trainer):
|
class PairwiseTrainer(Trainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PeftTrainer to compute pairwise loss.
|
Inherits Trainer to compute pairwise loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps)
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.create_optimizer()
|
||||||
|
|
||||||
|
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
r"""
|
r"""
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
@@ -34,7 +45,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
Note that the first element will be removed from the output tuple.
|
Note that the first element will be removed from the output tuple.
|
||||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
|
||||||
"""
|
"""
|
||||||
# Compute rewards
|
# Compute rewards
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|||||||
@@ -2,21 +2,19 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from ...extras.callbacks import FixValueHeadModelCallback
|
from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model_and_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
from ..utils import create_modelcard_and_push
|
||||||
from ...train.rm.metric import compute_accuracy
|
from .collator import PairwiseDataCollatorWithPadding
|
||||||
from ...train.rm.trainer import PairwiseTrainer
|
from .metric import compute_accuracy
|
||||||
from ...train.utils import create_modelcard_and_push
|
from .trainer import PairwiseTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
@@ -28,21 +26,19 @@ def run_rm(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(
|
tokenizer = load_tokenizer(model_args)
|
||||||
model_args, finetuning_args, training_args.do_train, add_valuehead=True
|
|
||||||
)
|
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PairwiseTrainer(
|
trainer = PairwiseTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user