Compare commits
187 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0c387e4db | ||
|
|
ae60ea15da | ||
|
|
72cd1123a8 | ||
|
|
1364190a66 | ||
|
|
6d17c59090 | ||
|
|
e0f2c0b5dc | ||
|
|
073e34855d | ||
|
|
ff9ba70bb8 | ||
|
|
adbebb0e3f | ||
|
|
3f6b3eed98 | ||
|
|
f45e81e186 | ||
|
|
ba648fd003 | ||
|
|
b0e5a76f4c | ||
|
|
8692796c9b | ||
|
|
d0edcde4ea | ||
|
|
8c4c2e580c | ||
|
|
07f33e7641 | ||
|
|
1998c641af | ||
|
|
be1e5f9d62 | ||
|
|
fdeec6db52 | ||
|
|
a4d335b42f | ||
|
|
fcb134e144 | ||
|
|
a47e24222a | ||
|
|
b96b995620 | ||
|
|
c231706aa5 | ||
|
|
35b5117a59 | ||
|
|
80f716bc10 | ||
|
|
ca95e98ca0 | ||
|
|
d5559461c1 | ||
|
|
f4acd81e2f | ||
|
|
31feb6e26c | ||
|
|
7d5c0a069c | ||
|
|
937f49ec3d | ||
|
|
abc2a73a33 | ||
|
|
5e1bf7572c | ||
|
|
8fdb32d0a3 | ||
|
|
c709d5f7db | ||
|
|
f5b2749ec2 | ||
|
|
ee5853c565 | ||
|
|
6ec6df8a5f | ||
|
|
fc95800840 | ||
|
|
765715af21 | ||
|
|
639a7f6796 | ||
|
|
35379c7c0e | ||
|
|
d992f5353f | ||
|
|
875eef45f3 | ||
|
|
556a4aa972 | ||
|
|
8dc1969111 | ||
|
|
b74c229498 | ||
|
|
3dbca466fd | ||
|
|
ce6f7fdb82 | ||
|
|
7528bc1bc0 | ||
|
|
9dd5f7d642 | ||
|
|
99ecb0daaf | ||
|
|
39d8d7995a | ||
|
|
2ac2cde03e | ||
|
|
aa6c3766de | ||
|
|
f4f5d7e3ce | ||
|
|
efbf6018d3 | ||
|
|
1090bb8bf3 | ||
|
|
26bc79f971 | ||
|
|
4c1f015eca | ||
|
|
0655a183d3 | ||
|
|
7754024e9b | ||
|
|
b4913569a8 | ||
|
|
eae9f09ca8 | ||
|
|
8264e5ceaa | ||
|
|
b76f319e45 | ||
|
|
82d744716a | ||
|
|
1a3764ab8f | ||
|
|
d2ede9d393 | ||
|
|
5690f513fc | ||
|
|
123a845209 | ||
|
|
b1b7d735b3 | ||
|
|
230c69f7ce | ||
|
|
bfc43558ef | ||
|
|
f2ae2cc04d | ||
|
|
6e9c03f958 | ||
|
|
2696f614a7 | ||
|
|
070b944895 | ||
|
|
f5f091d390 | ||
|
|
14ab14a0e6 | ||
|
|
4f7c850115 | ||
|
|
391eca66cf | ||
|
|
a67199246d | ||
|
|
5f67fdaac9 | ||
|
|
05e6fe4287 | ||
|
|
91cc571e6e | ||
|
|
890926e60c | ||
|
|
87aa332583 | ||
|
|
f90c4ca672 | ||
|
|
a922e85a5c | ||
|
|
9a65820592 | ||
|
|
f4e16ae373 | ||
|
|
e2cfd34da0 | ||
|
|
668dea9706 | ||
|
|
084be442f2 | ||
|
|
29cb4a1327 | ||
|
|
81a61134b8 | ||
|
|
cb1a49aa02 | ||
|
|
351b4efc6c | ||
|
|
9b551309de | ||
|
|
9fed4a2ef4 | ||
|
|
bceac4f554 | ||
|
|
ae3a88d3a7 | ||
|
|
9138a7a5ba | ||
|
|
9912b43fcc | ||
|
|
5ac37555a4 | ||
|
|
34bdc730a6 | ||
|
|
e45a9d70fc | ||
|
|
232b36059c | ||
|
|
d9fbd675d5 | ||
|
|
0206e7b9de | ||
|
|
a886544d3d | ||
|
|
8c9b929bb0 | ||
|
|
1bb1ae834e | ||
|
|
0d9e364a90 | ||
|
|
3b28c003dd | ||
|
|
48ff9fb150 | ||
|
|
c43bc74fe6 | ||
|
|
eaf9cc2195 | ||
|
|
4bd276f58f | ||
|
|
f8cf0d5e5d | ||
|
|
79bc60db33 | ||
|
|
dc7c54067e | ||
|
|
932f0d5c20 | ||
|
|
9670f5e41a | ||
|
|
97a23e1cbe | ||
|
|
11fcd055ec | ||
|
|
b0d9966663 | ||
|
|
5c51ab7e1f | ||
|
|
26f293d587 | ||
|
|
a3b52fd380 | ||
|
|
27d8706d6d | ||
|
|
bf59383783 | ||
|
|
1078611259 | ||
|
|
e6fc0ac8fe | ||
|
|
554ca3d8dc | ||
|
|
86dfdf956d | ||
|
|
c0e4475485 | ||
|
|
2b65f8bd5c | ||
|
|
09e78272c2 | ||
|
|
cccce564bd | ||
|
|
4adec327de | ||
|
|
1f093334d1 | ||
|
|
e0e8507108 | ||
|
|
f5962f8128 | ||
|
|
b31d808655 | ||
|
|
247cda4b68 | ||
|
|
e30975e9a2 | ||
|
|
de9f1583c2 | ||
|
|
ab48653e63 | ||
|
|
6d7a1e3f8f | ||
|
|
e093dad7cb | ||
|
|
b103a121f0 | ||
|
|
3578abc7a4 | ||
|
|
17d398f419 | ||
|
|
3453a8eebb | ||
|
|
77a089c35c | ||
|
|
516d83c946 | ||
|
|
fd02c9f973 | ||
|
|
351e80a656 | ||
|
|
4f04e2ed93 | ||
|
|
a810d1b98e | ||
|
|
fbe963a96a | ||
|
|
d13b8bee8a | ||
|
|
0aa072a155 | ||
|
|
57dde7c3bc | ||
|
|
6b9003f781 | ||
|
|
9c1c59e481 | ||
|
|
31daec2749 | ||
|
|
2bff90719b | ||
|
|
e4570e28a8 | ||
|
|
d84a730daa | ||
|
|
0fd1a05cec | ||
|
|
6373d307ec | ||
|
|
a32c3a50fc | ||
|
|
66b5634ebf | ||
|
|
92b3697e2c | ||
|
|
969e605c7e | ||
|
|
a3320f26cf | ||
|
|
45329d9e3c | ||
|
|
6481321470 | ||
|
|
efcf5e050d | ||
|
|
dfa686b617 | ||
|
|
fe638cf11f | ||
|
|
7cdc16abdf |
@@ -4,6 +4,8 @@
|
||||
.venv
|
||||
cache
|
||||
data
|
||||
hf_cache
|
||||
output
|
||||
examples
|
||||
.dockerignore
|
||||
.gitattributes
|
||||
|
||||
26
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
26
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -13,6 +13,18 @@ body:
|
||||
- label: I have read the README and searched the existing issues.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: llamafactory version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
@@ -26,7 +38,7 @@ body:
|
||||
请合理使用 Markdown 标签来格式化您的文本。
|
||||
|
||||
placeholder: |
|
||||
python src/train_bash.py ...
|
||||
llamafactory-cli train ...
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
@@ -38,18 +50,6 @@ body:
|
||||
Please provide a clear and concise description of what you would expect to happen.
|
||||
请提供您原本的目的,即这段代码的期望行为。
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
validations:
|
||||
required: false
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
|
||||
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
|
||||
|
||||
placeholder: transformers version, platform, python version, ...
|
||||
|
||||
- type: textarea
|
||||
id: others
|
||||
validations:
|
||||
|
||||
28
.github/workflows/tests.yml
vendored
28
.github/workflows/tests.yml
vendored
@@ -2,28 +2,38 @@ name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.py"
|
||||
- "requirements.txt"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
cache: "pip"
|
||||
cache-dependency-path: "setup.py"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ruff
|
||||
|
||||
python -m pip install .[torch,dev]
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
|
||||
@@ -6,7 +6,7 @@ COPY requirements.txt /app/
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
COPY . /app/
|
||||
RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
|
||||
RUN pip install -e .[metrics,bitsandbytes,qwen]
|
||||
|
||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||
EXPOSE 7860
|
||||
|
||||
5
Makefile
5
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: quality style
|
||||
.PHONY: quality style test
|
||||
|
||||
check_dirs := scripts src tests
|
||||
|
||||
@@ -9,3 +9,6 @@ quality:
|
||||
style:
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
pytest tests/
|
||||
|
||||
206
README.md
206
README.md
@@ -3,15 +3,15 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89
|
||||
Choose your path:
|
||||
|
||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
|
||||
## Table of Contents
|
||||
@@ -46,7 +47,7 @@ Choose your path:
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
@@ -70,14 +71,22 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
|
||||
|
||||
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
|
||||
[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
|
||||
|
||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
||||
|
||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
||||
@@ -104,7 +113,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall`.
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
|
||||
|
||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||
|
||||
@@ -142,43 +151,44 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| Model | Model size | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
>
|
||||
> Remember to use the **SAME** template in training and inference.
|
||||
|
||||
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
|
||||
Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
|
||||
|
||||
You also can add a custom chat template to [template.py](src/llmtuner/data/template.py).
|
||||
You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
@@ -189,7 +199,9 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
@@ -202,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
||||
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
@@ -209,12 +223,12 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
|
||||
<details><summary>Supervised fine-tuning datasets</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
@@ -223,7 +237,6 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
@@ -236,15 +249,16 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
@@ -260,13 +274,13 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||
|
||||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -281,21 +295,21 @@ huggingface-cli login
|
||||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.40.1 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.0 |
|
||||
| peft | 0.9.0 | 0.10.0 |
|
||||
| trl | 0.8.1 | 0.8.6 |
|
||||
| python | 3.8 | 3.11 |
|
||||
| torch | 1.13.1 | 2.3.0 |
|
||||
| transformers | 4.41.2 | 4.41.2 |
|
||||
| datasets | 2.16.0 | 2.19.2 |
|
||||
| accelerate | 0.30.1 | 0.30.1 |
|
||||
| peft | 0.11.1 | 0.11.1 |
|
||||
| trl | 0.8.6 | 0.9.4 |
|
||||
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.14.0 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.0 | 0.4.2 |
|
||||
| flash-attn | 2.3.0 | 2.5.8 |
|
||||
| vllm | 0.4.3 | 0.4.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.9 |
|
||||
|
||||
### Hardware Requirement
|
||||
|
||||
@@ -319,12 +333,12 @@ huggingface-cli login
|
||||
> Installation is mandatory.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e .[torch,metrics]
|
||||
pip install -e '.[torch,metrics]'
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
|
||||
> [!TIP]
|
||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||
@@ -343,19 +357,35 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**.
|
||||
Join [NPU user group](assets/wechat_npu.jpg).
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
# install CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
|
||||
# install CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# set env variables
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.2.0 | 2.2.0 |
|
||||
| torch-npu | 2.2.0 | 2.2.0 |
|
||||
| torch | 2.1.0 | 2.1.0 |
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker image:
|
||||
|
||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB: Coming soon
|
||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
@@ -387,29 +417,12 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||
|
||||
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board GUI only supports training on a single GPU.
|
||||
|
||||
#### Use local environment
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||
```
|
||||
|
||||
<details><summary>For Alibaba Cloud PAI or AutoDL users</summary>
|
||||
|
||||
If you encountered display problems in LLaMA Board on Alibaba Cloud PAI, try using the following command to set environment variables before starting LLaMA Board:
|
||||
|
||||
```bash
|
||||
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||
```
|
||||
|
||||
If you are using AutoDL, please install a specific version of Gradio:
|
||||
|
||||
```bash
|
||||
pip install gradio==4.10.0
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Use Docker
|
||||
@@ -420,7 +433,6 @@ docker run --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-p 7860:7860 \
|
||||
--shm-size 16G \
|
||||
--name llama_factory \
|
||||
@@ -447,6 +459,9 @@ docker compose -f ./docker-compose.yml up -d
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
|
||||
|
||||
### Download from ModelScope Hub
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
|
||||
@@ -455,7 +470,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you
|
||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
|
||||
### Use W&B Logger
|
||||
|
||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
|
||||
|
||||
```yaml
|
||||
report_to: wandb
|
||||
run_name: test_run # optional
|
||||
```
|
||||
|
||||
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
@@ -502,7 +528,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
@@ -514,7 +540,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
212
README_zh.md
212
README_zh.md
@@ -3,15 +3,15 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
@@ -26,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
选择你的打开方式:
|
||||
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
|
||||
## 目录
|
||||
@@ -46,7 +47,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
@@ -70,14 +71,22 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。
|
||||
|
||||
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
|
||||
[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
|
||||
|
||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
|
||||
|
||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
@@ -104,7 +113,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall` 即可使模型获得工具调用能力。
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `dataset: glaive_toolcall_zh` 即可使模型获得工具调用能力。
|
||||
|
||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `use_unsloth: true` 参数启用 unsloth 优化。该方法可提供 **170%** 的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
@@ -142,43 +151,44 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 3.8B | qkv_proj | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果。
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
>
|
||||
> 请务必在训练和推理时使用**完全一致**的模板。
|
||||
> 请务必在训练和推理时采用**完全一致**的模板。
|
||||
|
||||
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
|
||||
项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
|
||||
|
||||
您也可以在 [template.py](src/llmtuner/data/template.py) 中添加自己的对话模板。
|
||||
您也可以在 [template.py](src/llamafactory/data/template.py) 中添加自己的对话模板。
|
||||
|
||||
## 训练方法
|
||||
|
||||
@@ -189,7 +199,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| SimPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## 数据集
|
||||
|
||||
@@ -202,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
|
||||
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
@@ -209,12 +223,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
<details><summary>指令微调数据集</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
@@ -223,7 +237,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
@@ -236,15 +249,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
@@ -260,13 +274,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -281,21 +295,21 @@ huggingface-cli login
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.0 |
|
||||
| transformers | 4.37.2 | 4.40.1 |
|
||||
| datasets | 2.14.3 | 2.19.1 |
|
||||
| accelerate | 0.27.2 | 0.30.0 |
|
||||
| peft | 0.9.0 | 0.10.0 |
|
||||
| trl | 0.8.1 | 0.8.6 |
|
||||
| python | 3.8 | 3.11 |
|
||||
| torch | 1.13.1 | 2.3.0 |
|
||||
| transformers | 4.41.2 | 4.41.2 |
|
||||
| datasets | 2.16.0 | 2.19.2 |
|
||||
| accelerate | 0.30.1 | 0.30.1 |
|
||||
| peft | 0.11.1 | 0.11.1 |
|
||||
| trl | 0.8.6 | 0.9.4 |
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.14.0 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.0 | 0.4.2 |
|
||||
| flash-attn | 2.3.0 | 2.5.8 |
|
||||
| vllm | 0.4.3 | 0.4.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.9 |
|
||||
|
||||
### 硬件依赖
|
||||
|
||||
@@ -319,12 +333,12 @@ huggingface-cli login
|
||||
> 此步骤为必需。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e .[torch,metrics]
|
||||
pip install -e '.[torch,metrics]'
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
|
||||
> [!TIP]
|
||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||
@@ -343,21 +357,37 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。
|
||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
# 安装 CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
|
||||
# 安装 CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# 设置环境变量
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.2.0 | 2.2.0 |
|
||||
| torch-npu | 2.2.0 | 2.2.0 |
|
||||
| torch | 2.1.0 | 2.1.0 |
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker 镜像:
|
||||
|
||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB:敬请期待
|
||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||
|
||||
@@ -387,31 +417,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
|
||||
|
||||
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练。
|
||||
|
||||
#### 使用本地环境
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
|
||||
```
|
||||
|
||||
<details><summary>阿里云 PAI 和 AutoDL 用户指南</summary>
|
||||
|
||||
如果您在阿里云 PAI 上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
||||
|
||||
```bash
|
||||
export GRADIO_SERVER_PORT=7860 GRADIO_ROOT_PATH=/${JUPYTER_NAME}/proxy/7860/
|
||||
```
|
||||
|
||||
如果您正在使用 AutoDL,请安装下述 Gradio 版本:
|
||||
|
||||
```bash
|
||||
pip install gradio==4.10.0
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 使用 Docker
|
||||
|
||||
```bash
|
||||
@@ -420,7 +431,6 @@ docker run --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-p 7860:7860 \
|
||||
--shm-size 16G \
|
||||
--name llama_factory \
|
||||
@@ -447,6 +457,9 @@ docker compose -f ./docker-compose.yml up -d
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
|
||||
|
||||
### 从魔搭社区下载
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
@@ -455,7 +468,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l
|
||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
```
|
||||
|
||||
将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
|
||||
### 使用 W&B 面板
|
||||
|
||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
|
||||
|
||||
```yaml
|
||||
report_to: wandb
|
||||
run_name: test_run # 可选
|
||||
```
|
||||
|
||||
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
@@ -502,7 +526,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
@@ -514,7 +538,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
222
data/README.md
222
data/README.md
@@ -1,16 +1,18 @@
|
||||
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section.
|
||||
The [dataset_info.json](dataset_info.json) contains all available datasets. If you are using a custom dataset, please **make sure** to add a *dataset description* in `dataset_info.json` and specify `dataset: dataset_name` before training to use it.
|
||||
|
||||
Currently we support datasets in **alpaca** and **sharegpt** format.
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore script_url and file_name)",
|
||||
"ms_hub_url": "the name of the dataset repository on the Model Scope hub. (if specified, ignore script_url and file_name)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore file_name)",
|
||||
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||
"file_name": "the name of the dataset folder or dataset file in this directory. (required if above are not specified)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
||||
"columns (optional)": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||
@@ -19,7 +21,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)"
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
||||
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
||||
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
||||
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
||||
},
|
||||
"tags (optional, used for the sharegpt format)": {
|
||||
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||
@@ -33,28 +38,34 @@ If you are using a custom dataset, please add your **dataset description** to `d
|
||||
}
|
||||
```
|
||||
|
||||
After that, you can load the custom dataset by specifying `--dataset dataset_name`.
|
||||
## Alpaca Format
|
||||
|
||||
----
|
||||
### Supervised Fine-Tuning Dataset
|
||||
|
||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
||||
* [Example dataset](alpaca_en_demo.json)
|
||||
|
||||
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt if specified.
|
||||
|
||||
The `history` column is a list consisting of string tuples representing prompt-response pairs in the history messages. Note that the responses in the history **will also be learned by the model** in supervised fine-tuning.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"output": "model response (required)",
|
||||
"system": "system prompt (optional)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
["human instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["human instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@@ -69,11 +80,11 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
}
|
||||
```
|
||||
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
### Pre-training Dataset
|
||||
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
|
||||
- [Example dataset](c4_demo.json)
|
||||
|
||||
For the **pre-training datasets**, only the `prompt` column will be used for training, for example:
|
||||
In pre-training, only the `text` column will be used for model learning.
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -82,7 +93,7 @@ For the **pre-training datasets**, only the `prompt` column will be used for tra
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@@ -93,22 +104,24 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
}
|
||||
```
|
||||
|
||||
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
### Preference Dataset
|
||||
|
||||
Preference datasets are used for reward modeling, DPO training and ORPO training.
|
||||
|
||||
It requires a better response in `chosen` column and a worse response in `rejected` column.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"chosen": "chosen answer (required)",
|
||||
"rejected": "rejected answer (required)"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@@ -117,14 +130,85 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
----
|
||||
### KTO Dataset
|
||||
|
||||
The dataset in **sharegpt** format should follow the below format:
|
||||
- [Example dataset](kto_en_demo.json)
|
||||
|
||||
KTO datasets require a extra `kto_tag` column containing the boolean human feedback.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"output": "model response (required)",
|
||||
"kto_tag": "human feedback [true/false] (required)"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"kto_tag": "kto_tag"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multimodal Dataset
|
||||
|
||||
- [Example dataset](mllm_demo.json)
|
||||
|
||||
Multimodal datasets require a `images` column containing the paths to the input images. Currently we only support one image.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"output": "model response (required)",
|
||||
"images": [
|
||||
"image path (required)"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"images": "images"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Sharegpt Format
|
||||
|
||||
### Supervised Fine-Tuning Dataset
|
||||
|
||||
- [Example dataset](glaive_toolcall_en_demo.json)
|
||||
|
||||
Compared to the alpaca format, the sharegpt format allows the datasets have **more roles**, such as human, gpt, observation and function. They are presented in a list of objects in the `conversations` column.
|
||||
|
||||
Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions.
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -132,7 +216,15 @@ The dataset in **sharegpt** format should follow the below format:
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "user instruction"
|
||||
"value": "human instruction"
|
||||
},
|
||||
{
|
||||
"from": "function_call",
|
||||
"value": "tool arguments"
|
||||
},
|
||||
{
|
||||
"from": "observation",
|
||||
"value": "tool result"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
@@ -145,7 +237,7 @@ The dataset in **sharegpt** format should follow the below format:
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@@ -155,19 +247,63 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
"tools": "tools"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "from",
|
||||
"content_tag": "value",
|
||||
"user_tag": "human",
|
||||
"assistant_tag": "gpt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||
### Preference Dataset
|
||||
|
||||
We also supports the dataset in the **openai** format:
|
||||
- [Example dataset](dpo_en_demo.json)
|
||||
|
||||
Preference datasets in sharegpt format also require a better message in `chosen` column and a worse message in `rejected` column.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "human instruction"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "model response"
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "human instruction"
|
||||
}
|
||||
],
|
||||
"chosen": {
|
||||
"from": "gpt",
|
||||
"value": "chosen answer (required)"
|
||||
},
|
||||
"rejected": {
|
||||
"from": "gpt",
|
||||
"value": "rejected answer (required)"
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### OpenAI Format
|
||||
|
||||
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -179,7 +315,7 @@ We also supports the dataset in the **openai** format:
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "user instruction"
|
||||
"content": "human instruction"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -190,7 +326,7 @@ We also supports the dataset in the **openai** format:
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@@ -209,4 +345,6 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
}
|
||||
```
|
||||
|
||||
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet.
|
||||
The KTO datasets and multimodal datasets in sharegpt format are similar to the alpaca format.
|
||||
|
||||
Pre-training datasets are **incompatible** with the sharegpt format.
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
如果您使用自定义数据集,请务必按照以下格式在 `dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子。
|
||||
[dataset_info.json](dataset_info.json) 包含了所有可用的数据集。如果您希望使用自定义数据集,请**务必**在 `dataset_info.json` 文件中添加*数据集描述*,并通过修改 `dataset: 数据集名称` 配置来使用数据集。
|
||||
|
||||
目前我们支持 **alpaca** 格式和 **sharegpt** 格式的数据集。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
|
||||
"file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||
"columns(可选)": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
@@ -19,7 +21,10 @@
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)"
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)",
|
||||
"chosen": "数据集代表更优回答的表头名称(默认:None)",
|
||||
"rejected": "数据集代表更差回答的表头名称(默认:None)",
|
||||
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
||||
},
|
||||
"tags(可选,用于 sharegpt 格式)": {
|
||||
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||
@@ -28,22 +33,28 @@
|
||||
"assistant_tag": "消息中代表助手的 role_tag(默认:gpt)",
|
||||
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system column)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。
|
||||
## Alpaca 格式
|
||||
|
||||
----
|
||||
### 指令监督微调数据集
|
||||
|
||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
||||
- [样例数据集](alpaca_zh_demo.json)
|
||||
|
||||
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
|
||||
|
||||
如果指定,`system` 列对应的内容将被作为系统提示词。
|
||||
|
||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容**也会被用于模型学习**。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"system": "系统提示词(选填)",
|
||||
"history": [
|
||||
@@ -54,7 +65,7 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@@ -69,11 +80,11 @@
|
||||
}
|
||||
```
|
||||
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
### 预训练数据集
|
||||
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
|
||||
- [样例数据集](c4_demo.json)
|
||||
|
||||
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如:
|
||||
在预训练时,只有 `text` 列中的内容会用于模型学习。
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -82,7 +93,7 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@@ -93,22 +104,24 @@
|
||||
}
|
||||
```
|
||||
|
||||
对于**偏好数据集**,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
### 偏好数据集
|
||||
|
||||
偏好数据集用于奖励模型训练、DPO 训练和 ORPO 训练。
|
||||
|
||||
它需要在 `chosen` 列中提供更优的回答,并在 `rejected` 列中提供更差的回答。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"chosen": "优质回答(必填)",
|
||||
"rejected": "劣质回答(必填)"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@@ -117,14 +130,85 @@
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
----
|
||||
### KTO 数据集
|
||||
|
||||
而 **sharegpt** 格式的数据集按照以下方式组织:
|
||||
- [样例数据集](kto_en_demo.json)
|
||||
|
||||
KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人类反馈。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"kto_tag": "人类反馈 [true/false](必填)"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"kto_tag": "kto_tag"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 多模态数据集
|
||||
|
||||
- [样例数据集](mllm_demo.json)
|
||||
|
||||
多模态数据集需要额外添加一个 `images` 列,包含输入图像的路径。目前我们仅支持单张图像输入。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"images": [
|
||||
"图像路径(必填)"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"images": "images"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Sharegpt 格式
|
||||
|
||||
### 指令监督微调数据集
|
||||
|
||||
- [样例数据集](glaive_toolcall_zh_demo.json)
|
||||
|
||||
相比 alpaca 格式的数据集,sharegpt 格式支持**更多的角色种类**,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 `conversations` 列中。
|
||||
|
||||
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -132,7 +216,15 @@
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "用户指令"
|
||||
"value": "人类指令"
|
||||
},
|
||||
{
|
||||
"from": "function_call",
|
||||
"value": "工具参数"
|
||||
},
|
||||
{
|
||||
"from": "observation",
|
||||
"value": "工具结果"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
@@ -145,7 +237,7 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@@ -155,19 +247,63 @@
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
"tools": "tools"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "from",
|
||||
"content_tag": "value",
|
||||
"user_tag": "human",
|
||||
"assistant_tag": "gpt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
### 偏好数据集
|
||||
|
||||
我们同样支持 **openai** 格式的数据集:
|
||||
- [样例数据集](dpo_zh_demo.json)
|
||||
|
||||
Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的消息,并在 `rejected` 列中提供更差的消息。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "人类指令"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "模型回答"
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "人类指令"
|
||||
}
|
||||
],
|
||||
"chosen": {
|
||||
"from": "gpt",
|
||||
"value": "优质回答"
|
||||
},
|
||||
"rejected": {
|
||||
"from": "gpt",
|
||||
"value": "劣质回答"
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### OpenAI 格式
|
||||
|
||||
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -179,7 +315,7 @@
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "用户指令"
|
||||
"content": "人类指令"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -190,7 +326,7 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@@ -209,4 +345,6 @@
|
||||
}
|
||||
```
|
||||
|
||||
预训练数据集和偏好数据集**尚不支持** sharegpt 格式。
|
||||
Sharegpt 格式中的 KTO 数据集和多模态数据集与 alpaca 格式的类似。
|
||||
|
||||
预训练数据集**不支持** sharegpt 格式。
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
||||
@@ -1 +0,0 @@
|
||||
a97cf9475291591843976554878568e046d8a46d
|
||||
@@ -1 +0,0 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
||||
@@ -1 +0,0 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
||||
@@ -1 +0,0 @@
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
||||
@@ -1 +0,0 @@
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
||||
@@ -1,37 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
_URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
for key, example in enumerate(example_dataset):
|
||||
yield key, example
|
||||
@@ -1 +0,0 @@
|
||||
4748dff00d1dc42768a5b6cc772143c313017812
|
||||
@@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"chosen": datasets.Value("string"),
|
||||
"rejected": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
@@ -79,5 +80,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
|
||||
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
|
||||
key += 1
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
736bcedea2b24a1414765c6d69cbdafaea839f3c
|
||||
30
data/wiki_demo.txt
Normal file
30
data/wiki_demo.txt
Normal file
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
|
||||
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb
|
||||
@@ -10,8 +10,6 @@ services:
|
||||
- ./hf_cache:/root/.cache/huggingface/
|
||||
- ./data:/app/data
|
||||
- ./output:/app/output
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
ports:
|
||||
- "7860:7860"
|
||||
ipc: host
|
||||
|
||||
@@ -154,7 +154,7 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
df = pd.read_csv(filepath, header=None)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
|
||||
@@ -47,16 +47,16 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
||||
```
|
||||
|
||||
#### DPO Training
|
||||
#### DPO/ORPO/SimPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### ORPO Training
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### Preprocess Dataset
|
||||
@@ -107,22 +107,23 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
|
||||
|
||||
### LoRA Fine-Tuning on Multiple GPUs
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Single Node
|
||||
#### Supervised Fine-Tuning on Single Node
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/single_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/multi_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/ds_zero3.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
||||
```
|
||||
|
||||
### LoRA Fine-Tuning on Multiple NPUs
|
||||
@@ -130,27 +131,28 @@ bash examples/lora_multi_gpu/ds_zero3.sh
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_npu/ds_zero0.sh
|
||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
||||
```
|
||||
|
||||
### Full-Parameter Fine-Tuning on Multiple GPUs
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Single Node
|
||||
#### Supervised Fine-Tuning on Single Node
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/single_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/multi_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/predict.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
||||
```
|
||||
|
||||
### Merging LoRA Adapters and Quantization
|
||||
@@ -171,22 +173,24 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
|
||||
|
||||
### Inferring LoRA Fine-Tuned Models
|
||||
|
||||
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
|
||||
|
||||
#### Use CLI
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Use Web UI
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Launch OpenAI-style API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### Extras
|
||||
|
||||
@@ -47,16 +47,16 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
||||
```
|
||||
|
||||
#### DPO 训练
|
||||
#### DPO/ORPO/SimPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### ORPO 训练
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_orpo.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### 预处理数据集
|
||||
@@ -107,50 +107,52 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_l
|
||||
|
||||
### 多 GPU LoRA 微调
|
||||
|
||||
#### 使用 Accelerate 进行单节点训练
|
||||
#### 在单机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/single_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 Accelerate 进行多节点训练
|
||||
#### 在多机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/multi_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_gpu/ds_zero3.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
||||
```
|
||||
|
||||
### 多 NPU LoRA 微调
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-0 训练
|
||||
#### 使用 DeepSpeed ZeRO-0 进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/lora_multi_npu/ds_zero0.sh
|
||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
||||
```
|
||||
|
||||
### 多 GPU 全参数微调
|
||||
|
||||
#### 使用 DeepSpeed 进行单节点训练
|
||||
#### 在单机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/single_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed 进行多节点训练
|
||||
#### 在多机上进行指令监督微调
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/multi_node.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
```
|
||||
|
||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||
|
||||
```bash
|
||||
bash examples/full_multi_gpu/predict.sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
||||
```
|
||||
|
||||
### 合并 LoRA 适配器与模型量化
|
||||
@@ -171,22 +173,24 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.y
|
||||
|
||||
### 推理 LoRA 模型
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
|
||||
|
||||
#### 使用命令行接口
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用浏览器界面
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 启动 OpenAI 风格 API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/merge_lora/llama3_lora_sft.yaml
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
### 杂项
|
||||
|
||||
@@ -5,16 +5,16 @@ downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: true # offload may affect training speed
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_use_orig_params: true
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
mixed_precision: fp16 # or bf16
|
||||
num_machines: 1 # the number of nodes
|
||||
num_processes: 2 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_process_ip: 192.168.0.1
|
||||
main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,16 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1 # the number of nodes
|
||||
num_processes: 4 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,18 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 1
|
||||
main_process_ip: 192.168.0.1
|
||||
main_process_port: 29555
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 2 # the number of nodes
|
||||
num_processes: 8 # the number of GPUs in all nodes
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,41 +1,41 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_badam: true
|
||||
badam_switch_mode: descending
|
||||
badam_switch_mode: ascending
|
||||
badam_switch_interval: 50
|
||||
badam_verbose: 2
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
quantization_bit: 4
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# ddp
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
#!/bin/bash
|
||||
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA
|
||||
|
||||
pip install "transformers>=4.39.1"
|
||||
pip install "accelerate>=0.28.0"
|
||||
pip install "bitsandbytes>=0.43.0"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||
--config_file examples/accelerate/fsdp_config.yaml \
|
||||
src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
@@ -11,32 +11,32 @@ galore_target: mlp,self_attn
|
||||
galore_rank: 128
|
||||
galore_scale: 2.0
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: models/llama3-8b-instruct-pro
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: freeze
|
||||
@@ -9,32 +9,32 @@ freeze_trainable_layers: 8
|
||||
freeze_trainable_modules: all
|
||||
use_llama_pro: true
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b-instruct-pro/freeze/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
loraplus_lr_ratio: 16.0
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
mixture_of_depths: convert
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b-mod/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
optim: paged_adamw_8bit
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: saves/llama3-8b/full/sft
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_predict: true
|
||||
finetuning_type: full
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/predict
|
||||
overwrite_output_dir: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
per_device_eval_batch_size: 1
|
||||
predict_with_generate: true
|
||||
|
||||
@@ -1,41 +1,41 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
|
||||
# ddp
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=2
|
||||
RANK=0
|
||||
MASTER_ADDR=192.168.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/single_config.yaml \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
||||
@@ -1,41 +1,41 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# ddp
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,42 +1,42 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# ddp
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
# also launch it on slave machine using slave_config.yaml
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/master_config.yaml \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file examples/accelerate/single_config.yaml \
|
||||
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NPROC_PER_NODE=4
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ADDR=127.0.0.1
|
||||
MASTER_PORT=29500
|
||||
|
||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
|
||||
--nproc_per_node $NPROC_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
||||
@@ -1,42 +1,42 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# ddp
|
||||
### ddp
|
||||
ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,39 +1,40 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: dpo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
dpo_ftx: 1.0
|
||||
lora_target: all
|
||||
pref_beta: 0.1
|
||||
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
# dataset
|
||||
dataset: orca_rlhf
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/dpo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
|
||||
# method
|
||||
### method
|
||||
finetuning_type: lora
|
||||
|
||||
# dataset
|
||||
### dataset
|
||||
task: mmlu
|
||||
split: test
|
||||
template: fewshot
|
||||
lang: en
|
||||
n_shot: 5
|
||||
|
||||
# output
|
||||
### output
|
||||
save_dir: saves/llama3-8b/lora/eval
|
||||
|
||||
# eval
|
||||
### eval
|
||||
batch_size: 4
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
stage: orpo
|
||||
### method
|
||||
stage: kto
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: orca_rlhf
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
output_dir: saves/llama3-8b/lora/orpo
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/kto
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
reward_model: saves/llama3-8b/lora/reward
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: ppo
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/ppo
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# generate
|
||||
### generate
|
||||
max_new_tokens: 512
|
||||
top_k: 0
|
||||
top_p: 0.9
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_predict: true
|
||||
finetuning_type: lora
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/predict
|
||||
overwrite_output_dir: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
per_device_eval_batch_size: 1
|
||||
predict_with_generate: true
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: pt
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
### dataset
|
||||
dataset: c4_demo
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: rm
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: orca_rlhf
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/reward
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
@@ -16,6 +16,6 @@ overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
tokenized_path: saves/llama3-8b/dataset/sft
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
overwrite_output_dir: true
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||
visual_inputs: true
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
### dataset
|
||||
dataset: mllm_demo
|
||||
template: vicuna
|
||||
cutoff_len: 1024
|
||||
@@ -16,24 +16,24 @@ max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llava1_5-7b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
template: llama3
|
||||
|
||||
# export
|
||||
### export
|
||||
export_dir: models/llama3_gptq
|
||||
export_quantization_bit: 4
|
||||
export_quantization_dataset: data/c4_demo.json
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters
|
||||
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
template: llama3
|
||||
finetuning_type: lora
|
||||
|
||||
# export
|
||||
### export
|
||||
export_dir: models/llama3_lora_sft
|
||||
export_size: 2
|
||||
export_device: cpu
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
quantization_bit: 4
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -1,38 +1,38 @@
|
||||
# model
|
||||
### model
|
||||
model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
|
||||
|
||||
# method
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
lora_target: all
|
||||
|
||||
# dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
# output
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
# train
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
|
||||
# eval
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
|
||||
@@ -13,7 +13,7 @@ select = ["C", "E", "F", "I", "W"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["llmtuner"]
|
||||
known-first-party = ["llamafactory"]
|
||||
known-third-party = [
|
||||
"accelerate",
|
||||
"datasets",
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
transformers>=4.37.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.27.2
|
||||
peft>=0.10.0
|
||||
trl>=0.8.1
|
||||
transformers>=4.41.2
|
||||
datasets>=2.16.0
|
||||
accelerate>=0.30.1
|
||||
peft>=0.11.1
|
||||
trl>=0.8.6
|
||||
gradio>=4.0.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
tiktoken
|
||||
protobuf
|
||||
uvicorn
|
||||
pydantic
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner.chat import ChatModel
|
||||
from llamafactory.chat import ChatModel
|
||||
|
||||
|
||||
def calculate_flops(
|
||||
|
||||
@@ -12,10 +12,10 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_tokenizer
|
||||
from llamafactory.data import get_dataset
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||
|
||||
@@ -12,10 +12,10 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model, load_tokenizer
|
||||
from llamafactory.data import get_dataset
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -7,9 +7,9 @@ from collections import defaultdict
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_tokenizer
|
||||
from llamafactory.data import get_dataset
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
def length_cdf(
|
||||
|
||||
@@ -104,10 +104,10 @@ def block_expansion(
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
print("Fine-tune this model with:")
|
||||
print(" --model_name_or_path {} \\".format(output_dir))
|
||||
print(" --finetuning_type freeze \\")
|
||||
print(" --freeze_trainable_layers {} \\".format(num_expand))
|
||||
print(" --use_llama_pro")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("finetuning_type: freeze")
|
||||
print("freeze_trainable_layers: {}".format(num_expand))
|
||||
print("use_llama_pro: true")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -20,7 +20,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key="0",
|
||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||
)
|
||||
tools = [
|
||||
13
setup.py
13
setup.py
@@ -5,7 +5,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version():
|
||||
with open(os.path.join("src", "llmtuner", "cli.py"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
@@ -21,24 +21,25 @@ def get_requires():
|
||||
|
||||
extra_require = {
|
||||
"torch": ["torch>=1.13.1"],
|
||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"vllm": ["vllm>=0.4.0"],
|
||||
"vllm": ["vllm>=0.4.3"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"qwen": ["tiktoken", "transformers_stream_generator"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"quality": ["ruff"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
setup(
|
||||
name="llmtuner",
|
||||
name="llamafactory",
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
@@ -53,7 +54,7 @@ def main():
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]},
|
||||
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
|
||||
@@ -2,8 +2,8 @@ import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner.api.app import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llamafactory.api.app import create_app
|
||||
from llamafactory.chat import ChatModel
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
6
src/llamafactory/__init__.py
Normal file
6
src/llamafactory/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
||||
__version__ = VERSION
|
||||
@@ -1,10 +1,13 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_fastapi_available
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
@@ -25,7 +28,17 @@ if is_fastapi_available():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_requests_available():
|
||||
import requests
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
@@ -40,7 +53,9 @@ ROLE_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
@@ -49,12 +64,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
system = None
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
@@ -66,6 +82,21 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
elif isinstance(message.content, list):
|
||||
for input_item in message.content:
|
||||
if input_item.type == "text":
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
||||
else:
|
||||
image_url = input_item.image_url.url
|
||||
if image_url.startswith("data:image"): # base64 image
|
||||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
||||
image_path = io.BytesIO(image_data)
|
||||
elif os.path.isfile(image_url): # local file
|
||||
image_path = open(image_url, "rb")
|
||||
else: # web uri
|
||||
image_path = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
@@ -76,9 +107,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools
|
||||
return input_messages, system, tools, image
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
@@ -97,11 +128,12 @@ async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -145,7 +177,7 @@ async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
@@ -159,6 +191,7 @@ async def create_stream_chat_completion_response(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -56,9 +56,19 @@ class FunctionCall(BaseModel):
|
||||
function: Function
|
||||
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ImageURL] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@ import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -23,6 +24,9 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,47 +59,69 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" + messages[0]["content"]
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and template.image_token not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = template.image_token + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
|
||||
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature if temperature is not None else generating_args["temperature"],
|
||||
top_p=top_p if top_p is not None else generating_args["top_p"],
|
||||
top_k=top_k if top_k is not None else generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None
|
||||
else generating_args["repetition_penalty"],
|
||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
||||
generating_args["do_sample"] = True
|
||||
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
||||
|
||||
if not generating_args["temperature"]:
|
||||
generating_args["do_sample"] = False
|
||||
|
||||
if not generating_args["do_sample"]:
|
||||
generating_args.pop("temperature", None)
|
||||
@@ -111,14 +137,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
gen_kwargs["pixel_values"] = pixel_values.to(model.device)
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ if is_vllm_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@@ -36,8 +35,6 @@ class VllmEngine(BaseEngine):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
infer_dtype = str(infer_dtype).split(".")[-1]
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
@@ -51,7 +48,7 @@ class VllmEngine(BaseEngine):
|
||||
"model": model_args.model_name_or_path,
|
||||
"trust_remote_code": True,
|
||||
"download_dir": model_args.cache_dir,
|
||||
"dtype": infer_dtype,
|
||||
"dtype": model_args.vllm_dtype,
|
||||
"max_model_len": model_args.vllm_maxlen,
|
||||
"tensor_parallel_size": get_device_count() or 1,
|
||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||
@@ -59,6 +56,7 @@ class VllmEngine(BaseEngine):
|
||||
"disable_log_requests": True,
|
||||
"enforce_eager": model_args.vllm_enforce_eager,
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
|
||||
if model_args.visual_inputs:
|
||||
@@ -66,11 +64,10 @@ class VllmEngine(BaseEngine):
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
|
||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
@@ -91,27 +88,49 @@ class VllmEngine(BaseEngine):
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
|
||||
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
if (
|
||||
self.processor is not None
|
||||
and image is not None
|
||||
and not hasattr(self.processor, "image_seq_length")
|
||||
and self.template.image_token not in messages[0]["content"]
|
||||
): # llava-like models (TODO: paligemma models)
|
||||
messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search = self.generating_args["num_beams"] > 1
|
||||
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if "max_new_tokens" in self.generating_args:
|
||||
max_tokens = self.generating_args["max_new_tokens"]
|
||||
elif "max_length" in self.generating_args:
|
||||
if self.generating_args["max_length"] > prompt_length:
|
||||
max_tokens = self.generating_args["max_length"] - prompt_length
|
||||
else:
|
||||
max_tokens = 1
|
||||
|
||||
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
|
||||
if max_length:
|
||||
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
||||
|
||||
@@ -120,32 +139,26 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=(
|
||||
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
||||
)
|
||||
or 1.0, # repetition_penalty must > 0
|
||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None:
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
prompt=None,
|
||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=self.lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
return result_generator
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
@@ -23,8 +30,6 @@ USAGE = (
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
VERSION = "0.7.1"
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
@@ -37,11 +42,14 @@ WELCOME = (
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
API = "api"
|
||||
CHAT = "chat"
|
||||
ENV = "env"
|
||||
EVAL = "eval"
|
||||
EXPORT = "export"
|
||||
TRAIN = "train"
|
||||
@@ -57,11 +65,34 @@ def main():
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
run_chat()
|
||||
elif command == Command.ENV:
|
||||
print_env()
|
||||
elif command == Command.EVAL:
|
||||
run_eval()
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
).format(
|
||||
nnodes=os.environ.get("NNODES", "1"),
|
||||
node_rank=os.environ.get("RANK", "0"),
|
||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
16
src/llamafactory/data/__init__.py
Normal file
16
src/llamafactory/data/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
"TEMPLATES",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
]
|
||||
@@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from .utils import Role
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,7 +15,13 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
@@ -29,6 +36,9 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
@@ -45,21 +55,32 @@ def convert_alpaca(
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
@@ -68,6 +89,9 @@ def convert_alpaca(
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
@@ -87,21 +111,62 @@ def convert_sharegpt(
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
81
src/llamafactory/data/collator.py
Normal file
81
src/llamafactory/data/collator.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
return super().__call__(concatenated_features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
if "token_type_ids" in batch:
|
||||
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
|
||||
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -1,17 +1,19 @@
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -57,12 +59,12 @@ def load_single_dataset(
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
@@ -105,9 +107,21 @@ def load_single_dataset(
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
@@ -116,7 +130,7 @@ def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
@@ -165,14 +179,17 @@ def get_dataset(
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
|
||||
exit(0)
|
||||
sys.exit(0)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
else:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
@@ -20,23 +20,28 @@ class DatasetAttr:
|
||||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
""" rlhf columns """
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
""" sharegpt columns """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
""" sharegpt tags """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
@@ -98,17 +103,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "images"]
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
84
src/llamafactory/data/preprocess.py
Normal file
84
src/llamafactory/data/preprocess.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||
|
||||
from .processors.feedback import preprocess_feedback_dataset
|
||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
|
||||
from .processors.pretrain import preprocess_pretrain_dataset
|
||||
from .processors.supervised import (
|
||||
preprocess_packed_supervised_dataset,
|
||||
preprocess_supervised_dataset,
|
||||
print_supervised_dataset_example,
|
||||
)
|
||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(
|
||||
preprocess_pretrain_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
data_args=data_args,
|
||||
)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_supervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "kto":
|
||||
preprocess_func = partial(
|
||||
preprocess_feedback_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
126
src/llamafactory/data/processors/feedback.py
Normal file
126
src/llamafactory/data/processors/feedback.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = prompt + [response[1]]
|
||||
|
||||
if kl_response[0]["content"]:
|
||||
kl_messages = prompt + [kl_response[0]]
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
def preprocess_feedback_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
model_inputs["kl_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
123
src/llamafactory/data/processors/pairwise.py
Normal file
123
src/llamafactory/data/processors/pairwise.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
def preprocess_pairwise_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {
|
||||
"chosen_input_ids": [],
|
||||
"chosen_attention_mask": [],
|
||||
"chosen_labels": [],
|
||||
"rejected_input_ids": [],
|
||||
"rejected_attention_mask": [],
|
||||
"rejected_labels": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"] = []
|
||||
model_inputs["rejected_token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["chosen_input_ids"].append(chosen_input_ids)
|
||||
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
|
||||
model_inputs["chosen_labels"].append(chosen_labels)
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
||||
)
|
||||
model_inputs["rejected_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
||||
36
src/llamafactory/data/processors/pretrain.py
Normal file
36
src/llamafactory/data/processors/pretrain.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
return result
|
||||
64
src/llamafactory/data/processors/processor_utils.py
Normal file
64
src/llamafactory/data/processors/processor_utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
r"""
|
||||
Finds the index of largest number that fits into the knapsack with the given capacity.
|
||||
"""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
|
||||
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
r"""
|
||||
An efficient greedy algorithm with binary search for the knapsack problem.
|
||||
"""
|
||||
numbers.sort() # sort numbers in ascending order for binary search
|
||||
knapsacks = []
|
||||
|
||||
while numbers:
|
||||
current_knapsack = []
|
||||
remaining_capacity = capacity
|
||||
|
||||
while True:
|
||||
index = search_for_fit(numbers, remaining_capacity)
|
||||
if index == -1:
|
||||
break # no more numbers fit in this knapsack
|
||||
|
||||
remaining_capacity -= numbers[index] # update the remaining capacity
|
||||
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
|
||||
|
||||
knapsacks.append(current_knapsack)
|
||||
|
||||
return knapsacks
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
169
src/llamafactory/data/processors/supervised.py
Normal file
169
src/llamafactory/data/processors/supervised.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels = [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
data_args=data_args,
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user