Compare commits
116 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48d2e6d7fe | ||
|
|
041c83ea03 | ||
|
|
0e621c2dc9 | ||
|
|
544e7a491b | ||
|
|
a2c881fa08 | ||
|
|
c53c7af168 | ||
|
|
a2d93e5269 | ||
|
|
b392e6cfb9 | ||
|
|
13aa2d389a | ||
|
|
1e7962dfc4 | ||
|
|
1c9556c84c | ||
|
|
ca3ca7a5b5 | ||
|
|
0500befdb4 | ||
|
|
f618feab51 | ||
|
|
4b06aa134f | ||
|
|
9cde56d760 | ||
|
|
d0ea203694 | ||
|
|
c5eb3fba62 | ||
|
|
a8bc32553c | ||
|
|
88f3358320 | ||
|
|
a85bdcf2f6 | ||
|
|
caf56b313e | ||
|
|
75603c45fc | ||
|
|
89f86cc970 | ||
|
|
c09a0e4f08 | ||
|
|
7bac6c9460 | ||
|
|
0c7d0bf172 | ||
|
|
a274900188 | ||
|
|
67deefe527 | ||
|
|
823f618cba | ||
|
|
bc16c9a54a | ||
|
|
a3f30038a0 | ||
|
|
e237f618c2 | ||
|
|
688adad665 | ||
|
|
0158812afb | ||
|
|
e52e0d9b07 | ||
|
|
eb2aa2c073 | ||
|
|
debfd46749 | ||
|
|
5ccf8fcd6b | ||
|
|
7bd1991513 | ||
|
|
456e4ca569 | ||
|
|
6bf0fe4913 | ||
|
|
596b6828cb | ||
|
|
b403f8d8a8 | ||
|
|
590b6c2143 | ||
|
|
5537ef1e7d | ||
|
|
5f83860aa1 | ||
|
|
62b6a7971a | ||
|
|
1d16e87c5f | ||
|
|
1955a8ea5a | ||
|
|
a41fa6e730 | ||
|
|
b98a64448a | ||
|
|
1ce82f391a | ||
|
|
4d473894fd | ||
|
|
5788b7c7d0 | ||
|
|
04515f6b55 | ||
|
|
96f8ccf3d5 | ||
|
|
2c3ef480a6 | ||
|
|
fa6873122c | ||
|
|
34bc0c22b1 | ||
|
|
e5484b2729 | ||
|
|
f67f781fed | ||
|
|
b564b97b7e | ||
|
|
0dd68d1e06 | ||
|
|
73f40f1ca4 | ||
|
|
ea53bebac4 | ||
|
|
00418012bd | ||
|
|
5f3d8c514b | ||
|
|
cb39a3f1c4 | ||
|
|
4d78fe6ece | ||
|
|
a3e3ea9846 | ||
|
|
feba34e82d | ||
|
|
e134013e04 | ||
|
|
5589d0296a | ||
|
|
de0ebab464 | ||
|
|
f2e7122a96 | ||
|
|
996cc5d900 | ||
|
|
a2ae5bd867 | ||
|
|
5fa52e87cb | ||
|
|
bcd76d2c7a | ||
|
|
36fcbedc11 | ||
|
|
1dad01cc53 | ||
|
|
5fb21f6e54 | ||
|
|
08dfac8352 | ||
|
|
956751e419 | ||
|
|
fe2ae04c91 | ||
|
|
5b8712d061 | ||
|
|
dc7ff90c1e | ||
|
|
1ace676170 | ||
|
|
8947a87b95 | ||
|
|
786a2f1103 | ||
|
|
36ac14a566 | ||
|
|
7a048fc91d | ||
|
|
3f3756b113 | ||
|
|
b36c4b99cc | ||
|
|
9856a2276e | ||
|
|
b6dc3ed3ad | ||
|
|
75be329994 | ||
|
|
1fe1ca1c8b | ||
|
|
882a6a1d51 | ||
|
|
712ab4ae7a | ||
|
|
18ad259fb3 | ||
|
|
fe4d93c6db | ||
|
|
c6ba588e37 | ||
|
|
3fda60fca0 | ||
|
|
96531a0ef8 | ||
|
|
7abc3065fb | ||
|
|
013ded4bac | ||
|
|
010c3c7348 | ||
|
|
bf075c075c | ||
|
|
41b34e5f60 | ||
|
|
5a889398e7 | ||
|
|
054cae86d8 | ||
|
|
cd1cb8b83c | ||
|
|
a34779c027 | ||
|
|
d19cb77d74 |
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
# What does this PR do?
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Before submitting
|
||||
|
||||
- [ ] Did you read the [contributor guideline](/CONTRIBUTING.md)?
|
||||
29
.github/workflows/tests.yml
vendored
Normal file
29
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ruff
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
21
CONTRIBUTING.md
Normal file
21
CONTRIBUTING.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Contributing to LLaMA Factory
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are several ways you can contribute to LLaMA Factory:
|
||||
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Contribute to the examples or to the documentation.
|
||||
|
||||
### Style guide
|
||||
|
||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||
4
Makefile
4
Makefile
@@ -3,9 +3,9 @@
|
||||
check_dirs := src tests
|
||||
|
||||
quality:
|
||||
black --check $(check_dirs)
|
||||
ruff $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
|
||||
style:
|
||||
black $(check_dirs)
|
||||
ruff $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
|
||||
118
README.md
118
README.md
@@ -5,6 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
@@ -16,9 +17,7 @@
|
||||
|
||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||
|
||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
||||
|
||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** and **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**, or launch it locally with `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`.
|
||||
|
||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||
|
||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Changelog](#changelog)
|
||||
- [Supported Models](#supported-models)
|
||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [Citation](#citation)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
|
||||
## Benchmark
|
||||
|
||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||
@@ -55,14 +64,20 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||
|
||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
|
||||
|
||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||
|
||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||
@@ -103,6 +118,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
@@ -110,6 +126,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
@@ -123,7 +140,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
@@ -154,8 +171,8 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self Cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
@@ -171,8 +188,10 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
@@ -185,6 +204,15 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -194,6 +222,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -208,15 +237,27 @@ huggingface-cli login
|
||||
|
||||
## Requirement
|
||||
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- sentencepiece, protobuf and tiktoken
|
||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.1 |
|
||||
| transformers | 4.37.2 | 4.38.1 |
|
||||
| datasets | 2.14.3 | 2.17.1 |
|
||||
| accelerate | 0.27.2 | 0.27.2 |
|
||||
| peft | 0.9.0 | 0.9.0 |
|
||||
| trl | 0.7.11 | 0.7.11 |
|
||||
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.13.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.5 |
|
||||
|
||||
### Hardware Requirement
|
||||
|
||||
\* *estimated*
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
@@ -244,12 +285,14 @@ cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2.
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||
|
||||
### Use ModelScope Hub (optional)
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||
@@ -377,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
|
||||
|
||||
> [!WARNING]
|
||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
||||
|
||||
@@ -405,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
|
||||
|
||||
### Distributed Training
|
||||
|
||||
#### Use Huggingface Accelerate
|
||||
@@ -418,6 +467,7 @@ accelerate launch src/train_bash.py # arguments (same as above)
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
@@ -494,7 +544,7 @@ python src/export_model.py \
|
||||
> [!TIP]
|
||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
|
||||
|
||||
### API Demo
|
||||
### Inference with OpenAI-style API
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
@@ -507,7 +557,7 @@ python src/api_demo.py \
|
||||
> [!TIP]
|
||||
> Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
### Inference with command line
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
@@ -517,7 +567,7 @@ python src/cli_demo.py \
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
### Inference with web browser
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
@@ -554,7 +604,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
@@ -568,11 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
|
||||
> [!TIP]
|
||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||
@@ -581,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
116
README_zh.md
116
README_zh.md
@@ -5,6 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
@@ -16,9 +17,7 @@
|
||||
|
||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||
|
||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board,或者通过命令 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 本地启动。
|
||||
|
||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||
|
||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## 目录
|
||||
|
||||
- [项目特色](#项目特色)
|
||||
- [性能指标](#性能指标)
|
||||
- [更新日志](#更新日志)
|
||||
- [模型](#模型)
|
||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [引用](#引用)
|
||||
- [致谢](#致谢)
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
|
||||
## 性能指标
|
||||
|
||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||
@@ -55,14 +64,20 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||
|
||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。
|
||||
|
||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||
|
||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||
@@ -103,6 +118,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
@@ -110,6 +126,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
@@ -154,8 +171,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self Cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
@@ -171,8 +188,10 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
|
||||
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
@@ -185,6 +204,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
|
||||
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
|
||||
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
|
||||
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
|
||||
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
|
||||
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
|
||||
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -194,6 +222,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -208,15 +237,27 @@ huggingface-cli login
|
||||
|
||||
## 软硬件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
- sentencepiece, protobuf 和 tiktoken
|
||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.1 |
|
||||
| transformers | 4.37.2 | 4.38.1 |
|
||||
| datasets | 2.14.3 | 2.17.1 |
|
||||
| accelerate | 0.27.2 | 0.27.2 |
|
||||
| peft | 0.9.0 | 0.9.0 |
|
||||
| trl | 0.7.11 | 0.7.11 |
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.13.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.5 |
|
||||
|
||||
### 硬件依赖
|
||||
|
||||
\* *估算值*
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
@@ -244,12 +285,14 @@ cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2。
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||
|
||||
### 使用魔搭社区(可跳过)
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
@@ -377,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
|
||||
|
||||
> [!WARNING]
|
||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
||||
|
||||
@@ -405,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
|
||||
|
||||
### 多 GPU 分布式训练
|
||||
|
||||
#### 使用 Huggingface Accelerate
|
||||
@@ -418,6 +467,7 @@ accelerate launch src/train_bash.py # 参数同上
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
@@ -494,7 +544,7 @@ python src/export_model.py \
|
||||
> [!TIP]
|
||||
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
|
||||
|
||||
### API 服务
|
||||
### 使用 OpenAI 风格 API 推理
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
@@ -507,7 +557,7 @@ python src/api_demo.py \
|
||||
> [!TIP]
|
||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
### 使用命令行推理
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
@@ -517,7 +567,7 @@ python src/cli_demo.py \
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
### 使用浏览器推理
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
@@ -554,7 +604,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
@@ -568,11 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
|
||||
> [!TIP]
|
||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||
@@ -581,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
7
SECURITY.md
Normal file
7
SECURITY.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Reporting Security Issues
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
|
||||
|
||||
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
Report security bugs in third-party modules to the person or team maintaining the module.
|
||||
@@ -11,15 +11,23 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"columns": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
|
||||
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
|
||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None, for both)"
|
||||
"columns (optional)": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||
"response": "the column name in the dataset containing the responses. (default: output)",
|
||||
"history": "the column name in the dataset containing the histories. (default: None)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)"
|
||||
},
|
||||
"tags (optional, used for the sharegpt format)": {
|
||||
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||
"content_tag": "the key in the message represents the content. (default: value)",
|
||||
"user_tag": "the value of the role_tag represents the user. (default: human)",
|
||||
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
||||
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
||||
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
||||
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -57,9 +65,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
@@ -91,7 +99,8 @@ The dataset in sharegpt format should follow the below format:
|
||||
"value": "model response"
|
||||
}
|
||||
],
|
||||
"system": "system prompt (optional)"
|
||||
"system": "system prompt (optional)",
|
||||
"tools": "tool description (optional)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -102,13 +111,18 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
"dataset_name": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
"system": "system",
|
||||
"tools": "tools"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "from",
|
||||
"content_tag": "value",
|
||||
"user_tag": "human",
|
||||
"assistant_tag": "gpt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
|
||||
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||
|
||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||
|
||||
@@ -11,15 +11,23 @@
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)",
|
||||
"query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)",
|
||||
"response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)"
|
||||
"columns(可选)": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)"
|
||||
},
|
||||
"tags(可选,用于 sharegpt 格式)": {
|
||||
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||
"content_tag": "消息中代表文本内容的键名(默认:value)",
|
||||
"user_tag": "消息中代表用户的 role_tag(默认:human)",
|
||||
"assistant_tag": "消息中代表助手的 role_tag(默认:gpt)",
|
||||
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -57,9 +65,9 @@
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
|
||||
`system` 为模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
@@ -91,7 +99,8 @@
|
||||
"value": "模型回答"
|
||||
}
|
||||
],
|
||||
"system": "系统提示词(选填)"
|
||||
"system": "系统提示词(选填)",
|
||||
"tools": "工具描述(选填)"
|
||||
}
|
||||
]
|
||||
```
|
||||
@@ -102,13 +111,18 @@
|
||||
"数据集名称": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
"system": "system"
|
||||
"system": "system",
|
||||
"tools": "tools"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "from",
|
||||
"content_tag": "value",
|
||||
"user_tag": "human",
|
||||
"assistant_tag": "gpt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
|
||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||
|
||||
@@ -1 +1 @@
|
||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
||||
34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b
|
||||
29
examples/full_multi_gpu/sft.sh
Normal file
29
examples/full_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
--deepspeed ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
16
examples/lora_multi_gpu/config.yaml
Normal file
16
examples/lora_multi_gpu/config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
30
examples/lora_multi_gpu/sft.sh
Normal file
30
examples/lora_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file config.yaml ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
33
examples/lora_single_gpu/dpo.sh
Normal file
33
examples/lora_single_gpu/dpo.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage dpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 1000 \
|
||||
--val_size 0.1 \
|
||||
--dpo_ftx 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/lora_single_gpu/ppo.sh
Normal file
31
examples/lora_single_gpu/ppo.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage ppo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--reward_model ../../saves/LLaMA2-7B/lora/reward \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 512 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 1000 \
|
||||
--top_k 0 \
|
||||
--top_p 0.9 \
|
||||
--max_new_tokens 256 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
18
examples/lora_single_gpu/predict.sh
Normal file
18
examples/lora_single_gpu/predict.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/predict \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 20 \
|
||||
--predict_with_generate
|
||||
29
examples/lora_single_gpu/pretrain.sh
Normal file
29
examples/lora_single_gpu/pretrain.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage pt \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset c4_demo \
|
||||
--dataset_dir ../../data \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 10000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/lora_single_gpu/reward.sh
Normal file
31
examples/lora_single_gpu/reward.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage rm \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/reward \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 5000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/lora_single_gpu/sft.sh
Normal file
30
examples/lora_single_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/aqlm.sh
Normal file
30
examples/qlora_single_gpu/aqlm.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/awq.sh
Normal file
30
examples/qlora_single_gpu/awq.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--quantization_bit 4 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/gptq.sh
Normal file
30
examples/qlora_single_gpu/gptq.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
@@ -2,23 +2,19 @@
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
target-version = ["py38"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
line-length = 119
|
||||
|
||||
[tool.ruff.isort]
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["llmtuner"]
|
||||
|
||||
[isort]
|
||||
default_section = "FIRSTPARTY"
|
||||
known_first_party = "llmtuner"
|
||||
known_third_party = [
|
||||
known-third-party = [
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"gradio",
|
||||
@@ -28,10 +24,9 @@ known_third_party = [
|
||||
"transformers",
|
||||
"trl"
|
||||
]
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.36.2
|
||||
transformers>=4.37.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.21.0
|
||||
peft>=0.7.0
|
||||
trl>=0.7.6
|
||||
accelerate>=0.27.2
|
||||
peft>=0.9.0
|
||||
trl>=0.7.11
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
einops
|
||||
|
||||
@@ -7,5 +7,5 @@ from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.5.0"
|
||||
__version__ = "0.5.3"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
|
||||
@@ -74,6 +74,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
@@ -85,30 +92,30 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
if not chat_model.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
messages = [dictify(message) for message in request.messages]
|
||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||
system = messages.pop(0)["content"]
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = None
|
||||
system = ""
|
||||
|
||||
if len(messages) % 2 == 0:
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif messages[i]["role"] == Role.TOOL:
|
||||
messages[i]["role"] = DataRole.OBSERVATION
|
||||
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if len(tool_list):
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
|
||||
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
@@ -116,10 +123,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request)
|
||||
return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
|
||||
|
||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
generate = stream_chat_completion(messages, system, tools, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ChatModel:
|
||||
)
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
|
||||
def _process_args(
|
||||
self,
|
||||
@@ -94,6 +94,9 @@ class ChatModel:
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List[Response]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -123,6 +126,9 @@ class ChatModel:
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `stream_chat`.")
|
||||
|
||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -134,9 +140,11 @@ class ChatModel:
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||
|
||||
inputs = self.tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from .utils import Role
|
||||
|
||||
|
||||
@@ -15,23 +17,26 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history:
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
content.append(examples[dataset_attr.prompt][i])
|
||||
|
||||
instruction = examples[dataset_attr.prompt][i]
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
instruction += "\n" + examples[dataset_attr.query][i]
|
||||
prompt.append({"role": Role.USER, "content": instruction})
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
if dataset_attr.response:
|
||||
if isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
response = []
|
||||
|
||||
@@ -46,36 +51,38 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
prompt = []
|
||||
response = []
|
||||
aligned_messages = []
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if turn_idx % 2 == 0:
|
||||
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
|
||||
else:
|
||||
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
|
||||
|
||||
if message[dataset_attr.role_tag] not in accept_tags:
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
|
||||
prompt.append(
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
last_message = prompt.pop(-1)
|
||||
response.append(last_message)
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
|
||||
return outputs
|
||||
@@ -86,8 +93,8 @@ def align_dataset(
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
prompt: [{"role": "user", "content": "..."}]
|
||||
response: [{"role": "assistant", "content": "..."}]
|
||||
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
system: "..."
|
||||
tools: "..."
|
||||
"""
|
||||
@@ -97,6 +104,18 @@ def align_dataset(
|
||||
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
|
||||
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
features = Features.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"response": [
|
||||
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
||||
],
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
@@ -105,4 +124,10 @@ def align_dataset(
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
features=features,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -15,11 +15,11 @@ JSON_FORMAT_PROMPT = (
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format to answer the question:\n"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool{format_prompt}.\n"
|
||||
"```\n"
|
||||
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
|
||||
"Action Input: the input to the action{format_prompt}.\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
|
||||
@@ -31,12 +31,16 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
||||
items = (
|
||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||
)
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
@@ -91,12 +95,15 @@ class StringFormatter(Formatter):
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError("Expected a string, got {}".format(value))
|
||||
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
@@ -120,7 +127,7 @@ class FunctionFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def load_single_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
):
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
@@ -60,7 +61,7 @@ def load_single_dataset(
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
|
||||
checksum(data_files, dataset_attr.dataset_sha1)
|
||||
checksum(data_files, dataset_attr.file_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -142,7 +143,7 @@ def get_dataset(
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
# split: Optional[str] = "train", # TODO: add split
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
@@ -155,12 +156,9 @@ def get_dataset(
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off dataset streaming to save cache files.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
@@ -188,6 +186,6 @@ def get_dataset(
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Empty dataset!")
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
@@ -13,34 +13,44 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
|
||||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
""" extra configs """
|
||||
file_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
|
||||
""" columns for the sharegpt format """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
|
||||
""" tags for the sharegpt format """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||
@@ -73,30 +83,36 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||
)
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names = ["prompt", "query", "response", "history"]
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names = ["messages", "tools"]
|
||||
column_names.extend(["messages", "tools"])
|
||||
|
||||
column_names += ["system"]
|
||||
for column_name in column_names:
|
||||
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
|
||||
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
"user_tag",
|
||||
"assistant_tag",
|
||||
"observation_tag",
|
||||
"function_tag",
|
||||
"system_tag",
|
||||
)
|
||||
for tag in tag_names:
|
||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,12 +22,8 @@ def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
|
||||
tokenized_examples["attention_mask"][i] += [1]
|
||||
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
@@ -51,14 +48,19 @@ def preprocess_supervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
@@ -93,16 +95,16 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
@@ -120,9 +122,10 @@ def preprocess_packed_supervised_dataset(
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -137,12 +140,21 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
@@ -164,17 +176,26 @@ def preprocess_pairwise_dataset(
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
tokenizer,
|
||||
rejected_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
|
||||
@@ -37,7 +37,7 @@ class Template:
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
@@ -57,7 +57,7 @@ class Template:
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 16,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
@@ -88,16 +88,16 @@ class Template:
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
@@ -117,9 +117,9 @@ class Template:
|
||||
elif isinstance(elem, dict):
|
||||
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
elif isinstance(elem, set):
|
||||
if "bos_token" in elem and tokenizer.bos_token_id:
|
||||
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
||||
token_ids += [tokenizer.bos_token_id]
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id:
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
@@ -144,10 +144,10 @@ class Template:
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
source_ids = encoded_messages[i][:max_source_len]
|
||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(source_ids) + len(target_ids)
|
||||
encoded_pairs.append((source_ids, target_ids))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
@@ -179,16 +179,16 @@ class Llama2Template(Template):
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
@@ -198,7 +198,7 @@ class Llama2Template(Template):
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
def _register_template(
|
||||
name: str,
|
||||
format_user: Optional["Formatter"] = None,
|
||||
format_assistant: Optional["Formatter"] = None,
|
||||
@@ -218,7 +218,7 @@ def register_template(
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(slots="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
@@ -236,29 +236,45 @@ def register_template(
|
||||
)
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
is_oov = eos_token not in tokenizer.get_vocab()
|
||||
tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
if is_oov:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
if name is None: # for pre-training
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = templates["vanilla"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
if not stop_words:
|
||||
raise ValueError("Stop words are required to replace the EOS token.")
|
||||
|
||||
tokenizer.eos_token = stop_words[0]
|
||||
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
||||
stop_words = stop_words[1:]
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if stop_words:
|
||||
tokenizer.add_special_tokens(
|
||||
@@ -269,7 +285,7 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer")
|
||||
return template
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
@@ -279,7 +295,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_separator=EmptyFormatter(slots=["###"]),
|
||||
@@ -292,21 +308,30 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
@@ -315,13 +340,13 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
@@ -331,15 +356,32 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
@@ -349,14 +391,43 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
@@ -364,9 +435,10 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
@@ -379,14 +451,15 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
@@ -394,7 +467,17 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
@@ -403,7 +486,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
@@ -420,7 +503,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
@@ -437,7 +520,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
@@ -445,7 +528,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
@@ -453,16 +536,24 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
@@ -473,7 +564,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
||||
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
||||
@@ -481,7 +572,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||
@@ -494,10 +585,12 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(name="vanilla")
|
||||
_register_template(
|
||||
name="vanilla",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
@@ -507,7 +600,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
default_system=(
|
||||
@@ -518,10 +611,13 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]))
|
||||
_register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
@@ -541,7 +637,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
@@ -550,7 +646,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
@@ -559,7 +655,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
@@ -567,7 +663,7 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
_register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
|
||||
@@ -19,8 +19,9 @@ logger = get_logger(__name__)
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
OBSERVATION = "observation"
|
||||
SYSTEM = "system"
|
||||
FUNCTION = "function"
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
|
||||
@@ -24,7 +24,7 @@ class Evaluator:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
|
||||
@@ -11,7 +11,14 @@ DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"}
|
||||
FILEEXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text",
|
||||
}
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
@@ -46,7 +53,9 @@ class DownloadSource(str, Enum):
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
@@ -219,22 +228,36 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekLLM-7B-Base": {
|
||||
"DeepSeek-LLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
||||
},
|
||||
"DeepSeekLLM-67B-Base": {
|
||||
"DeepSeek-LLM-67B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
||||
},
|
||||
"DeepSeekLLM-7B-Chat": {
|
||||
"DeepSeek-LLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
||||
},
|
||||
"DeepSeekLLM-67B-Chat": {
|
||||
"DeepSeek-LLM-67B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
||||
},
|
||||
"DeepSeek-Math-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
||||
},
|
||||
"DeepSeek-Math-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeek-MoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
@@ -246,6 +269,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
},
|
||||
"DeepSeekCoder-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||
@@ -254,6 +280,9 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
},
|
||||
"DeepSeekCoder-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
@@ -263,21 +292,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekMoE-16B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
||||
},
|
||||
"DeepSeekMoE-16B-Chat": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
||||
},
|
||||
},
|
||||
template="deepseek",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
@@ -310,6 +324,29 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
||||
},
|
||||
"Gemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
||||
},
|
||||
"Gemma-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
||||
},
|
||||
"Gemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@@ -370,7 +407,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"},
|
||||
"LLaMA-7B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
||||
},
|
||||
"LLaMA-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||
@@ -455,7 +495,7 @@ register_model_group(
|
||||
register_model_group(
|
||||
models={
|
||||
"OpenChat3.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
||||
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
||||
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
|
||||
}
|
||||
},
|
||||
@@ -465,23 +505,71 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"},
|
||||
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"},
|
||||
"Orion-14B-Base": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
||||
},
|
||||
"Orion-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
||||
},
|
||||
"Orion-14B-Long-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
||||
},
|
||||
"Orion-14B-RAG-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
},
|
||||
"Orion-14B-Plugin-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
},
|
||||
},
|
||||
template="orion",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
||||
},
|
||||
"Phi-2-2.7B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"},
|
||||
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"},
|
||||
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"},
|
||||
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"},
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||
@@ -530,7 +618,112 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"},
|
||||
"Qwen1.5-0.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
|
||||
},
|
||||
"Qwen1.5-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
|
||||
},
|
||||
"Qwen1.5-4B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
|
||||
},
|
||||
"Qwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
|
||||
},
|
||||
"Qwen1.5-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
|
||||
},
|
||||
"Qwen1.5-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
|
||||
},
|
||||
"Qwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
|
||||
},
|
||||
"Qwen1.5-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-0.5B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-1.8B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-1.8B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-4B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-4B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-14B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-14B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-72B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-72B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SOLAR-10.7B": {
|
||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
|
||||
},
|
||||
"SOLAR-10.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
||||
@@ -567,10 +760,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"},
|
||||
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"},
|
||||
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"},
|
||||
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"},
|
||||
"XuanYuan-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||
},
|
||||
"XuanYuan-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
},
|
||||
},
|
||||
template="xuanyuan",
|
||||
)
|
||||
@@ -578,9 +779,18 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"},
|
||||
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"},
|
||||
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"},
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
||||
},
|
||||
"XVERSE-65B-2": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||
@@ -619,18 +829,38 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"},
|
||||
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"},
|
||||
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"},
|
||||
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"},
|
||||
"Yi-6B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
||||
},
|
||||
"Yi-34B": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
||||
},
|
||||
"Yi-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
|
||||
},
|
||||
"Yi-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
||||
},
|
||||
"Yi-6B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
@@ -668,3 +898,18 @@ register_model_group(
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
@@ -133,6 +134,8 @@ def get_current_device() -> torch.device:
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
else:
|
||||
|
||||
@@ -2,11 +2,11 @@ import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_package_available(name: str) -> bool:
|
||||
def _is_package_available(name: str) -> bool:
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
def get_package_version(name: str) -> str:
|
||||
def _get_package_version(name: str) -> str:
|
||||
try:
|
||||
return importlib.metadata.version(name)
|
||||
except Exception:
|
||||
@@ -14,36 +14,40 @@ def get_package_version(name: str) -> str:
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return is_package_available("fastapi")
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return is_package_available("jieba")
|
||||
return _is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return is_package_available("matplotlib")
|
||||
return _is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return is_package_available("nltk")
|
||||
return _is_package_available("nltk")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return is_package_available("requests")
|
||||
return _is_package_available("requests")
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return is_package_available("rouge_chinese")
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return is_package_available("sse_starlette")
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_unsloth_available():
|
||||
return _is_package_available("unsloth")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return is_package_available("uvicorn")
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
38
src/llmtuner/extras/patches/mixtral_patch.py
Normal file
38
src/llmtuner/extras/patches/mixtral_patch.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
|
||||
|
||||
|
||||
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
|
||||
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
topk_weight = topk_weight.to(hidden_states.dtype)
|
||||
|
||||
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||
y = torch.empty_like(hidden_states)
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
for i in range(self.num_experts):
|
||||
expert = self.experts[i]
|
||||
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
def patch_mixtral_replace_moe_impl() -> None:
|
||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
@@ -7,31 +7,42 @@ class DataArguments:
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
|
||||
template: Optional[str] = field(
|
||||
default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data", metadata={"help": "Path to the folder containing the datasets."}
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||
)
|
||||
cutoff_len: Optional[int] = field(
|
||||
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
||||
)
|
||||
reserved_label_len: Optional[int] = field(
|
||||
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable dataset streaming."},
|
||||
)
|
||||
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
@@ -42,13 +53,16 @@ class DataArguments:
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
@@ -57,17 +71,20 @@ class DataArguments:
|
||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||
},
|
||||
)
|
||||
val_size: Optional[float] = field(
|
||||
default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
default=0,
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
||||
)
|
||||
sft_packing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
default=False,
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the preprocessed datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -10,15 +10,34 @@ class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
task: str = field(metadata={"help": "Name of the evaluation task."})
|
||||
task_dir: Optional[str] = field(
|
||||
default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."},
|
||||
)
|
||||
task_dir: Optional[str] = field(
|
||||
default="evaluation",
|
||||
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
||||
)
|
||||
batch_size: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The batch size per GPU for evaluation."},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed to be used with data loaders."},
|
||||
)
|
||||
lang: Optional[Literal["en", "zh"]] = field(
|
||||
default="en",
|
||||
metadata={"help": "Language used at evaluation."},
|
||||
)
|
||||
n_shot: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
|
||||
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
|
||||
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
|
||||
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
|
||||
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
|
||||
download_mode: Optional[DownloadMode] = field(
|
||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
metadata={"help": "Download mode used for the evaluation datasets."},
|
||||
|
||||
@@ -8,20 +8,23 @@ class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
|
||||
name_module_trainable: Optional[str] = field(
|
||||
default="mlp",
|
||||
default=None,
|
||||
metadata={
|
||||
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
Phi choices: ["mlp", "mixer"], \
|
||||
Others choices: the same as LLaMA.'
|
||||
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["mlp", "self_attn"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||
Qwen choices: ["mlp", "attn"], \
|
||||
InternLM2 choices: ["feed_forward", "attention"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
default=3,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +33,7 @@ class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -37,27 +41,45 @@ class LoraArguments:
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
||||
)
|
||||
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
|
||||
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
|
||||
lora_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \
|
||||
Others choices: the same as LLaMA.'
|
||||
"help": """Name(s) of target modules to apply LoRA. \
|
||||
Use commas to separate multiple modules. \
|
||||
Use "all" to specify all the available modules. \
|
||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
|
||||
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
|
||||
Others choices: the same as LLaMA."""
|
||||
},
|
||||
)
|
||||
lora_bf16_mode: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
|
||||
)
|
||||
use_rslora: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
)
|
||||
use_dora: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
|
||||
)
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
)
|
||||
|
||||
|
||||
@@ -66,49 +88,66 @@ class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
"""
|
||||
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||
default="sigmoid", metadata={"help": "The type of DPO loss to use."}
|
||||
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
)
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
dpo_ftx: Optional[float] = field(
|
||||
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||
default=0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
ppo_buffer_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
||||
)
|
||||
ppo_epochs: Optional[int] = field(
|
||||
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
||||
default=4,
|
||||
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}
|
||||
default=None,
|
||||
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization in PPO training."}
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."},
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
||||
)
|
||||
ppo_whiten_rewards: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapters of the reference model."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the reference model."}
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the reward model used for the PPO training."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapters of the reward model."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the reward model."}
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
|
||||
default="lora",
|
||||
@@ -121,14 +160,26 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft", metadata={"help": "Which stage will be performed in training."}
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||
default="lora", metadata={"help": "Which fine-tuning method to use."}
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."},
|
||||
)
|
||||
use_llama_pro: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
disable_version_checking: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable version checking."},
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to save the training loss curves."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -147,10 +198,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
||||
@@ -7,11 +7,14 @@ class GeneratingArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
|
||||
do_sample: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.95, metadata={"help": "The value used to modulate the next token probabilities."}
|
||||
default=0.95,
|
||||
metadata={"help": "The value used to modulate the next token probabilities."},
|
||||
)
|
||||
top_p: Optional[float] = field(
|
||||
default=0.7,
|
||||
@@ -24,7 +27,8 @@ class GeneratingArguments:
|
||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=512,
|
||||
@@ -35,10 +39,12 @@ class GeneratingArguments:
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||
)
|
||||
repetition_penalty: Optional[float] = field(
|
||||
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||
default=1.0,
|
||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
|
||||
)
|
||||
length_penalty: Optional[float] = field(
|
||||
default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -7,11 +7,15 @@ class ModelArguments:
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
|
||||
metadata={
|
||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||
},
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -22,7 +26,8 @@ class ModelArguments:
|
||||
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
||||
)
|
||||
resize_vocab: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
|
||||
)
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
@@ -33,60 +38,88 @@ class ModelArguments:
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the model."}
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."},
|
||||
)
|
||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||
default="nf4", metadata={"help": "Quantization data type to use in int4 training."}
|
||||
default="nf4",
|
||||
metadata={"help": "Quantization data type to use in int4 training."},
|
||||
)
|
||||
double_quantization: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether or not to use double quantization in int4 training."}
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in int4 training."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."},
|
||||
)
|
||||
shift_attn: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
use_unsloth: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
disable_gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable gradient checkpointing."},
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
|
||||
)
|
||||
upcast_lmhead_output: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
|
||||
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
|
||||
export_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the directory to save the exported model."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
export_size: Optional[int] = field(
|
||||
default=1, metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of bits to quantize the exported model."}
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||
)
|
||||
export_quantization_nsamples: Optional[int] = field(
|
||||
default=128, metadata={"help": "The number of samples used for quantization."}
|
||||
default=128,
|
||||
metadata={"help": "The number of samples used for quantization."},
|
||||
)
|
||||
export_quantization_maxlen: Optional[int] = field(
|
||||
default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs used for quantization."},
|
||||
)
|
||||
export_legacy_format: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}
|
||||
default=None,
|
||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||
)
|
||||
print_param_status: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -3,13 +3,14 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -28,6 +29,17 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _check_dependencies(disabled: bool) -> None:
|
||||
if disabled:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
@@ -49,7 +61,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
@@ -60,16 +71,15 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if finetuning_args.create_new_adapter:
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
@@ -121,10 +131,36 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "freeze"
|
||||
and finetuning_args.name_module_trainable is None
|
||||
):
|
||||
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("DoRA does not support quantization.")
|
||||
|
||||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
@@ -138,7 +174,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# postprocess training_args
|
||||
# Post-process training arguments
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
@@ -151,7 +187,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
training_args.resume_from_checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
|
||||
@@ -187,7 +225,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
)
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
# Post-process model arguments
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
@@ -203,9 +241,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
@@ -213,25 +249,27 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .loader import load_model_and_tokenizer
|
||||
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
||||
from .utils import dispatch_model, load_valuehead_params
|
||||
|
||||
|
||||
__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"]
|
||||
__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import find_all_linear_modules
|
||||
from .utils import find_all_linear_modules, find_expanded_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,24 +46,46 @@ def init_adapter(
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.num_layer_trainable
|
||||
)
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
|
||||
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
|
||||
|
||||
freeze_modules = {"all"}
|
||||
for name, _ in model.named_modules():
|
||||
if ".0." in name:
|
||||
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
if module_name not in freeze_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
|
||||
)
|
||||
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
||||
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -84,7 +105,7 @@ def init_adapter(
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model = PeftModel.from_pretrained(model, adapter)
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
@@ -99,32 +120,32 @@ def init_adapter(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
|
||||
unsloth_peft_kwargs["loftq_config"] = {}
|
||||
|
||||
if getattr(model.config, "model_type", None) == "llama":
|
||||
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
elif getattr(model.config, "model_type", None) == "mistral":
|
||||
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
@@ -132,7 +153,7 @@ def init_adapter(
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
@@ -21,13 +20,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
@@ -63,8 +55,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = {
|
||||
"model_name": model_args.model_name_or_path,
|
||||
@@ -72,14 +63,12 @@ def load_model_and_tokenizer(
|
||||
"dtype": model_args.compute_dtype,
|
||||
"load_in_4bit": model_args.quantization_bit == 4,
|
||||
"token": model_args.hf_hub_token,
|
||||
"device_map": get_current_device(),
|
||||
"device_map": {"": get_current_device()},
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
|
||||
elif getattr(config, "model_type", None) == "mistral":
|
||||
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
|
||||
else:
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model_args.use_unsloth = False
|
||||
|
||||
@@ -132,4 +121,12 @@ def load_model_and_tokenizer(
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
if model_args.print_param_status:
|
||||
for name, param in model.named_parameters():
|
||||
print(
|
||||
"name: {}, dtype: {}, device: {}, trainable: {}".format(
|
||||
name, param.dtype, param.device, param.requires_grad
|
||||
)
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import PeftModel
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -16,6 +17,7 @@ from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||
from ..extras.packages import is_flash_attn2_available
|
||||
from ..extras.patches.llama_patch import apply_llama_patch
|
||||
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -100,6 +102,18 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
return samples
|
||||
|
||||
|
||||
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
|
||||
if model_args.flash_attn:
|
||||
if is_flash_attn2_available():
|
||||
config_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
config_kwargs["attn_implementation"] = None
|
||||
else:
|
||||
config_kwargs["attn_implementation"] = "eager"
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
@@ -127,15 +141,6 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
|
||||
)
|
||||
|
||||
|
||||
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
|
||||
if not is_flash_attn2_available():
|
||||
logger.warning("FlashAttention2 is not installed.")
|
||||
return
|
||||
|
||||
config_kwargs["use_flash_attention_2"] = True
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
@@ -152,7 +157,7 @@ def _configure_quantization(
|
||||
config_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@@ -162,7 +167,15 @@ def _configure_quantization(
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
logger.info(
|
||||
"Loading {}-bit {}-quantized model.".format(
|
||||
quantization_config.get("bits", "?"), quantization_config.get("quant_method", None)
|
||||
)
|
||||
)
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
@@ -222,7 +235,10 @@ def _prepare_model_for_training(
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.enable_input_require_grads()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
@@ -255,12 +271,11 @@ def patch_config(
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
_configure_attn_implementation(model_args, config_kwargs)
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
|
||||
if model_args.flash_attn:
|
||||
_configure_flashattn(config_kwargs)
|
||||
|
||||
if is_trainable and model_args.shift_attn:
|
||||
_configure_longlora(config)
|
||||
|
||||
@@ -283,6 +298,21 @@ def patch_model(
|
||||
if is_trainable:
|
||||
_prepare_model_for_training(model, model_args)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if is_trainable:
|
||||
patch_mixtral_replace_moe_impl()
|
||||
|
||||
try:
|
||||
model.add_model_tags(["llama-factory"])
|
||||
except Exception:
|
||||
logger.warning("Cannot properly tag the model.")
|
||||
|
||||
|
||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
@@ -293,7 +323,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
|
||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
@@ -13,7 +13,7 @@ from ..extras.misc import get_current_device
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -41,7 +41,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
return dispatch_model(model, **device_map_kwargs)
|
||||
@@ -76,16 +76,31 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def get_modelcard_args(
|
||||
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"license": "other",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
|
||||
}
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
||||
if num_layers % num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||
)
|
||||
|
||||
stride = num_layers // num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||
module_names = []
|
||||
for name, _ in model.named_modules():
|
||||
if any(target_module in name for target_module in target_modules) and any(
|
||||
trainable_layer in name for trainable_layer in trainable_layers
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
|
||||
@@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
@@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.packages import is_requests_available
|
||||
|
||||
@@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict(
|
||||
{
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
|
||||
}
|
||||
)
|
||||
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
|
||||
with context_maybe_zero3:
|
||||
if target == "reward": # save default head temporarily
|
||||
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
|
||||
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
|
||||
@@ -61,6 +61,7 @@ def run_ppo(
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
|
||||
@@ -56,12 +56,13 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||
|
||||
setattr(model.config, "use_cache", True)
|
||||
if getattr(model.config, "torch_dtype", None) == "bfloat16":
|
||||
model = model.to(torch.bfloat16).to("cpu")
|
||||
if getattr(model, "quantization_method", None):
|
||||
model = model.to("cpu")
|
||||
elif hasattr(model.config, "torch_dtype"):
|
||||
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
|
||||
else:
|
||||
model = model.to(torch.float16).to("cpu")
|
||||
setattr(model.config, "torch_dtype", "float16")
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
|
||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -25,14 +25,18 @@ def create_modelcard_and_push(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
|
||||
return
|
||||
try:
|
||||
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
|
||||
except Exception as err:
|
||||
logger.warning("Failed to create model card: {}".format(str(err)))
|
||||
kwargs = {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory", finetuning_args.finetuning_type],
|
||||
}
|
||||
if not training_args.do_train:
|
||||
pass
|
||||
elif training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
|
||||
@@ -106,6 +106,7 @@ class WebChatModel(ChatModel):
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
@@ -115,7 +116,7 @@ class WebChatModel(ChatModel):
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
query_messages = messages + [{"role": role, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
@@ -130,10 +131,10 @@ class WebChatModel(ChatModel):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import Role
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
@@ -20,23 +21,24 @@ def create_chat_box(
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=2)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
gen_kwargs = engine.chatter.generating_args
|
||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
@@ -48,13 +50,14 @@ def create_chat_box(
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
clear_btn=clear_btn,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ def save_model(
|
||||
max_shard_size: int,
|
||||
export_quantization_bit: int,
|
||||
export_quantization_dataset: str,
|
||||
export_legacy_format: bool,
|
||||
export_dir: str,
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
@@ -61,6 +62,7 @@ def save_model(
|
||||
export_size=max_shard_size,
|
||||
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_legacy_format=export_legacy_format,
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
@@ -73,6 +75,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
|
||||
export_dir = gr.Textbox()
|
||||
export_btn = gr.Button()
|
||||
@@ -90,6 +93,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_shard_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
export_legacy_format,
|
||||
export_dir,
|
||||
],
|
||||
[info_box],
|
||||
@@ -99,6 +103,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
max_shard_size=max_shard_size,
|
||||
export_quantization_bit=export_quantization_bit,
|
||||
export_quantization_dataset=export_quantization_dataset,
|
||||
export_legacy_format=export_legacy_format,
|
||||
export_dir=export_dir,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box,
|
||||
|
||||
@@ -16,7 +16,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
@@ -30,13 +30,11 @@ def create_top() -> Dict[str, "Component"]:
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none")
|
||||
booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
@@ -52,8 +52,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
@@ -76,11 +76,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
|
||||
with gr.Column():
|
||||
sft_packing = gr.Checkbox(value=False)
|
||||
upcast_layernorm = gr.Checkbox(value=False)
|
||||
with gr.Row():
|
||||
resize_vocab = gr.Checkbox()
|
||||
sft_packing = gr.Checkbox()
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm})
|
||||
input_elems.update(
|
||||
{
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
neftune_alpha,
|
||||
resize_vocab,
|
||||
sft_packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
extra_tab=extra_tab,
|
||||
@@ -88,28 +101,52 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha,
|
||||
resize_vocab=resize_vocab,
|
||||
sft_packing=sft_packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(scale=3)
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter})
|
||||
with gr.Row():
|
||||
use_rslora = gr.Checkbox(scale=1)
|
||||
use_dora = gr.Checkbox(scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update(
|
||||
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
create_new_adapter=create_new_adapter,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -143,7 +180,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@ import gradio as gr
|
||||
import transformers
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
@@ -64,12 +65,15 @@ class Runner:
|
||||
if len(dataset) == 0:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if self.demo_mode and (not from_preview):
|
||||
if not from_preview and self.demo_mode:
|
||||
return ALERTS["err_demo"][lang]
|
||||
|
||||
if not from_preview and get_device_count() > 1:
|
||||
return ALERTS["err_device_count"][lang]
|
||||
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.aborted = False
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
@@ -125,27 +129,40 @@ class Runner:
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
sft_packing=get("train.sft_packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") or None,
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = int(get("train.lora_rank"))
|
||||
args["lora_alpha"] = float(get("train.lora_alpha"))
|
||||
args["lora_dropout"] = float(get("train.lora_dropout"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
if args["stage"] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
@@ -154,8 +171,9 @@ class Runner:
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -44,11 +44,14 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
try:
|
||||
tools = json.loads(text)
|
||||
for tool in tools:
|
||||
assert "name" in tool
|
||||
except AssertionError:
|
||||
if tools:
|
||||
assert isinstance(tools, list)
|
||||
for tool in tools:
|
||||
if "name" not in tool:
|
||||
raise ValueError("Name not found.")
|
||||
except ValueError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except json.JSONDecodeError:
|
||||
except Exception:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
@@ -24,26 +24,35 @@ BASE_BS = 4_000_000 # from llama paper
|
||||
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
dataset: str,
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||
stage: Optional[str] = "sft",
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
stage=stage,
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template="default",
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||
)
|
||||
|
||||
52
tests/length_cdf.py
Normal file
52
tests/length_cdf.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Calculates the distribution of the input lengths in the dataset.
|
||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
|
||||
def length_cdf(
|
||||
model_name_or_path: str,
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
interval: Optional[int] = 1000,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=1_000_000,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
length_dict[len(sample) // interval * interval] += 1
|
||||
|
||||
length_tuples = list(length_dict.items())
|
||||
length_tuples.sort()
|
||||
count_accu, prob_accu = 0, 0
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(length_cdf)
|
||||
115
tests/llama_pro.py
Normal file
115
tests/llama_pro.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Performs block expansion for LLaMA, Mistral or Qwen1.5 models.
|
||||
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||
|
||||
|
||||
def block_expansion(
|
||||
model_name_or_path: str,
|
||||
output_dir: str,
|
||||
num_expand: int,
|
||||
shard_size: Optional[str] = "2GB",
|
||||
save_safetensors: Optional[bool] = False,
|
||||
):
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
|
||||
if save_safetensors:
|
||||
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
|
||||
|
||||
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||
|
||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||
layer_cnt += 1
|
||||
if (i + 1) % split == 0:
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if "down_proj" in key or "o_proj" in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||
else:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||
|
||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||
layer_cnt += 1
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key not in output_state_dict:
|
||||
output_state_dict[key] = value
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("Fine-tune this model with:")
|
||||
print(" --model_name_or_path {} \\".format(output_dir))
|
||||
print(" --finetuning_type freeze \\")
|
||||
print(" --name_module_trainable all \\")
|
||||
print(" --num_layer_trainable {} \\".format(num_expand))
|
||||
print(" --use_llama_pro")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(block_expansion)
|
||||
@@ -1,6 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
|
||||
@@ -76,7 +76,9 @@ def save_config(input_dir: str, output_dir: str):
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||
def llamafy_baichuan2(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Converts the InternLM2 model in the same format as LLaMA2.
|
||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
|
||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
|
||||
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
||||
|
||||
import json
|
||||
@@ -98,7 +98,9 @@ def save_config(input_dir: str, output_dir: str):
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||
def llamafy_internlm2(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Converts the Qwen models in the same format as LLaMA2.
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||
|
||||
import json
|
||||
@@ -128,7 +128,9 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
|
||||
def llamafy_qwen(
|
||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
|
||||
@@ -26,7 +26,7 @@ class Shell(nn.Module):
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
||||
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403
|
||||
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
||||
parent_name = ".".join(name.split(".")[:-1])
|
||||
child_name = name.split(".")[-1]
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
|
||||
57
tests/test_toolcall.py
Normal file
57
tests/test_toolcall.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||
|
||||
|
||||
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
||||
grade_to_score = {"A": 4, "B": 3, "C": 2}
|
||||
total_score, total_hour = 0, 0
|
||||
for grade, hour in zip(grades, hours):
|
||||
total_score += grade_to_score[grade] * hour
|
||||
total_hour += hour
|
||||
return total_score / total_hour
|
||||
|
||||
|
||||
tool_map = {"calculate_gpa": calculate_gpa}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = OpenAI(
|
||||
api_key="0",
|
||||
base_url="http://localhost:8000/v1",
|
||||
)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate_gpa",
|
||||
"description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
|
||||
"hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
|
||||
},
|
||||
"required": ["grades", "hours"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||
tool_call = result.choices[0].message.tool_calls[0].function
|
||||
name, arguments = tool_call.name, json.loads(tool_call.arguments)
|
||||
messages.append(
|
||||
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
|
||||
)
|
||||
tool_result = tool_map[name](**arguments)
|
||||
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
|
||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||
print(result.choices[0].message.content)
|
||||
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
|
||||
Reference in New Issue
Block a user