Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0bff18324 | ||
|
|
b631bdc5b7 | ||
|
|
c65f7e9bd5 | ||
|
|
3e0fa4a8da | ||
|
|
235ed85b0f | ||
|
|
1ca639a777 | ||
|
|
e36a994fe6 | ||
|
|
19ffcfea76 | ||
|
|
85f3a09c83 | ||
|
|
60b9a9c1fa | ||
|
|
984e38575c | ||
|
|
665df5d733 | ||
|
|
4bc0bea0e9 | ||
|
|
5cfa342f01 | ||
|
|
c106cc24e4 | ||
|
|
372da52d4a | ||
|
|
875270b851 | ||
|
|
43fab306b6 | ||
|
|
77242f4169 | ||
|
|
60d9896a70 | ||
|
|
485a80d294 | ||
|
|
63bfe9967e | ||
|
|
a720b82e63 | ||
|
|
d3b0048d8c | ||
|
|
9a0aca42a5 | ||
|
|
5e802b0645 | ||
|
|
ca67b7a568 | ||
|
|
76cd879c84 | ||
|
|
e0c049e590 | ||
|
|
727943f078 | ||
|
|
8393b08666 | ||
|
|
9049f72d2f | ||
|
|
32f45c9e91 | ||
|
|
05f3a3c944 | ||
|
|
14f7bfc545 | ||
|
|
7f90b0cd20 | ||
|
|
308abfec6c | ||
|
|
bb88536166 | ||
|
|
d2df3f2d6e | ||
|
|
2abfad9c1f | ||
|
|
2af932d969 | ||
|
|
c29fa61a9c | ||
|
|
a30931fe0f | ||
|
|
3ff9b87012 | ||
|
|
f4f315fd11 | ||
|
|
530165d9a5 | ||
|
|
dbd1458adf | ||
|
|
dedefecd2b | ||
|
|
46f441dd37 | ||
|
|
49b58fd6af | ||
|
|
103a507b39 | ||
|
|
0a75224f62 | ||
|
|
04d7629abf | ||
|
|
1b6786a21f | ||
|
|
5080f2314c | ||
|
|
41beb7f0a3 | ||
|
|
799873aa14 | ||
|
|
fe2c7eaa93 | ||
|
|
6392d45ea7 | ||
|
|
c60ea675d7 | ||
|
|
16c7c92396 | ||
|
|
7598b37543 | ||
|
|
cc9717e2f2 | ||
|
|
08f2f99f4b | ||
|
|
77bf3d66c7 | ||
|
|
f14f67f803 | ||
|
|
820b6e7b32 | ||
|
|
27aece94cf | ||
|
|
3f2508be92 | ||
|
|
fce11bb386 | ||
|
|
2723438531 | ||
|
|
f330b73682 | ||
|
|
bc04ca464a | ||
|
|
44829df762 | ||
|
|
94ddfa66c0 | ||
|
|
8db8ed5a41 |
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -38,7 +38,9 @@ body:
|
|||||||
请合理使用 Markdown 标签来格式化您的文本。
|
请合理使用 Markdown 标签来格式化您的文本。
|
||||||
|
|
||||||
placeholder: |
|
placeholder: |
|
||||||
|
```bash
|
||||||
llamafactory-cli train ...
|
llamafactory-cli train ...
|
||||||
|
```
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: expected-behavior
|
id: expected-behavior
|
||||||
|
|||||||
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -5,3 +5,4 @@ Fixes # (issue)
|
|||||||
## Before submitting
|
## Before submitting
|
||||||
|
|
||||||
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
|
||||||
|
- [ ] Did you write any new necessary tests?
|
||||||
|
|||||||
40
.github/workflows/publish.yml
vendored
Normal file
40
.github/workflows/publish.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: publish
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types:
|
||||||
|
- published
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish:
|
||||||
|
name: Upload release to PyPI
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
environment:
|
||||||
|
name: release
|
||||||
|
url: https://pypi.org/p/llamafactory
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.8"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install build
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -9,8 +9,6 @@ on:
|
|||||||
- "requirements.txt"
|
- "requirements.txt"
|
||||||
- ".github/workflows/*.yml"
|
- ".github/workflows/*.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
types:
|
|
||||||
- review_requested
|
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
paths:
|
paths:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
pip install -e .[$EXTRA_PACKAGES] && \
|
pip install -e .[$EXTRA_PACKAGES] && \
|
||||||
pip uninstall -y transformer-engine
|
pip uninstall -y transformer-engine flash-attn
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||||
@@ -42,3 +42,6 @@ EXPOSE 7860
|
|||||||
|
|
||||||
# Expose port 8000 for the API service
|
# Expose port 8000 for the API service
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Launch LLaMA Board
|
||||||
|
CMD [ "llamafactory-cli", "webui" ]
|
||||||
|
|||||||
1
MANIFEST.in
Normal file
1
MANIFEST.in
Normal file
@@ -0,0 +1 @@
|
|||||||
|
include LICENSE requirements.txt
|
||||||
2
Makefile
2
Makefile
@@ -11,4 +11,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
pytest tests/
|
CUDA_VISIBLE_DEVICES= pytest tests/
|
||||||
|
|||||||
83
README.md
83
README.md
@@ -49,7 +49,7 @@ Choose your path:
|
|||||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
@@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
|
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
|
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||||
|
|
||||||
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||||
|
|
||||||
@@ -151,35 +151,35 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
| --------------------------------------------------------- | -------------------------------- | --------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
@@ -259,6 +259,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
@@ -405,7 +408,7 @@ Please refer to [data/README.md](data/README.md) for checking the details about
|
|||||||
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
@@ -423,6 +426,8 @@ llamafactory-cli webui
|
|||||||
|
|
||||||
### Build Docker
|
### Build Docker
|
||||||
|
|
||||||
|
#### Use Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile \
|
docker build -f ./Dockerfile \
|
||||||
--build-arg INSTALL_BNB=false \
|
--build-arg INSTALL_BNB=false \
|
||||||
@@ -442,8 +447,12 @@ docker run -it --gpus=all \
|
|||||||
llamafactory:latest
|
llamafactory:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
#### Use Docker Compose
|
||||||
> Use Docker Compose to build image via `docker compose up -d`.
|
|
||||||
|
```bash
|
||||||
|
docker-compose up -d
|
||||||
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
<details><summary>Details about volume</summary>
|
<details><summary>Details about volume</summary>
|
||||||
|
|
||||||
@@ -456,7 +465,7 @@ docker run -it --gpus=all \
|
|||||||
### Deploy with OpenAI-style API and vLLM
|
### Deploy with OpenAI-style API and vLLM
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -474,7 +483,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam
|
|||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
report_to: wandb
|
report_to: wandb
|
||||||
|
|||||||
83
README_zh.md
83
README_zh.md
@@ -49,7 +49,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
@@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。
|
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
|
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||||
|
|
||||||
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||||
|
|
||||||
@@ -151,35 +151,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 模型大小 | Template |
|
||||||
| -------------------------------------------------------- | -------------------------------- | --------- |
|
| --------------------------------------------------------- | -------------------------------- | --------- |
|
||||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
@@ -259,6 +259,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||||
|
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
|
||||||
|
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
|
||||||
|
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
|
||||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||||
@@ -405,7 +408,7 @@ Docker 镜像:
|
|||||||
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
@@ -423,6 +426,8 @@ llamafactory-cli webui
|
|||||||
|
|
||||||
### 构建 Docker
|
### 构建 Docker
|
||||||
|
|
||||||
|
#### 使用 Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -f ./Dockerfile \
|
docker build -f ./Dockerfile \
|
||||||
--build-arg INSTALL_BNB=false \
|
--build-arg INSTALL_BNB=false \
|
||||||
@@ -442,8 +447,12 @@ docker run -it --gpus=all \
|
|||||||
llamafactory:latest
|
llamafactory:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
#### 使用 Docker Compose
|
||||||
> 通过 `docker compose up -d` 使用 Docker Compose 构建镜像。
|
|
||||||
|
```bash
|
||||||
|
docker-compose up -d
|
||||||
|
docker-compose exec llamafactory bash
|
||||||
|
```
|
||||||
|
|
||||||
<details><summary>数据卷详情</summary>
|
<details><summary>数据卷详情</summary>
|
||||||
|
|
||||||
@@ -456,7 +465,7 @@ docker run -it --gpus=all \
|
|||||||
### 利用 vLLM 部署 OpenAI API
|
### 利用 vLLM 部署 OpenAI API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -474,7 +483,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
report_to: wandb
|
report_to: wandb
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
llamafactory:
|
llamafactory:
|
||||||
build:
|
build:
|
||||||
@@ -19,6 +17,9 @@ services:
|
|||||||
- "7860:7860"
|
- "7860:7860"
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
ipc: host
|
ipc: host
|
||||||
|
tty: true
|
||||||
|
stdin_open: true
|
||||||
|
command: bash
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
reservations:
|
reservations:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
|||||||
@@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
|
|||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu)
|
- [LoRA Fine-Tuning](#lora-fine-tuning)
|
||||||
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu)
|
- [QLoRA Fine-Tuning](#qlora-fine-tuning)
|
||||||
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus)
|
- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
|
||||||
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
|
|
||||||
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
|
|
||||||
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
|
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
|
||||||
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
|
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
|
||||||
- [Extras](#extras)
|
- [Extras](#extras)
|
||||||
|
|
||||||
|
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### LoRA Fine-Tuning on A Single GPU
|
### LoRA Fine-Tuning
|
||||||
|
|
||||||
#### (Continuous) Pre-Training
|
#### (Continuous) Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning
|
#### Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multimodal Supervised Fine-Tuning
|
#### Multimodal Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Reward Modeling
|
#### Reward Modeling
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### PPO Training
|
#### PPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO Training
|
#### DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### KTO Training
|
#### KTO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Preprocess Dataset
|
#### Preprocess Dataset
|
||||||
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||||||
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||||
```
|
|
||||||
|
|
||||||
### QLoRA Fine-Tuning on a Single GPU
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### LoRA Fine-Tuning on Multiple GPUs
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Single Node
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### LoRA Fine-Tuning on Multiple NPUs
|
### QLoRA Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
|
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Full-Parameter Fine-Tuning on Multiple GPUs
|
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Full-Parameter Fine-Tuning
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Single Node
|
#### Supervised Fine-Tuning on Single Node
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Merging LoRA Adapters and Quantization
|
### Merging LoRA Adapters and Quantization
|
||||||
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
|
|||||||
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Quantizing Model using AutoGPTQ
|
#### Quantizing Model using AutoGPTQ
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Inferring LoRA Fine-Tuned Models
|
### Inferring LoRA Fine-Tuned Models
|
||||||
|
|
||||||
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
|
|
||||||
|
|
||||||
#### Use CLI
|
#### Use CLI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Web UI
|
#### Use Web UI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launch OpenAI-style API
|
#### Launch OpenAI-style API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Extras
|
### Extras
|
||||||
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
|||||||
#### Full-Parameter Fine-Tuning using GaLore
|
#### Full-Parameter Fine-Tuning using GaLore
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Full-Parameter Fine-Tuning using BAdam
|
#### Full-Parameter Fine-Tuning using BAdam
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LoRA+ Fine-Tuning
|
#### LoRA+ Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PiSSA Fine-Tuning
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Mixture-of-Depths Fine-Tuning
|
#### Mixture-of-Depths Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LLaMA-Pro Fine-Tuning
|
#### LLaMA-Pro Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/llama_pro/expand.sh
|
bash examples/extras/llama_pro/expand.sh
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### FSDP+QLoRA Fine-Tuning
|
#### FSDP+QLoRA Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/single_node.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -4,59 +4,59 @@
|
|||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [单 GPU LoRA 微调](#单-gpu-lora-微调)
|
- [LoRA 微调](#lora-微调)
|
||||||
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调)
|
- [QLoRA 微调](#qlora-微调)
|
||||||
- [多 GPU LoRA 微调](#多-gpu-lora-微调)
|
- [全参数微调](#全参数微调)
|
||||||
- [多 NPU LoRA 微调](#多-npu-lora-微调)
|
|
||||||
- [多 GPU 全参数微调](#多-gpu-全参数微调)
|
|
||||||
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
|
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
|
||||||
- [推理 LoRA 模型](#推理-lora-模型)
|
- [推理 LoRA 模型](#推理-lora-模型)
|
||||||
- [杂项](#杂项)
|
- [杂项](#杂项)
|
||||||
|
|
||||||
|
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
|
|
||||||
### 单 GPU LoRA 微调
|
### LoRA 微调
|
||||||
|
|
||||||
#### (增量)预训练
|
#### (增量)预训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 指令监督微调
|
#### 指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 多模态指令监督微调
|
#### 多模态指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 奖励模型训练
|
#### 奖励模型训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### PPO 训练
|
#### PPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO 训练
|
#### DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### KTO 训练
|
#### KTO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 预处理数据集
|
#### 预处理数据集
|
||||||
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
|||||||
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
|
llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 在 MMLU/CMMLU/C-Eval 上评估
|
#### 在 MMLU/CMMLU/C-Eval 上评估
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 单 GPU QLoRA 微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
|
|
||||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### 多 GPU LoRA 微调
|
|
||||||
|
|
||||||
#### 在单机上进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 在多机上进行指令监督微调
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 多 NPU LoRA 微调
|
### QLoRA 微调
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-0 进行指令监督微调
|
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 多 GPU 全参数微调
|
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 基于 4 比特 AWQ 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 基于 2 比特 AQLM 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 全参数微调
|
||||||
|
|
||||||
#### 在单机上进行指令监督微调
|
#### 在单机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 在多机上进行指令监督微调
|
#### 在多机上进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
|
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 合并 LoRA 适配器与模型量化
|
### 合并 LoRA 适配器与模型量化
|
||||||
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
|
|||||||
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 AutoGPTQ 量化模型
|
#### 使用 AutoGPTQ 量化模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 推理 LoRA 模型
|
### 推理 LoRA 模型
|
||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
|
|
||||||
|
|
||||||
#### 使用命令行接口
|
#### 使用命令行接口
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用浏览器界面
|
#### 使用浏览器界面
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 启动 OpenAI 风格 API
|
#### 启动 OpenAI 风格 API
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli api examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 杂项
|
### 杂项
|
||||||
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
|
|||||||
#### 使用 GaLore 进行全参数训练
|
#### 使用 GaLore 进行全参数训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 BAdam 进行全参数训练
|
#### 使用 BAdam 进行全参数训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LoRA+ 微调
|
#### LoRA+ 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PiSSA 微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 深度混合微调
|
#### 深度混合微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### LLaMA-Pro 微调
|
#### LLaMA-Pro 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/llama_pro/expand.sh
|
bash examples/extras/llama_pro/expand.sh
|
||||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### FSDP+QLoRA 微调
|
#### FSDP+QLoRA 微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/single_node.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ do_train: true
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
template: llama3
|
template: llama3
|
||||||
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
pure_bf16: true
|
pure_bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ stage: sft
|
|||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
pissa_init: true
|
||||||
### ddp
|
pissa_iter: 4
|
||||||
ddp_timeout: 180000000
|
pissa_convert: true
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: identity,alpaca_en_demo
|
dataset: identity,alpaca_en_demo
|
||||||
@@ -27,12 +27,13 @@ overwrite_output_dir: true
|
|||||||
|
|
||||||
### train
|
### train
|
||||||
per_device_train_batch_size: 1
|
per_device_train_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 8
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
num_train_epochs: 3.0
|
num_train_epochs: 3.0
|
||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -5,9 +5,6 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
stage: sft
|
stage: sft
|
||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: full
|
finetuning_type: full
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
@@ -33,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -32,6 +32,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -6,6 +6,7 @@ stage: kto
|
|||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
pref_beta: 0.1
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: kto_en_demo
|
dataset: kto_en_demo
|
||||||
@@ -30,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### generate
|
### generate
|
||||||
max_new_tokens: 512
|
max_new_tokens: 512
|
||||||
@@ -22,3 +22,4 @@ overwrite_output_dir: true
|
|||||||
### eval
|
### eval
|
||||||
per_device_eval_batch_size: 1
|
per_device_eval_batch_size: 1
|
||||||
predict_with_generate: true
|
predict_with_generate: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
@@ -29,6 +29,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -30,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -30,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -6,9 +6,6 @@ stage: sft
|
|||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -6,9 +6,6 @@ stage: sft
|
|||||||
do_train: true
|
do_train: true
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### ddp
|
|
||||||
ddp_timeout: 180000000
|
|
||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -30,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -30,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -30,6 +30,7 @@ num_train_epochs: 3.0
|
|||||||
lr_scheduler_type: cosine
|
lr_scheduler_type: cosine
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
fp16: true
|
fp16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
### eval
|
### eval
|
||||||
val_size: 0.1
|
val_size: 0.1
|
||||||
@@ -4,6 +4,7 @@ accelerate>=0.30.1
|
|||||||
peft>=0.11.1
|
peft>=0.11.1
|
||||||
trl>=0.8.6
|
trl>=0.8.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0
|
||||||
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|||||||
@@ -1,7 +1,20 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the flops of pre-trained models.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
#
|
||||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
|
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -17,6 +30,10 @@ def calculate_flops(
|
|||||||
seq_length: int = 256,
|
seq_length: int = 256,
|
||||||
flash_attn: str = "auto",
|
flash_attn: str = "auto",
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the flops of pre-trained models.
|
||||||
|
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||||
|
"""
|
||||||
with get_accelerator().device(0):
|
with get_accelerator().device(0):
|
||||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||||
|
|||||||
@@ -1,7 +1,20 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
#
|
||||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
|
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@@ -32,6 +45,10 @@ def calculate_lr(
|
|||||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||||
|
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage=stage,
|
stage=stage,
|
||||||
|
|||||||
@@ -1,6 +1,17 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the ppl on the dataset of the pre-trained models.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -56,6 +67,10 @@ def cal_ppl(
|
|||||||
max_samples: Optional[int] = None,
|
max_samples: Optional[int] = None,
|
||||||
train_on_prompt: bool = False,
|
train_on_prompt: bool = False,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the ppl on the dataset of the pre-trained models.
|
||||||
|
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage=stage,
|
stage=stage,
|
||||||
|
|||||||
@@ -1,6 +1,17 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the distribution of the input lengths in the dataset.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@@ -19,6 +30,10 @@ def length_cdf(
|
|||||||
template: str = "default",
|
template: str = "default",
|
||||||
interval: int = 1000,
|
interval: int = 1000,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Calculates the distribution of the input lengths in the dataset.
|
||||||
|
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||||
|
"""
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(
|
model_args, data_args, training_args, _, _ = get_train_args(
|
||||||
dict(
|
dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
|
|||||||
@@ -1,7 +1,20 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
#
|
||||||
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
|
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -37,6 +50,10 @@ def block_expansion(
|
|||||||
shard_size: Optional[str] = "2GB",
|
shard_size: Optional[str] = "2GB",
|
||||||
save_safetensors: Optional[bool] = False,
|
save_safetensors: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||||
|
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||||
|
"""
|
||||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
num_layers = getattr(config, "num_hidden_layers")
|
num_layers = getattr(config, "num_hidden_layers")
|
||||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||||
@@ -103,7 +120,7 @@ def block_expansion(
|
|||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
print("Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print("freeze_trainable_layers: {}".format(num_expand))
|
||||||
|
|||||||
@@ -1,8 +1,17 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
#
|
||||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
|
|||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||||
|
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
|
||||||
|
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,7 +1,17 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Converts the Qwen models in the same format as LLaMA2.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
#
|
||||||
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Converts the Qwen models in the same format as LLaMA2.
|
||||||
|
Usage: python llamafy_qwen.py --input_dir input --output_dir output
|
||||||
|
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
os.makedirs(output_dir, exist_ok=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,14 +1,25 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
|
#
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@@ -17,38 +28,21 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
class Shell(nn.Module):
|
|
||||||
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias = nn.Parameter(bias, requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
|
|
||||||
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
|
|
||||||
parent_name = ".".join(name.split(".")[:-1])
|
|
||||||
child_name = name.split(".")[-1]
|
|
||||||
parent_module = model.get_submodule(parent_name)
|
|
||||||
child_module = getattr(parent_module, child_name)
|
|
||||||
base_layer = getattr(child_module, "base_layer")
|
|
||||||
weight = getattr(base_layer, "weight", None)
|
|
||||||
bias = getattr(base_layer, "bias", None)
|
|
||||||
setattr(parent_module, child_name, Shell(weight, bias))
|
|
||||||
|
|
||||||
print("Model unwrapped.")
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_loftq(
|
def quantize_loftq(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_dir: str,
|
output_dir: str,
|
||||||
loftq_bits: Optional[int] = 4,
|
loftq_bits: int = 4,
|
||||||
loftq_iter: Optional[int] = 1,
|
loftq_iter: int = 4,
|
||||||
lora_alpha: Optional[int] = None,
|
lora_alpha: int = None,
|
||||||
lora_rank: Optional[int] = 16,
|
lora_rank: int = 16,
|
||||||
lora_target: Optional[str] = "q_proj,v_proj",
|
lora_dropout: float = 0,
|
||||||
save_safetensors: Optional[bool] = False,
|
lora_target: str = "q_proj,v_proj",
|
||||||
|
save_safetensors: bool = True,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||||
|
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
|
"""
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||||
@@ -57,25 +51,34 @@ def quantize_loftq(
|
|||||||
inference_mode=True,
|
inference_mode=True,
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=0.1,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||||
init_lora_weights="loftq",
|
init_lora_weights="loftq",
|
||||||
loftq_config=loftq_config,
|
loftq_config=loftq_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init LoftQ model
|
# Init LoftQ model
|
||||||
lora_model = get_peft_model(model, lora_config)
|
print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
|
||||||
base_model: "PreTrainedModel" = lora_model.get_base_model()
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
loftq_dir = os.path.join(output_dir, "loftq_init")
|
||||||
|
|
||||||
# Save LoftQ model
|
# Save LoftQ model
|
||||||
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
|
||||||
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
|
print("Adapter weights saved in {}".format(loftq_dir))
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
unwrap_model(base_model)
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(save_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
print("- Fine-tune this model with:")
|
||||||
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
|
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||||
|
print("finetuning_type: lora")
|
||||||
|
print("quantization_bit: {}".format(loftq_bits))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
82
scripts/pissa_init.py
Normal file
82
scripts/pissa_init.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_pissa(
|
||||||
|
model_name_or_path: str,
|
||||||
|
output_dir: str,
|
||||||
|
pissa_iter: int = 4,
|
||||||
|
lora_alpha: int = None,
|
||||||
|
lora_rank: int = 16,
|
||||||
|
lora_dropout: float = 0,
|
||||||
|
lora_target: str = "q_proj,v_proj",
|
||||||
|
save_safetensors: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||||
|
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||||
|
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init PiSSA model
|
||||||
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
pissa_dir = os.path.join(output_dir, "pissa_init")
|
||||||
|
|
||||||
|
# Save PiSSA model
|
||||||
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
|
print("Adapter weights saved in {}".format(pissa_dir))
|
||||||
|
|
||||||
|
# Save base model
|
||||||
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
print("Model weights saved in {}".format(output_dir))
|
||||||
|
|
||||||
|
print("- Fine-tune this model with:")
|
||||||
|
print("model_name_or_path: {}".format(output_dir))
|
||||||
|
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||||
|
print("finetuning_type: lora")
|
||||||
|
print("pissa_init: false")
|
||||||
|
print("pissa_convert: true")
|
||||||
|
print("- and optionally with:")
|
||||||
|
print("quantization_bit: 4")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(quantize_pissa)
|
||||||
@@ -1,3 +1,18 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|||||||
16
setup.py
16
setup.py
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -23,7 +37,7 @@ extra_require = {
|
|||||||
"torch": ["torch>=1.13.1"],
|
"torch": ["torch>=1.13.1"],
|
||||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
|
"deepspeed": ["deepspeed>=0.10.0"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"vllm": ["vllm>=0.4.3"],
|
"vllm": ["vllm>=0.4.3"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
|
|||||||
14
src/api.py
14
src/api.py
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
from .cli import VERSION
|
from .cli import VERSION
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@@ -78,9 +92,11 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
name = message.tool_calls[0].function.name
|
tool_calls = [
|
||||||
arguments = message.tool_calls[0].function.arguments
|
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
|
||||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
for input_item in message.content:
|
for input_item in message.content:
|
||||||
@@ -104,7 +120,7 @@ def _process_request(
|
|||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
try:
|
try:
|
||||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||||
except Exception:
|
except json.JSONDecodeError:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
@@ -146,15 +162,17 @@ async def create_chat_completion_response(
|
|||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
if tools:
|
if tools:
|
||||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
result = chat_model.engine.template.extract_tool(response.response_text)
|
||||||
else:
|
else:
|
||||||
result = response.response_text
|
result = response.response_text
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, list):
|
||||||
name, arguments = result
|
tool_calls = []
|
||||||
function = Function(name=name, arguments=arguments)
|
for tool in result:
|
||||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||||
|
|
||||||
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
else:
|
else:
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .base_engine import BaseEngine
|
from .base_engine import BaseEngine
|
||||||
from .chat_model import ChatModel
|
from .chat_model import ChatModel
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||||
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def start(
|
|
||||||
self,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,3 +1,20 @@
|
|||||||
|
# Copyright 2024 THUDM and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the THUDM's ChatGLM implementation.
|
||||||
|
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
|
|||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
||||||
@@ -32,7 +49,6 @@ class ChatModel:
|
|||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import os
|
import os
|
||||||
@@ -45,6 +59,14 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||||
) # must after fixing tokenizer to resize vocab
|
) # must after fixing tokenizer to resize vocab
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
try:
|
||||||
|
asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning("There is no current event loop, creating a new one.")
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
image,
|
image,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
@@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
image,
|
image,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
stream = self._stream_chat(*input_args)
|
stream = self._stream_chat(*input_args)
|
||||||
while True:
|
while True:
|
||||||
@@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
|
||||||
async with self._semaphore:
|
async with self.semaphore:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
return await loop.run_in_executor(pool, self._get_scores, *input_args)
|
||||||
|
|||||||
@@ -1,10 +1,24 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_vllm_available
|
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
@@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
|
|||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import MultiModalData
|
|
||||||
|
if is_vllm_version_greater_than_0_5():
|
||||||
|
from vllm.multimodal.image import ImagePixelData
|
||||||
|
else:
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -48,7 +66,7 @@ class VllmEngine(BaseEngine):
|
|||||||
"model": model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"download_dir": model_args.cache_dir,
|
"download_dir": model_args.cache_dir,
|
||||||
"dtype": model_args.vllm_dtype,
|
"dtype": model_args.infer_dtype,
|
||||||
"max_model_len": model_args.vllm_maxlen,
|
"max_model_len": model_args.vllm_maxlen,
|
||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||||
@@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
|
|||||||
if self.processor is not None and image is not None: # add image features
|
if self.processor is not None and image is not None: # add image features
|
||||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
if is_vllm_version_greater_than_0_5():
|
||||||
|
multi_modal_data = ImagePixelData(image=pixel_values)
|
||||||
|
else: # TODO: remove vllm 0.4.3 support
|
||||||
|
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
@@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||||
from .data_utils import Role, split_dataset
|
from .data_utils import Role, split_dataset
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
@@ -10,6 +24,7 @@ from .data_utils import Role
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
@@ -175,7 +190,10 @@ def convert_sharegpt(
|
|||||||
|
|
||||||
|
|
||||||
def align_dataset(
|
def align_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
r"""
|
r"""
|
||||||
Aligned dataset:
|
Aligned dataset:
|
||||||
@@ -208,7 +226,7 @@ def align_dataset(
|
|||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Converting format of dataset",
|
desc="Converting format of dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -8,21 +22,23 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
|
|||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
JSON_FORMAT_PROMPT = (
|
DEFAULT_TOOL_PROMPT = (
|
||||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TOOL_SYSTEM_PROMPT = (
|
|
||||||
"You have access to the following tools:\n{tool_text}"
|
"You have access to the following tools:\n{tool_text}"
|
||||||
"Use the following format if using a tool:\n"
|
"Use the following format if using a tool:\n"
|
||||||
"```\n"
|
"```\n"
|
||||||
"Action: tool name (one of [{tool_names}]).\n"
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
"Action Input: the input to the tool{format_prompt}.\n"
|
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||||
|
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||||
"```\n"
|
"```\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GLM4_TOOL_PROMPT = (
|
||||||
|
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
tool_text = ""
|
tool_text = ""
|
||||||
tool_names = []
|
tool_names = []
|
||||||
@@ -48,36 +64,60 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|||||||
)
|
)
|
||||||
tool_names.append(tool["name"])
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
return TOOL_SYSTEM_PROMPT.format(
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||||
action_match = re.search(regex, content)
|
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||||
if not action_match:
|
if not action_match:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
tool_name = action_match.group(1).strip()
|
results = []
|
||||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
for match in action_match:
|
||||||
|
tool_name = match[0].strip()
|
||||||
|
tool_input = match[1].strip().strip('"').strip("```")
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tool_input)
|
||||||
|
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||||
|
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
|
||||||
|
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
if "\n" not in content:
|
||||||
|
return content
|
||||||
|
|
||||||
|
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(tool_input)
|
arguments = json.loads(tool_input)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[Literal["default"]] = None
|
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS: ...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -140,22 +180,28 @@ class FunctionFormatter(Formatter):
|
|||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
|
functions: List[Tuple[str, str]] = []
|
||||||
try:
|
try:
|
||||||
function = json.loads(content)
|
tool_calls = json.loads(content)
|
||||||
name = function["name"]
|
if not isinstance(tool_calls, list): # parallel function call
|
||||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
tool_calls = [tool_calls]
|
||||||
except Exception:
|
|
||||||
name, arguments = "", ""
|
for tool_call in tool_calls:
|
||||||
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
functions = []
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for name, arguments in functions:
|
||||||
if isinstance(slot, str):
|
for slot in self.slots:
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
if isinstance(slot, str):
|
||||||
elements.append(slot)
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
elif isinstance(slot, (dict, set)):
|
elements.append(slot)
|
||||||
elements.append(slot)
|
elif isinstance(slot, (dict, set)):
|
||||||
else:
|
elements.append(slot)
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@@ -163,25 +209,22 @@ class FunctionFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format is None:
|
if self.tool_format == "default":
|
||||||
|
self._tool_formatter = default_tool_formatter
|
||||||
|
self._tool_extractor = default_tool_extractor
|
||||||
|
elif self.tool_format == "glm4":
|
||||||
|
self._tool_formatter = glm4_tool_formatter
|
||||||
|
self._tool_extractor = glm4_tool_extractor
|
||||||
|
else:
|
||||||
raise ValueError("Tool format was not found.")
|
raise ValueError("Tool format was not found.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
if not len(tools):
|
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
return [""]
|
except json.JSONDecodeError:
|
||||||
|
|
||||||
if self.tool_format == "default":
|
|
||||||
return [default_tool_formatter(tools)]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
except Exception:
|
|
||||||
return [""]
|
return [""]
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
if self.tool_format == "default":
|
return self._tool_extractor(content)
|
||||||
return default_tool_extractor(content)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
@@ -32,6 +45,7 @@ def load_single_dataset(
|
|||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
@@ -123,7 +137,7 @@ def load_single_dataset(
|
|||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
dataset = dataset.select(range(max_samples))
|
dataset = dataset.select(range(max_samples))
|
||||||
|
|
||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
@@ -157,7 +171,8 @@ def get_dataset(
|
|||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
@@ -169,7 +184,7 @@ def get_dataset(
|
|||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
|
||||||
|
|
||||||
@@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|||||||
@@ -1,9 +1,26 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
|
|
||||||
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
|
|||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
|
||||||
|
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
|
||||||
|
|
||||||
if not data_args.packing:
|
if not data_args.packing:
|
||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from typing import TYPE_CHECKING, List, Sequence
|
from typing import TYPE_CHECKING, List, Sequence
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
@@ -7,8 +21,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@@ -24,12 +38,12 @@ class Template:
|
|||||||
format_observation: "Formatter"
|
format_observation: "Formatter"
|
||||||
format_tools: "Formatter"
|
format_tools: "Formatter"
|
||||||
format_separator: "Formatter"
|
format_separator: "Formatter"
|
||||||
|
format_prefix: "Formatter"
|
||||||
default_system: str
|
default_system: str
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
image_token: str
|
image_token: str
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
force_system: bool
|
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
@@ -65,6 +79,12 @@ class Template:
|
|||||||
"""
|
"""
|
||||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||||
|
|
||||||
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
r"""
|
||||||
|
Extracts tool message.
|
||||||
|
"""
|
||||||
|
return self.format_tools.extract(content)
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@@ -83,10 +103,15 @@ class Template:
|
|||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
if i == 0 and (system or tools or self.force_system):
|
|
||||||
|
if i == 0:
|
||||||
|
elements += self.format_prefix.apply()
|
||||||
|
|
||||||
|
if i == 0 and (system or tools):
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
elements += self.format_system.apply(content=(system + tool_text))
|
elements += self.format_system.apply(content=(system + tool_text))
|
||||||
elif i > 0 and i % 2 == 0:
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
@@ -173,11 +198,16 @@ class Llama2Template(Template):
|
|||||||
encoded_messages = []
|
encoded_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
elements = []
|
elements = []
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
elements += self.format_prefix.apply()
|
||||||
|
|
||||||
system_text = ""
|
system_text = ""
|
||||||
if i == 0 and (system or tools or self.force_system):
|
if i == 0 and (system or tools):
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||||
elif i > 0 and i % 2 == 0:
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
@@ -208,12 +238,12 @@ def _register_template(
|
|||||||
format_observation: Optional["Formatter"] = None,
|
format_observation: Optional["Formatter"] = None,
|
||||||
format_tools: Optional["Formatter"] = None,
|
format_tools: Optional["Formatter"] = None,
|
||||||
format_separator: Optional["Formatter"] = None,
|
format_separator: Optional["Formatter"] = None,
|
||||||
|
format_prefix: Optional["Formatter"] = None,
|
||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: List[str] = [],
|
stop_words: List[str] = [],
|
||||||
image_token: str = "<image>",
|
image_token: str = "<image>",
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
force_system: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Registers a chat template.
|
Registers a chat template.
|
||||||
@@ -245,9 +275,12 @@ def _register_template(
|
|||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
default_function_formatter = FunctionFormatter(
|
||||||
|
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
||||||
|
)
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
|
default_prefix_formatter = EmptyFormatter()
|
||||||
TEMPLATES[name] = template_class(
|
TEMPLATES[name] = template_class(
|
||||||
format_user=format_user or default_user_formatter,
|
format_user=format_user or default_user_formatter,
|
||||||
format_assistant=format_assistant or default_assistant_formatter,
|
format_assistant=format_assistant or default_assistant_formatter,
|
||||||
@@ -256,12 +289,12 @@ def _register_template(
|
|||||||
format_observation=format_observation or format_user or default_user_formatter,
|
format_observation=format_observation or format_user or default_user_formatter,
|
||||||
format_tools=format_tools or default_tool_formatter,
|
format_tools=format_tools or default_tool_formatter,
|
||||||
format_separator=format_separator or default_separator_formatter,
|
format_separator=format_separator or default_separator_formatter,
|
||||||
|
format_prefix=format_prefix or default_prefix_formatter,
|
||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
image_token=image_token,
|
image_token=image_token,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
force_system=force_system,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -307,6 +340,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
|||||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
jinja_template = ""
|
jinja_template = ""
|
||||||
|
|
||||||
|
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||||
|
if prefix:
|
||||||
|
jinja_template += "{{ " + prefix + " }}"
|
||||||
|
|
||||||
if template.default_system:
|
if template.default_system:
|
||||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||||
|
|
||||||
@@ -315,11 +352,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
|||||||
)
|
)
|
||||||
|
|
||||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||||
if isinstance(template, Llama2Template):
|
if not isinstance(template, Llama2Template):
|
||||||
pass
|
|
||||||
elif template.force_system:
|
|
||||||
jinja_template += "{{ " + system_message + " }}"
|
|
||||||
else:
|
|
||||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||||
|
|
||||||
jinja_template += "{% for message in messages %}"
|
jinja_template += "{% for message in messages %}"
|
||||||
@@ -435,9 +468,8 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
force_system=True,
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -450,11 +482,7 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="breeze",
|
name="breeze",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
|
||||||
"You are a helpful AI assistant built by MediaTek Research. "
|
|
||||||
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
|
||||||
),
|
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -462,10 +490,9 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="chatglm2",
|
name="chatglm2",
|
||||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -473,14 +500,14 @@ _register_template(
|
|||||||
name="chatglm3",
|
name="chatglm3",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
),
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -488,13 +515,12 @@ _register_template(
|
|||||||
name="chatglm3_system",
|
name="chatglm3_system",
|
||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
|
||||||
),
|
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
),
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||||
"Follow the user's instructions carefully. Respond using markdown."
|
"Follow the user's instructions carefully. Respond using markdown."
|
||||||
@@ -529,8 +555,7 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="codegeex2",
|
name="codegeex2",
|
||||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -544,21 +569,15 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
|
||||||
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
),
|
|
||||||
default_system=(
|
|
||||||
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
|
|
||||||
"by providing thorough responses. You are trained by Cohere."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="cpm",
|
name="cpm",
|
||||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -591,8 +610,7 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -622,11 +640,8 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="empty",
|
name="empty",
|
||||||
format_user=StringFormatter(slots=["{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -648,13 +663,12 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="gemma",
|
name="gemma",
|
||||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||||
),
|
),
|
||||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -662,36 +676,33 @@ _register_template(
|
|||||||
name="glm4",
|
name="glm4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
stop_words=["<|user|>", "<|observation|>"],
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<eoa>"],
|
stop_words=["<eoa>"],
|
||||||
efficient_eos=True,
|
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern2",
|
name="intern2",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
||||||
default_system=(
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
|
||||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
|
||||||
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
|
||||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
|
||||||
"by the user such as English and 中文."
|
|
||||||
),
|
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
||||||
)
|
)
|
||||||
@@ -700,7 +711,6 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
|
||||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -723,9 +733,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
|
|
||||||
),
|
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@@ -734,7 +742,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
default_system="You are a helpful assistant.",
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
@@ -743,24 +751,21 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="olmo",
|
name="olmo",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="openchat",
|
name="openchat",
|
||||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -774,27 +779,25 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="orion",
|
name="orion",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="phi",
|
name="phi",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful AI assistant.",
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
@@ -827,7 +830,6 @@ _register_template(
|
|||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
stop_words=["<|end|>"],
|
stop_words=["<|end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
force_system=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,41 @@
|
|||||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Dan's test library.
|
||||||
|
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2020 Dan Hendrycks
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Sequence, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@@ -389,6 +403,18 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
||||||
},
|
},
|
||||||
|
"DeepSeek-MoE-Coder-16B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-236B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-16B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||||
|
},
|
||||||
|
"DeepSeek-MoE-Coder-236B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="deepseek",
|
template="deepseek",
|
||||||
)
|
)
|
||||||
@@ -668,6 +694,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiniCPM-2B-SFT-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
||||||
|
},
|
||||||
|
"MiniCPM-2B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="cpm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B-v0.1": {
|
"Mistral-7B-v0.1": {
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
@@ -9,7 +23,7 @@ import trl
|
|||||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||||
|
|
||||||
|
|
||||||
VERSION = "0.8.1"
|
VERSION = "0.8.2"
|
||||||
|
|
||||||
|
|
||||||
def print_env() -> None:
|
def print_env() -> None:
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, Tuple
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
@@ -8,6 +22,7 @@ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTr
|
|||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
is_safetensors_available,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
@@ -20,6 +35,11 @@ from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
|||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
@@ -114,9 +134,6 @@ def fix_valuehead_checkpoint(
|
|||||||
return
|
return
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
from safetensors import safe_open
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
|||||||
@@ -1,5 +1,23 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -24,10 +42,6 @@ def is_fastapi_available():
|
|||||||
return _is_package_available("fastapi")
|
return _is_package_available("fastapi")
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn2_available():
|
|
||||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
|
|
||||||
|
|
||||||
|
|
||||||
def is_galore_available():
|
def is_galore_available():
|
||||||
return _is_package_available("galore_torch")
|
return _is_package_available("galore_torch")
|
||||||
|
|
||||||
@@ -36,18 +50,10 @@ def is_gradio_available():
|
|||||||
return _is_package_available("gradio")
|
return _is_package_available("gradio")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
|
||||||
return _is_package_available("jieba")
|
|
||||||
|
|
||||||
|
|
||||||
def is_matplotlib_available():
|
def is_matplotlib_available():
|
||||||
return _is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
def is_nltk_available():
|
|
||||||
return _is_package_available("nltk")
|
|
||||||
|
|
||||||
|
|
||||||
def is_pillow_available():
|
def is_pillow_available():
|
||||||
return _is_package_available("PIL")
|
return _is_package_available("PIL")
|
||||||
|
|
||||||
@@ -60,10 +66,6 @@ def is_rouge_available():
|
|||||||
return _is_package_available("rouge_chinese")
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
def is_sdpa_available():
|
|
||||||
return _get_package_version("torch") > version.parse("2.1.1")
|
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return _is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
@@ -74,3 +76,8 @@ def is_uvicorn_available():
|
|||||||
|
|
||||||
def is_vllm_available():
|
def is_vllm_available():
|
||||||
return _is_package_available("vllm")
|
return _is_package_available("vllm")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_vllm_version_greater_than_0_5():
|
||||||
|
return _get_package_version("vllm") >= version.parse("0.5.0")
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
|
|||||||
@@ -1,3 +1,20 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|||||||
@@ -1,5 +1,19 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -94,6 +108,18 @@ class LoraArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||||
)
|
)
|
||||||
|
pissa_init: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
||||||
|
)
|
||||||
|
pissa_iter: int = field(
|
||||||
|
default=4,
|
||||||
|
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
||||||
|
)
|
||||||
|
pissa_convert: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
|
||||||
|
)
|
||||||
create_new_adapter: bool = field(
|
create_new_adapter: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||||
@@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
return [item.strip() for item in arg.split(",")]
|
return [item.strip() for item in arg.split(",")]
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||||
self.galore_target = split_arg(self.galore_target)
|
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||||
|
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|
||||||
@@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.pissa_convert and self.finetuning_type != "lora":
|
||||||
|
raise ValueError("`pissa_convert` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
|
||||||
|
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||||
|
|
||||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,28 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -22,6 +45,10 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
adapter_folder: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The folder containing the adapter weights to load."},
|
||||||
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
@@ -127,13 +154,9 @@ class ModelArguments:
|
|||||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||||
)
|
)
|
||||||
vllm_max_lora_rank: int = field(
|
vllm_max_lora_rank: int = field(
|
||||||
default=8,
|
default=32,
|
||||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||||
)
|
)
|
||||||
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
|
||||||
default="auto",
|
|
||||||
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
|
|
||||||
)
|
|
||||||
offload_folder: str = field(
|
offload_folder: str = field(
|
||||||
default="offload",
|
default="offload",
|
||||||
metadata={"help": "Path to offload model weights."},
|
metadata={"help": "Path to offload model weights."},
|
||||||
@@ -142,6 +165,10 @@ class ModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
|
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={"help": "Data type for model weights and activations at inference."},
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
@@ -192,9 +219,9 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype: Optional["torch.dtype"] = None
|
||||||
self.device_map = None
|
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
self.model_max_length = None
|
self.model_max_length: Optional[int] = None
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
@@ -216,3 +243,13 @@ class ModelArguments:
|
|||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||||
|
arg_dict = old_arg.to_dict()
|
||||||
|
arg_dict.update(**kwargs)
|
||||||
|
new_arg = cls(**arg_dict)
|
||||||
|
new_arg.compute_dtype = old_arg.compute_dtype
|
||||||
|
new_arg.device_map = old_arg.device_map
|
||||||
|
new_arg.model_max_length = old_arg.model_max_length
|
||||||
|
return new_arg
|
||||||
|
|||||||
@@ -1,3 +1,20 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -8,6 +25,7 @@ import transformers
|
|||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
@@ -72,6 +90,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
if finetuning_args.pissa_init:
|
||||||
|
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
|
||||||
|
|
||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||||
|
|
||||||
@@ -162,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
):
|
):
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||||
|
|
||||||
|
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
@@ -171,9 +195,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||||
raise ValueError("Cannot use device map for quantized models in training.")
|
raise ValueError("Cannot use device map for quantized models in training.")
|
||||||
|
|
||||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
|
||||||
raise ValueError("Unsloth does not support DoRA.")
|
|
||||||
|
|
||||||
if finetuning_args.pure_bf16:
|
if finetuning_args.pure_bf16:
|
||||||
if not is_torch_bf16_gpu_available():
|
if not is_torch_bf16_gpu_available():
|
||||||
raise ValueError("This device does not support `pure_bf16`.")
|
raise ValueError("This device does not support `pure_bf16`.")
|
||||||
@@ -184,14 +205,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if (
|
if (
|
||||||
finetuning_args.use_galore
|
finetuning_args.use_galore
|
||||||
and finetuning_args.galore_layerwise
|
and finetuning_args.galore_layerwise
|
||||||
and training_args.parallel_mode.value == "distributed"
|
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
):
|
):
|
||||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_badam
|
finetuning_args.use_badam
|
||||||
and finetuning_args.badam_mode == "layer"
|
and finetuning_args.badam_mode == "layer"
|
||||||
and training_args.parallel_mode.value == "distributed"
|
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
):
|
):
|
||||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||||
|
|
||||||
@@ -233,7 +254,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
training_args.parallel_mode.value == "distributed"
|
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
@@ -293,7 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
training_args.local_rank,
|
training_args.local_rank,
|
||||||
training_args.device,
|
training_args.device,
|
||||||
training_args.n_gpu,
|
training_args.n_gpu,
|
||||||
training_args.parallel_mode.value == "distributed",
|
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||||
str(model_args.compute_dtype),
|
str(model_args.compute_dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -332,6 +353,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
|
|
||||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||||
model_args.device_map = {"": torch.device("cpu")}
|
model_args.device_map = {"": torch.device("cpu")}
|
||||||
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
else:
|
else:
|
||||||
model_args.device_map = "auto"
|
model_args.device_map = "auto"
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from llamafactory.train.tuner import run_exp
|
from llamafactory.train.tuner import run_exp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from .loader import load_config, load_model, load_tokenizer
|
from .loader import load_config, load_model, load_tokenizer
|
||||||
from .model_utils.misc import find_all_linear_modules
|
from .model_utils.misc import find_all_linear_modules
|
||||||
from .model_utils.valuehead import load_valuehead_params
|
from .model_utils.valuehead import load_valuehead_params
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -25,8 +39,12 @@ def _setup_full_tuning(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not is_trainable:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not is_trainable:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs:
|
||||||
config = model.config.text_config
|
config = model.config.text_config
|
||||||
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
|
|||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PeftModel":
|
) -> "PeftModel":
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
if is_trainable:
|
||||||
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
adapter_to_merge = model_args.adapter_name_or_path
|
adapter_to_merge = model_args.adapter_name_or_path
|
||||||
|
|
||||||
|
init_kwargs = {
|
||||||
|
"subfolder": model_args.adapter_folder,
|
||||||
|
"offload_folder": model_args.offload_folder,
|
||||||
|
"cache_dir": model_args.cache_dir,
|
||||||
|
"revision": model_args.model_revision,
|
||||||
|
"token": model_args.hf_hub_token,
|
||||||
|
}
|
||||||
|
|
||||||
for adapter in adapter_to_merge:
|
for adapter in adapter_to_merge:
|
||||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
|
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
|
|||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||||
model,
|
|
||||||
adapter_to_resume,
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
is_trainable=is_trainable,
|
|
||||||
offload_folder=model_args.offload_folder,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@@ -216,6 +245,14 @@ def _setup_lora_tuning(
|
|||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||||
else:
|
else:
|
||||||
|
if finetuning_args.pissa_init:
|
||||||
|
if finetuning_args.pissa_iter == -1:
|
||||||
|
logger.info("Using PiSSA initialization.")
|
||||||
|
peft_kwargs["init_lora_weights"] = "pissa"
|
||||||
|
else:
|
||||||
|
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||||
|
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
inference_mode=False,
|
inference_mode=False,
|
||||||
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
|
|||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -247,29 +281,37 @@ def init_adapter(
|
|||||||
|
|
||||||
Note that the trainable parameters must be cast to float32.
|
Note that the trainable parameters must be cast to float32.
|
||||||
"""
|
"""
|
||||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
if finetuning_args.finetuning_type != "lora":
|
||||||
return model
|
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
if finetuning_args.pissa_init:
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
# cast trainable parameters to float32 if:
|
||||||
|
# 1. is_trainable and quantization_bit is not None (qlora)
|
||||||
|
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
|
||||||
|
# 3. is_trainable and not pure_bf16 and not badam
|
||||||
|
if not is_trainable:
|
||||||
|
cast_trainable_params_to_fp32 = False
|
||||||
|
elif model_args.quantization_bit is None and (
|
||||||
|
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
|
||||||
|
):
|
||||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||||
cast_trainable_params_to_fp32 = False
|
cast_trainable_params_to_fp32 = False
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
|
elif finetuning_args.finetuning_type == "freeze":
|
||||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
elif finetuning_args.finetuning_type == "lora":
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
|
||||||
model = _setup_lora_tuning(
|
model = _setup_lora_tuning(
|
||||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||||
|
|||||||
@@ -1,7 +1,22 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -21,13 +36,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
requested_attn_implementation = "eager"
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
if not is_sdpa_available():
|
if not is_torch_sdpa_available():
|
||||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == "fa2":
|
elif model_args.flash_attn == "fa2":
|
||||||
if not is_flash_attn2_available():
|
if not is_flash_attn_2_available():
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
@@ -68,7 +86,6 @@ def prepare_model_for_training(
|
|||||||
(1) cast the layernorm in fp32
|
(1) cast the layernorm in fp32
|
||||||
(2) make output embedding layer require grads
|
(2) make output embedding layer require grads
|
||||||
(3) add the upcasting of the lm_head in fp32
|
(3) add the upcasting of the lm_head in fp32
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
logger.info("Upcasting layernorm weights in float32.")
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user