mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-01-30 22:32:05 +00:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
762b480131 | ||
|
|
9640f79ae5 | ||
|
|
7ef19eea00 | ||
|
|
f9f11dcb97 | ||
|
|
641bfdd482 | ||
|
|
e70651ac58 | ||
|
|
db2f794f7b | ||
|
|
44eadbda1c | ||
|
|
9829ae0a77 | ||
|
|
958b9c3468 | ||
|
|
4d3621e3d3 | ||
|
|
a296723697 | ||
|
|
15b87f3125 | ||
|
|
9f73a6eb23 | ||
|
|
b2effbd77c | ||
|
|
d7d734d54c | ||
|
|
8abb8fb533 | ||
|
|
766d5ae6ad | ||
|
|
5cccaeec82 | ||
|
|
5fb5d7ebd3 | ||
|
|
03a70ba8dd | ||
|
|
5cfd804b59 | ||
|
|
4c1eb922e2 | ||
|
|
958fb523a2 | ||
|
|
b4e051bea4 | ||
|
|
d43e1007e8 | ||
|
|
f89d9367e5 | ||
|
|
d22de0d4bf | ||
|
|
ea0b4e2466 | ||
|
|
e944dc442c | ||
|
|
68119e5522 | ||
|
|
f60a6e3d01 | ||
|
|
81b8a50aa5 | ||
|
|
8600530002 | ||
|
|
9ae62c6fc0 | ||
|
|
0087bc253b | ||
|
|
355d5c5e5a | ||
|
|
6fe6bd290b |
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
docker-images: false
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Get llamafactory version
|
||||
id: version
|
||||
|
||||
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
15
.github/workflows/tests.yml
vendored
15
.github/workflows/tests.yml
vendored
@@ -54,10 +54,11 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
@@ -70,7 +71,8 @@ jobs:
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install -e ".[dev]"
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Install transformers
|
||||
if: ${{ matrix.transformers }}
|
||||
@@ -79,7 +81,7 @@ jobs:
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
|
||||
@@ -87,25 +89,18 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
29
.github/workflows/tests_cuda.yml
vendored
29
.github/workflows/tests_cuda.yml
vendored
@@ -35,9 +35,16 @@ jobs:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
env:
|
||||
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
|
||||
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
@@ -52,37 +59,21 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
25
.github/workflows/tests_npu.yml
vendored
25
.github/workflows/tests_npu.yml
vendored
@@ -43,10 +43,11 @@ jobs:
|
||||
HF_ENDPOINT: https://hf-mirror.com
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
@@ -58,8 +59,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install -e ".[dev]"
|
||||
uv pip install -r requirements/npu.txt
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Install node
|
||||
run: |
|
||||
@@ -68,35 +70,18 @@ jobs:
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
|
||||
apt-get install -y nodejs
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: /root/.cache/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -176,6 +176,7 @@ llamaboard_cache/
|
||||
llamaboard_config/
|
||||
saves/
|
||||
output/
|
||||
outputs/
|
||||
wandb/
|
||||
swanlog/
|
||||
generated_predictions.jsonl
|
||||
|
||||
63
README.md
63
README.md
@@ -92,7 +92,7 @@ Read technical notes:
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen3, Qwen3-VL, DeepSeek, Gemma, GLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
|
||||
@@ -279,11 +279,10 @@ Read technical notes:
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
@@ -292,12 +291,13 @@ Read technical notes:
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
@@ -307,18 +307,17 @@ Read technical notes:
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
@@ -327,8 +326,6 @@ Read technical notes:
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -514,12 +511,13 @@ huggingface-cli login
|
||||
#### Install from Source
|
||||
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]"
|
||||
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||
cd LlamaFactory
|
||||
pip install -e .
|
||||
pip install -r requirements/metrics.txt
|
||||
```
|
||||
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
|
||||
|
||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||
|
||||
@@ -577,36 +575,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
|
||||
|
||||
|
||||
You can also download the pre-built Docker images:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
# install CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||
# Docker Hub
|
||||
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||
|
||||
# install CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||
|
||||
# set env variables
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# quay.io
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||
```
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
||||
|
||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
#### Install BitsAndBytes
|
||||
|
||||
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||
|
||||
62
README_zh.md
62
README_zh.md
@@ -94,7 +94,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen3、Qwen3-VL、DeepSeek、Gemma、GLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
|
||||
@@ -281,11 +281,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| 模型名 | 参数量 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
|
||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
|
||||
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
|
||||
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
|
||||
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
|
||||
@@ -294,12 +293,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
@@ -309,18 +309,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
|
||||
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
|
||||
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
|
||||
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
|
||||
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
|
||||
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
|
||||
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
@@ -329,8 +328,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
|
||||
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
@@ -516,12 +513,13 @@ huggingface-cli login
|
||||
#### 从源码安装
|
||||
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[metrics]"
|
||||
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||
cd LlamaFactory
|
||||
pip install -e .
|
||||
pip install -r requirements/metrics.txt
|
||||
```
|
||||
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
|
||||
|
||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||
|
||||
@@ -579,36 +577,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
|
||||
|
||||
您可以直接下载预安装的最新docker镜像:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
# 安装 CANN Toolkit
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
|
||||
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
|
||||
# Docker Hub
|
||||
docker pull hiyouga/llamafactory:latest-npu-a2
|
||||
docker pull hiyouga/llamafactory:latest-npu-a3
|
||||
|
||||
# 安装 CANN Kernels
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
|
||||
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
|
||||
# 设置环境变量
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# quay.io
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a2
|
||||
docker pull quay.io/ascend/llamafactory:latest-npu-a3
|
||||
```
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||
|
||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
#### 安装 BitsAndBytes
|
||||
|
||||
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e . && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.05-py3
|
||||
# NVIDIA official image (ubuntu-24.04 + cuda-12.9.1 + python-3.12)
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-06.html
|
||||
FROM nvcr.io/nvidia/pytorch:25.06-py3
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PIP_ROOT_USER_ACTION=ignore
|
||||
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
|
||||
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
|
||||
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
||||
ENV PIP_CONSTRAINT=""
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
|
||||
@@ -14,20 +15,14 @@ RUN pip uninstall -y torch torchvision torch-tensorrt \
|
||||
flash_attn transformer-engine \
|
||||
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
|
||||
|
||||
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
RUN pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129
|
||||
|
||||
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
|
||||
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
|
||||
rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \
|
||||
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
|
||||
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
|
||||
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
|
||||
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
|
||||
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
# RUN pip install vllm==0.8.4 \
|
||||
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
RUN pip install --trusted-host mirrors.aliyun.com --index-url ${PYPI_MIRROR} \
|
||||
"megatron-core>=0.13.0,<0.14.0" "deepspeed==0.16.4"
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -37,6 +32,8 @@ RUN pip uninstall -y apex && \
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
|
||||
|
||||
RUN pip install --no-build-isolation transformer_engine[pytorch]
|
||||
|
||||
RUN rm -rf /build
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -53,14 +50,17 @@ RUN apt-get update && apt-get install -y zip
|
||||
RUN apt-get install -y openjdk-21-jdk
|
||||
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
|
||||
|
||||
# pip install LLaMA-Factory
|
||||
ARG REPO_URL=https://github.com/hiyouga/LlamaFactory.git
|
||||
ARG BRANCH=main
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
# Clone the repository
|
||||
RUN git clone --depth 1 --branch ${BRANCH} ${REPO_URL} /app || \
|
||||
git clone --depth 1 ${REPO_URL} /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||
|
||||
|
||||
@@ -35,7 +35,8 @@ COPY . /app
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||
|
||||
@@ -34,7 +34,8 @@ COPY . /app
|
||||
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
38
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
38
examples/extras/eaft/qwen25_05b_eaft_full.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_eaft_loss: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: qwen2.5-0_5b/full/sft_eaft
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
template: deepseek
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
|
||||
adapter_name_or_path: saves/Kllama_deepseekV3
|
||||
template: deepseek
|
||||
template: deepseek3
|
||||
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
|
||||
trust_remote_code: true
|
||||
|
||||
use_kt: true # use KTransformers as LoRA sft backend to inference
|
||||
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: deepseek
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ lora_rank: 8
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: identity
|
||||
template: deepseek
|
||||
dataset: identity, alpaca_en_demo
|
||||
template: deepseek3
|
||||
cutoff_len: 2048
|
||||
max_samples: 100000
|
||||
overwrite_cache: true
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
|
||||
|
||||
### ktransformers
|
||||
use_kt: true # use KTransformers as LoRA sft backend
|
||||
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
|
||||
cpu_infer: 32
|
||||
chunk_size: 8192
|
||||
|
||||
|
||||
34
examples/v1/train_full/train_full_fsdp2.yaml
Normal file
34
examples/v1/train_full/train_full_fsdp2.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
init_config:
|
||||
name: init_on_meta
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -30,7 +30,6 @@ classifiers = [
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
@@ -63,7 +62,7 @@ dependencies = [
|
||||
"hf-transfer",
|
||||
"safetensors",
|
||||
# python
|
||||
"av",
|
||||
"av>=10.0.0,<=16.0.0",
|
||||
"fire",
|
||||
"omegaconf",
|
||||
"packaging",
|
||||
@@ -73,14 +72,9 @@ dependencies = [
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette"
|
||||
"sse-starlette",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||
|
||||
[project.scripts]
|
||||
llamafactory-cli = "llamafactory.cli:main"
|
||||
lmf = "llamafactory.cli:main"
|
||||
|
||||
1
requirements/deepspeed.txt
Normal file
1
requirements/deepspeed.txt
Normal file
@@ -0,0 +1 @@
|
||||
deepspeed>=0.10.0,<=0.16.9
|
||||
4
requirements/dev.txt
Normal file
4
requirements/dev.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pre-commit
|
||||
ruff
|
||||
pytest
|
||||
build
|
||||
3
requirements/metrics.txt
Normal file
3
requirements/metrics.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
nltk
|
||||
jieba
|
||||
rouge-chinese
|
||||
4
requirements/npu.txt
Normal file
4
requirements/npu.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
32
scripts/convert_ckpt/tiny_qwen3.py
Normal file
32
scripts/convert_ckpt/tiny_qwen3.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoTokenizer, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
config = Qwen3Config(
|
||||
hidden_size=1408,
|
||||
image_size=336,
|
||||
intermediate_size=5632,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=4,
|
||||
vision_output_dim=4096,
|
||||
)
|
||||
model = Qwen3ForCausalLM.from_config(config)
|
||||
model.save_pretrained("tiny-qwen3")
|
||||
tokenizer.save_pretrained("tiny-qwen3")
|
||||
model.push_to_hub("llamafactory/tiny-random-qwen3")
|
||||
tokenizer.push_to_hub("llamafactory/tiny-random-qwen3")
|
||||
@@ -28,7 +28,7 @@ try:
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
jieba.initialize()
|
||||
except ImportError:
|
||||
print("Please install llamafactory with `pip install -e .[metrics]`.")
|
||||
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
55
scripts/hf2dcp.py
Normal file
55
scripts/hf2dcp.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a HuggingFace model to DCP checkpoint format.
|
||||
|
||||
Usage:
|
||||
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
|
||||
|
||||
Arguments:
|
||||
hf_path: Path to the HuggingFace model directory.
|
||||
dcp_path: Output path (directory) for DCP checkpoint.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def convert(hf_path: str, dcp_path: str) -> None:
|
||||
"""Convert HF model weights to DCP.
|
||||
|
||||
Args:
|
||||
hf_path: HuggingFace model directory.
|
||||
dcp_path: Output path (directory) for DCP checkpoint.
|
||||
"""
|
||||
if not hf_path or not dcp_path:
|
||||
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
|
||||
|
||||
print(f"Loading HF model from {hf_path}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
|
||||
|
||||
print(f"Saving to DCP format at {dcp_path}...")
|
||||
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
|
||||
print("Done!")
|
||||
|
||||
|
||||
def help() -> None:
|
||||
"""Show help message."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire({"convert": convert, "help": help, "--convert": convert})
|
||||
@@ -65,11 +65,17 @@ def merge_dataset(
|
||||
if not data_args.streaming:
|
||||
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
strategy_map: str = {
|
||||
"interleave_under": "first_exhausted",
|
||||
"interleave_over": "all_exhausted",
|
||||
"interleave_once": "all_exhausted_without_replacement",
|
||||
}[data_args.mix_strategy]
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
stopping_strategy=strategy_map, # type: ignore
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class LFMVLPlugin(BasePlugin):
|
||||
r"""Plugin for LFM2.5-VL vision-language models.
|
||||
|
||||
LFM2.5-VL uses dynamic image token counts based on image resolution.
|
||||
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
|
||||
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
|
||||
"""
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
|
||||
|
||||
if self.expand_mm_tokens and len(images) > 0:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
spatial_shapes = mm_inputs.get("spatial_shapes", [])
|
||||
else:
|
||||
spatial_shapes = []
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
|
||||
h, w = spatial_shapes[num_image_tokens].tolist()
|
||||
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"ernie_vl": ErnieVLPlugin,
|
||||
@@ -2104,6 +2171,7 @@ PLUGINS = {
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"lfm2_vl": LFMVLPlugin,
|
||||
"minicpm_v": MiniCPMVPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
|
||||
@@ -649,42 +649,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}###"]),
|
||||
format_system=StringFormatter(slots=["System: {{content}}###"]),
|
||||
default_system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
stop_words=["</s>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bailing",
|
||||
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
|
||||
@@ -712,20 +676,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="breeze",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
||||
@@ -734,14 +684,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
@@ -784,29 +726,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex2",
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
default_system=(
|
||||
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
|
||||
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
@@ -822,25 +741,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="cpm3",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="cpm4",
|
||||
@@ -1239,19 +1139,12 @@ register_template(
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
|
||||
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
|
||||
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
|
||||
"chosen by the user such as English and 中文."
|
||||
),
|
||||
stop_words=["<eoa>"],
|
||||
name="hunyuan_small",
|
||||
format_user=StringFormatter(slots=["<|hy_User|>{{content}}<|hy_place▁holder▁no▁8|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁2|>"]),
|
||||
format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]),
|
||||
format_prefix=EmptyFormatter(slots=["<|hy_begin▁of▁sentence|>"]),
|
||||
stop_words=["<|hy_place▁holder▁no▁2|>"],
|
||||
)
|
||||
|
||||
|
||||
@@ -1330,6 +1223,47 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="lfm2",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="lfm2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||
default_system="You are a helpful multimodal assistant by Liquid AI.",
|
||||
stop_words=["<|im_end|>"],
|
||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
@@ -1576,23 +1510,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="marco",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
|
||||
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen template
|
||||
register_template(
|
||||
name="mimo",
|
||||
@@ -1804,13 +1721,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["{{content}}\n"]),
|
||||
@@ -1869,6 +1779,17 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="phi4_mini",
|
||||
format_user=StringFormatter(slots=["<|user|>{{content}}<|end|><|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_system=StringFormatter(slots=["<|system|>{{content}}<|end|>"]),
|
||||
format_tools=StringFormatter(slots=["<|tool|>{{content}}<|/tool|>"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
# copied from ministral template
|
||||
register_template(
|
||||
name="pixtral",
|
||||
@@ -2104,41 +2025,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
register_template(
|
||||
name="skywork_o1",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
|
||||
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
|
||||
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
|
||||
"After completing your thoughts, you then provide a detailed explanation of the solution process "
|
||||
"in your response."
|
||||
),
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="smollm",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@@ -2175,13 +2061,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="telechat",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="telechat2",
|
||||
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
||||
@@ -2225,32 +2104,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
register_template(
|
||||
name="yi",
|
||||
@@ -2278,6 +2131,21 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="youtu",
|
||||
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
|
||||
format_system=StringFormatter(slots=["{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
|
||||
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="default"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|end_of_text|>"],
|
||||
replace_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
||||
@@ -2292,10 +2160,3 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n"]),
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
LFM2_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
|
||||
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||
|
||||
|
||||
class LFM2ToolUtils(ToolUtils):
|
||||
r"""LFM2.5 tool using template with Pythonic function call syntax."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_list = []
|
||||
for tool in tools:
|
||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
tool_list.append(tool)
|
||||
|
||||
return LFM2_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
calls = []
|
||||
for name, args_json in functions:
|
||||
args = json.loads(args_json)
|
||||
kwargs_parts = []
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
kwargs_parts.append(f'{key}="{value}"')
|
||||
else:
|
||||
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
|
||||
|
||||
calls.append(f"{name}({', '.join(kwargs_parts)})")
|
||||
|
||||
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
|
||||
|
||||
@staticmethod
|
||||
def _ast_to_value(node: ast.AST) -> Any:
|
||||
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
|
||||
# Handle JSON-style true/false/null as Name nodes
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id == "true":
|
||||
return True
|
||||
elif node.id == "false":
|
||||
return False
|
||||
elif node.id == "null":
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Unknown identifier: {node.id}")
|
||||
|
||||
# Use literal_eval for other cases (strings, numbers, lists, dicts)
|
||||
return ast.literal_eval(node)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
# Extract content between tool call markers
|
||||
start_marker = "<|tool_call_start|>"
|
||||
end_marker = "<|tool_call_end|>"
|
||||
|
||||
start_idx = content.find(start_marker)
|
||||
if start_idx == -1:
|
||||
return content
|
||||
|
||||
end_idx = content.find(end_marker, start_idx)
|
||||
if end_idx == -1:
|
||||
return content
|
||||
|
||||
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
|
||||
|
||||
# Parse Pythonic function call syntax using AST
|
||||
try:
|
||||
tree = ast.parse(tool_call_str, mode="eval")
|
||||
except SyntaxError:
|
||||
return content
|
||||
|
||||
# Handle both single call and list of calls
|
||||
if isinstance(tree.body, ast.List):
|
||||
call_nodes = tree.body.elts
|
||||
elif isinstance(tree.body, ast.Call):
|
||||
call_nodes = [tree.body]
|
||||
else:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for node in call_nodes:
|
||||
if not isinstance(node, ast.Call):
|
||||
return content
|
||||
|
||||
# Extract function name
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
else:
|
||||
return content
|
||||
|
||||
# Extract keyword arguments
|
||||
args_dict = {}
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
try:
|
||||
value = LFM2ToolUtils._ast_to_value(keyword.value)
|
||||
except (ValueError, SyntaxError):
|
||||
return content
|
||||
args_dict[key] = value
|
||||
|
||||
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results if results else content
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"lfm2": LFM2ToolUtils(),
|
||||
"minimax1": MiniMaxM1ToolUtils(),
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
|
||||
@@ -57,6 +57,7 @@ LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
|
||||
MCA_SUPPORTED_MODELS = {
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
@@ -181,51 +182,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
||||
},
|
||||
"Baichuan-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
||||
},
|
||||
"Baichuan-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
||||
},
|
||||
},
|
||||
template="baichuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
||||
},
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
|
||||
},
|
||||
},
|
||||
template="baichuan2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": {
|
||||
@@ -262,21 +218,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
||||
},
|
||||
"BlueLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="bluelm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Breeze-7B": {
|
||||
@@ -290,17 +231,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
||||
}
|
||||
},
|
||||
template="chatglm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": {
|
||||
@@ -347,17 +277,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGeeX4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
|
||||
},
|
||||
},
|
||||
template="codegeex4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-7B": {
|
||||
@@ -642,15 +561,15 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ERNIE-4.5-0.3B-PT": {
|
||||
"ERNIE-4.5-0.3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT",
|
||||
},
|
||||
"ERNIE-4.5-21B-A3B-PT": {
|
||||
"ERNIE-4.5-21B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT",
|
||||
},
|
||||
"ERNIE-4.5-300B-A47B-PT": {
|
||||
"ERNIE-4.5-300B-A47B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT",
|
||||
},
|
||||
@@ -661,7 +580,7 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ERNIE-4.5-VL-28B-A3B-PT": {
|
||||
"ERNIE-4.5-VL-28B-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
},
|
||||
@@ -669,7 +588,7 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
|
||||
},
|
||||
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
|
||||
"ERNIE-4.5-VL-424B-A47B-Instruct": {
|
||||
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
|
||||
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
|
||||
},
|
||||
@@ -1226,19 +1145,50 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Hunyuan-0.5B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-0.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-0.5B-Instruct",
|
||||
},
|
||||
"Hunyuan-1.8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-1.8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-1.8B-Instruct",
|
||||
},
|
||||
"Hunyuan-4B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-4B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-4B-Instruct",
|
||||
},
|
||||
"Hunyuan-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-7B-Instruct",
|
||||
},
|
||||
"Hunyuan-MT-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
|
||||
},
|
||||
"HY-MT1.5-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/HY-MT1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-7B",
|
||||
},
|
||||
"Hunyuan-A13B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Hunyuan-A13B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-A13B-Instruct",
|
||||
},
|
||||
},
|
||||
template="hunyuan",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"HY-MT1.5-1.8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/HY-MT1.5-1.8B",
|
||||
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-1.8B",
|
||||
},
|
||||
},
|
||||
template="hunyuan_small",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Index-1.9B-Base": {
|
||||
@@ -1266,29 +1216,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
||||
},
|
||||
"InternLM-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
||||
},
|
||||
"InternLM-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||
},
|
||||
"InternLM-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
||||
},
|
||||
},
|
||||
template="intern",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM2-7B": {
|
||||
@@ -1485,11 +1412,25 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
||||
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
||||
}
|
||||
"LFM2.5-1.2B": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
|
||||
},
|
||||
"LFM2.5-1.2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
},
|
||||
},
|
||||
template="lfm2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LFM2.5-VL-1.6B": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
|
||||
},
|
||||
},
|
||||
template="lfm2_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1804,17 +1745,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Marco-o1-Chat": {
|
||||
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
|
||||
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
|
||||
},
|
||||
},
|
||||
template="marco",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-7B-Base": {
|
||||
@@ -1885,33 +1815,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-2B-SFT-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
||||
},
|
||||
"MiniCPM-2B-DPO-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
||||
},
|
||||
},
|
||||
template="cpm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM3-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
|
||||
},
|
||||
},
|
||||
template="cpm3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM4-0.5B-Chat": {
|
||||
@@ -1949,26 +1852,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
||||
},
|
||||
},
|
||||
template="minicpm_v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-4": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4",
|
||||
},
|
||||
},
|
||||
template="minicpm_v",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-V-4.5": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5",
|
||||
@@ -2226,33 +2113,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Orion-14B-Base": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
||||
},
|
||||
"Orion-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
||||
},
|
||||
"Orion-14B-Long-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
||||
},
|
||||
"Orion-14B-RAG-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
||||
},
|
||||
"Orion-14B-Plugin-Chat": {
|
||||
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
},
|
||||
},
|
||||
template="orion",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"PaliGemma-3B-pt-224": {
|
||||
@@ -2349,20 +2209,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-1.5-1.3B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
||||
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
||||
},
|
||||
"Phi-2-2.7B": {
|
||||
DownloadSource.DEFAULT: "microsoft/phi-2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-3-4B-4k-Instruct": {
|
||||
@@ -2419,6 +2265,15 @@ register_model_group(
|
||||
template="phi4",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi-4-3.8B-instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-4-mini-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-4-mini-instruct",
|
||||
},
|
||||
},
|
||||
template="phi4_mini",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
@@ -2432,228 +2287,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
|
||||
},
|
||||
"Qwen-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
|
||||
},
|
||||
"Qwen-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
|
||||
},
|
||||
"Qwen-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
|
||||
},
|
||||
"Qwen-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
|
||||
},
|
||||
"Qwen-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
|
||||
},
|
||||
"Qwen-1.8B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
},
|
||||
"Qwen-1.8B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
},
|
||||
"Qwen-7B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
|
||||
},
|
||||
"Qwen-7B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
|
||||
},
|
||||
"Qwen-14B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
|
||||
},
|
||||
"Qwen-14B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
|
||||
},
|
||||
"Qwen-72B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
|
||||
},
|
||||
"Qwen-72B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen1.5-0.5B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
|
||||
},
|
||||
"Qwen1.5-1.8B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
|
||||
},
|
||||
"Qwen1.5-4B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
|
||||
},
|
||||
"Qwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-14B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
|
||||
},
|
||||
"Qwen1.5-32B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
|
||||
},
|
||||
"Qwen1.5-72B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
|
||||
},
|
||||
"Qwen1.5-110B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
|
||||
},
|
||||
"Qwen1.5-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
|
||||
},
|
||||
"Qwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
|
||||
},
|
||||
"Qwen1.5-32B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
|
||||
},
|
||||
"Qwen1.5-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
|
||||
},
|
||||
"Qwen1.5-110B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-0.5B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-1.8B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-4B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-4B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-7B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-14B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-14B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-32B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-72B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-72B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-110B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"CodeQwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"CodeQwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"CodeQwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-0.5B": {
|
||||
@@ -3421,27 +3054,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
||||
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-o1-Open-Llama-3.1-8B": {
|
||||
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
|
||||
}
|
||||
},
|
||||
template="skywork_o1",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SmolLM-135M": {
|
||||
@@ -3536,30 +3148,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat-1B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
|
||||
},
|
||||
"TeleChat-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
|
||||
},
|
||||
"TeleChat-12B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
|
||||
},
|
||||
"TeleChat-52B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
|
||||
},
|
||||
},
|
||||
template="telechat",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat2-3B-Chat": {
|
||||
@@ -3674,80 +3262,6 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
||||
},
|
||||
"XVERSE-13B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
||||
},
|
||||
"XVERSE-65B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
||||
},
|
||||
"XVERSE-65B-2": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
||||
},
|
||||
"XVERSE-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
||||
},
|
||||
"XVERSE-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
||||
},
|
||||
"XVERSE-65B-Chat": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
||||
},
|
||||
"XVERSE-MoE-A4.2B": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
||||
},
|
||||
"XVERSE-7B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-7B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-13B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"XVERSE-13B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"XVERSE-65B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
||||
},
|
||||
},
|
||||
template="xverse",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yayi-7B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
||||
},
|
||||
"Yayi-13B": {
|
||||
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
||||
},
|
||||
},
|
||||
template="yayi",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": {
|
||||
@@ -3846,6 +3360,21 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Youtu-LLM-2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
|
||||
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
|
||||
},
|
||||
"Youtu-LLM-2B-Base": {
|
||||
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
|
||||
},
|
||||
},
|
||||
template="youtu",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yuan2-2B-Chat": {
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.4"
|
||||
VERSION = "0.9.5.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
|
||||
return torch.device(device)
|
||||
|
||||
|
||||
def get_device_name() -> str:
|
||||
r"""Get the name of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu"
|
||||
elif is_torch_npu_available():
|
||||
device = "npu"
|
||||
elif is_torch_mps_available():
|
||||
device = "mps"
|
||||
elif is_torch_cuda_available():
|
||||
device = "gpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
r"""Get the torch device namespace for the available devices."""
|
||||
device_name = get_device_name()
|
||||
device_name = "cuda" if device_name == "gpu" else device_name
|
||||
try:
|
||||
return getattr(torch, device_name)
|
||||
except AttributeError:
|
||||
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
|
||||
return torch.cuda
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
|
||||
@@ -63,9 +63,9 @@ class DataArguments:
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
||||
)
|
||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
||||
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
|
||||
)
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
|
||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the DFT loss."},
|
||||
)
|
||||
use_eaft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the EAFT loss."},
|
||||
)
|
||||
eaft_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
||||
|
||||
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
||||
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
elif sys.argv[1].endswith(".json"):
|
||||
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||
else:
|
||||
return sys.argv[1:]
|
||||
@@ -142,14 +143,6 @@ def _verify_model_args(
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
# Validate advanced training features
|
||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
@@ -347,6 +340,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -363,6 +359,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,59 +39,55 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
ray_storage_path: str = field(
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
ray_num_workers: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
master_addr: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master address for init_process_group"},
|
||||
)
|
||||
master_port: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The master port for init_process_group"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.use_ray = use_ray()
|
||||
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
|
||||
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
|
||||
|
||||
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
if self.ray_storage_filesystem is not None:
|
||||
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
|
||||
raise ValueError(
|
||||
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
|
||||
)
|
||||
|
||||
import pyarrow.fs as fs
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
|
||||
if self.ray_storage_filesystem == "s3":
|
||||
self.ray_storage_filesystem = fs.S3FileSystem()
|
||||
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
|
||||
self.ray_storage_filesystem = fs.GcsFileSystem()
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import is_torch_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -203,6 +205,16 @@ def load_model(
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
# Conv3D is not recommended when using torch 2.9.x
|
||||
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||
raise ValueError(
|
||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||
"This combination is known to cause severe performance regression. "
|
||||
"Please downgrade torch to <2.9 or remove Conv3D. "
|
||||
"See https://github.com/pytorch/pytorch/issues/166122"
|
||||
)
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
model.eval()
|
||||
@@ -218,7 +230,7 @@ def load_model(
|
||||
)
|
||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
|
||||
@@ -356,7 +356,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
@@ -365,7 +365,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl_moe",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
@@ -374,7 +374,7 @@ _register_composite_model(
|
||||
_register_composite_model(
|
||||
model_type="qwen3_omni_moe_thinker",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
@@ -138,18 +138,25 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
||||
setattr(config.text_config, "topk_method", "greedy")
|
||||
|
||||
if "InternVLChatModel" in getattr(config, "architectures", []):
|
||||
architectures = getattr(config, "architectures", None)
|
||||
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
|
||||
raise ValueError(
|
||||
"Please download the internvl models in a Hugging Face–compatible format "
|
||||
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
||||
)
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||
|
||||
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
|
||||
raise RuntimeError(
|
||||
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
|
||||
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
|
||||
)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||
|
||||
|
||||
@@ -12,35 +12,45 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
try:
|
||||
# Use Transformer Engine backend (optimal for Hopper GPUs)
|
||||
if backend == "te":
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
logger.info_rank0("Using Transformer Engine FP8 backend")
|
||||
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
|
||||
|
||||
# Use TorchAO backend (default)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
# Create Float8LinearConfig if torchao backend is used
|
||||
config = None
|
||||
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return True
|
||||
|
||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if (
|
||||
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
|
||||
and training_args.fp8_enable_fsdp_float8_all_gather
|
||||
):
|
||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||
|
||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
||||
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
|
||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
"fp8" if FP8 is enabled, None otherwise
|
||||
"""
|
||||
return "fp8" if model_args.fp8 else None
|
||||
return "fp8" if training_args.fp8 else None
|
||||
|
||||
|
||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
|
||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||
|
||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
variables and validates the FP8 configuration.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
import os
|
||||
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||
|
||||
# Configure FP8 backend and options
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend != "auto":
|
||||
os.environ["FP8_BACKEND"] = backend
|
||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||
|
||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
||||
fp8_kwargs = create_fp8_kwargs(training_args)
|
||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||
|
||||
# Enable FSDP float8 all-gather optimization if requested
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||
|
||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||
|
||||
|
||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
|
||||
"""Verify that FP8 training is actually working after model preparation.
|
||||
|
||||
Args:
|
||||
accelerator: The HuggingFace Accelerator instance
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Check Accelerate's FP8 status
|
||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend == "torchao" or backend == "auto":
|
||||
logger.info_rank0(
|
||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
|
||||
if not fp8_enabled:
|
||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||
|
||||
|
||||
def patch_accelerator_for_fp8() -> None:
|
||||
"""Patch Accelerator to inject FP8 recipe kwargs.
|
||||
|
||||
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
|
||||
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
|
||||
"""
|
||||
import transformer_engine.pytorch as te
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Guard against multiple patches
|
||||
if getattr(Accelerator, "_te_fp8_patched", False):
|
||||
return
|
||||
|
||||
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
|
||||
if not hasattr(te, "fp8"):
|
||||
te.fp8 = types.ModuleType("fp8")
|
||||
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
|
||||
|
||||
try:
|
||||
from accelerate.utils import TERecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = True
|
||||
except ImportError:
|
||||
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = False
|
||||
|
||||
original_init = Accelerator.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
|
||||
if use_te_recipe:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
else:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
# Only force mixed_precision when we inject handlers
|
||||
kwargs["mixed_precision"] = "fp8"
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
Accelerator.__init__ = patched_init
|
||||
Accelerator._te_fp8_patched = True
|
||||
|
||||
@@ -19,16 +19,15 @@ import torch
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
training_args: TrainingArguments = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
|
||||
else:
|
||||
return loss
|
||||
|
||||
@override
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.args.save_safetensors:
|
||||
from collections import defaultdict
|
||||
|
||||
ptrs = defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
ptrs[id(tensor)].append(name)
|
||||
|
||||
for names in ptrs.values():
|
||||
if len(names) > 1:
|
||||
names.sort()
|
||||
for name in names[1:]:
|
||||
state_dict.pop(name, None)
|
||||
|
||||
super()._save(output_dir, state_dict)
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
|
||||
@@ -27,18 +27,17 @@ from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
training_args: TrainingArguments = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
from ..trainer_utils import eaft_loss_func
|
||||
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -34,6 +33,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
|
||||
from ..extras.misc import get_device_name
|
||||
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
@@ -49,15 +49,15 @@ if is_apollo_available():
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, TrainerCallback, TrainerState
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments, RayArguments, TrainingArguments
|
||||
from ..hparams import DataArguments, TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -634,7 +634,9 @@ def get_batch_logps(
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
def dft_loss_func(
|
||||
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||
):
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
|
||||
|
||||
def _dft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
) -> "torch.Tensor":
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
|
||||
logits = logits.float()
|
||||
vocab_size = logits.size(-1)
|
||||
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
logits = logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
|
||||
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
|
||||
return loss
|
||||
|
||||
|
||||
def _eaft_cross_entropy(
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
|
||||
|
||||
valid_losses = per_token_loss[valid_mask]
|
||||
|
||||
with torch.no_grad():
|
||||
source_detached = source[valid_mask].detach()
|
||||
|
||||
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||
log_probs_topk = topk_val - logsumexp_topk
|
||||
probs_topk = torch.exp(log_probs_topk)
|
||||
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||
|
||||
entropy_term = entropy_approx / 3.0
|
||||
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||
|
||||
weighted_losses = valid_losses * adaptive_weight
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
total_loss = weighted_losses.sum()
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(total_loss.device)
|
||||
loss = total_loss / num_items_in_batch
|
||||
else:
|
||||
loss = weighted_losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
@@ -744,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
return swanlab_callback
|
||||
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
|
||||
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
|
||||
r"""Get the Ray placement group for distributed training."""
|
||||
bundle = {"CPU": 10}
|
||||
device_name = get_device_name().upper()
|
||||
if device_name != "CPU":
|
||||
bundle[device_name] = 1
|
||||
bundles = [bundle for _ in range(num_workers)]
|
||||
pg = placement_group(bundles, strategy="PACK")
|
||||
|
||||
if ray_args.ray_init_kwargs is not None:
|
||||
ray.init(**ray_args.ray_init_kwargs)
|
||||
return pg, bundle
|
||||
|
||||
if ray_args.ray_storage_filesystem is not None:
|
||||
# this means we are using s3/gcs
|
||||
storage_path = ray_args.ray_storage_path
|
||||
else:
|
||||
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_function,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=ray_args.ray_num_workers,
|
||||
resources_per_worker=ray_args.resources_per_worker,
|
||||
placement_strategy=ray_args.placement_strategy,
|
||||
use_gpu=True,
|
||||
def get_ray_remote_config_for_worker(
|
||||
placement_group: "PlacementGroup",
|
||||
bundle_idx: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
env: dict[str, str] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Get the remote config for a Ray worker."""
|
||||
env_vars = {
|
||||
"RANK": str(rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
"TORCHELASTIC_USE_AGENT_STORE": "False",
|
||||
}
|
||||
env.update(env_vars)
|
||||
|
||||
remote_config = {
|
||||
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=ray_args.ray_run_name,
|
||||
storage_filesystem=ray_args.ray_storage_filesystem,
|
||||
storage_path=storage_path,
|
||||
),
|
||||
)
|
||||
return trainer
|
||||
"runtime_env": {"env_vars": env},
|
||||
"num_cpus": 10,
|
||||
}
|
||||
|
||||
device_name = get_device_name()
|
||||
if device_name == "gpu":
|
||||
remote_config["num_gpus"] = 1
|
||||
elif device_name == "npu":
|
||||
remote_config["resources"] = {"NPU": 1}
|
||||
|
||||
return remote_config
|
||||
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
|
||||
r"""Sort the placement group bundles by their node IP addresses."""
|
||||
|
||||
@ray.remote
|
||||
def _get_node_ip():
|
||||
return ray.util.get_node_ip_address().strip("[]")
|
||||
|
||||
tasks = []
|
||||
for bundle_idx in range(placement_group.bundle_count):
|
||||
task = _get_node_ip.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_bundle_index=bundle_idx,
|
||||
),
|
||||
).remote()
|
||||
tasks.append(task)
|
||||
|
||||
bundle_ips = ray.get(tasks)
|
||||
bundle_node_ip_list = list(enumerate(bundle_ips))
|
||||
|
||||
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
|
||||
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
|
||||
|
||||
if master_addr is not None:
|
||||
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
|
||||
if preferred_indices:
|
||||
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
|
||||
sorted_bundle_indices = preferred_indices + remaining
|
||||
|
||||
return sorted_bundle_indices
|
||||
|
||||
@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .dpo import run_dpo
|
||||
@@ -34,12 +34,17 @@ from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .trainer_utils import get_ray_trainer, get_swanlab_callback
|
||||
from .trainer_utils import (
|
||||
get_placement_group,
|
||||
get_ray_head_node_ip,
|
||||
get_ray_remote_config_for_worker,
|
||||
get_swanlab_callback,
|
||||
sort_placement_group_by_node_ip,
|
||||
)
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.train.huggingface.transformers import RayTrainReportCallback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
|
||||
ray_args = get_ray_args(args)
|
||||
callbacks = callbacks or []
|
||||
if ray_args.use_ray:
|
||||
callbacks.append(RayTrainReportCallback())
|
||||
trainer = get_ray_trainer(
|
||||
training_function=_training_function,
|
||||
train_loop_config={"args": args, "callbacks": callbacks},
|
||||
ray_args=ray_args,
|
||||
)
|
||||
trainer.fit()
|
||||
_ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
|
||||
else:
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
with open(ollama_modelfile, "w", encoding="utf-8") as f:
|
||||
f.write(template.get_ollama_modelfile(tokenizer))
|
||||
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self):
|
||||
self._setup_env_visible_devices()
|
||||
|
||||
local_rank = os.environ.get("LOCAL_RANK", "0")
|
||||
get_torch_device().set_device(int(local_rank))
|
||||
|
||||
def _setup_env_visible_devices(self) -> None:
|
||||
RAY_NOSET_VISIBLE_DEVICES_LIST = [
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
|
||||
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
|
||||
]
|
||||
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
|
||||
if is_ray_noset_visible_devices:
|
||||
device_name = get_device_name().upper()
|
||||
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
|
||||
os.environ["LOCAL_RANK"] = local_rank
|
||||
else:
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
def _training_function(self, config: dict[str, Any]) -> None:
|
||||
_training_function(config)
|
||||
|
||||
|
||||
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
|
||||
num_workers = ray_args.ray_num_workers
|
||||
master_addr = ray_args.master_addr
|
||||
master_port = ray_args.master_port
|
||||
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
|
||||
|
||||
# initialize ray
|
||||
if not ray.is_initialized():
|
||||
if ray_args.ray_init_kwargs is not None:
|
||||
ray.init(**ray_args.ray_init_kwargs)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
# verify resources
|
||||
device_name = get_device_name().upper()
|
||||
total_devices = int(ray.cluster_resources().get(device_name, 0))
|
||||
if num_workers > total_devices:
|
||||
raise ValueError(
|
||||
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
|
||||
)
|
||||
|
||||
# verify master_addr
|
||||
if master_addr is None:
|
||||
master_addr = get_ray_head_node_ip()
|
||||
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
|
||||
else:
|
||||
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
|
||||
if master_addr not in nodes:
|
||||
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
|
||||
|
||||
# create placementgroup for resource management
|
||||
pg, bundle = get_placement_group(total_devices)
|
||||
ray.get(pg.ready())
|
||||
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
|
||||
|
||||
# get sorted_bundle_indices
|
||||
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
|
||||
|
||||
# get master port
|
||||
if master_port is None:
|
||||
master_port = find_available_port()
|
||||
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
|
||||
master_port = str(master_port)
|
||||
|
||||
# backing up environment variables
|
||||
current_env = dict(os.environ.items())
|
||||
|
||||
# launch workers
|
||||
RayWorker = ray.remote(Worker)
|
||||
workers = []
|
||||
for rank in range(num_workers):
|
||||
remote_config = get_ray_remote_config_for_worker(
|
||||
placement_group=pg,
|
||||
bundle_idx=sorted_bundle_indices[rank],
|
||||
rank=rank,
|
||||
world_size=num_workers,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
env=current_env,
|
||||
)
|
||||
worker = RayWorker.options(**remote_config).remote()
|
||||
workers.append(worker)
|
||||
|
||||
ray.get([worker._training_function.remote(config=config) for worker in workers])
|
||||
ray.shutdown()
|
||||
|
||||
@@ -119,9 +119,19 @@ def synchronize() -> None:
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
def set_device_index() -> None:
|
||||
"""Set current accelerator index to local rank."""
|
||||
if get_current_accelerator().type != DeviceType.CPU:
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def get_current_device() -> torch.device:
|
||||
"""Get current accelerator device."""
|
||||
if get_current_accelerator().type == DeviceType.CPU:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
@@ -170,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
|
||||
return result.tolist()
|
||||
|
||||
|
||||
def get_process_group_backend() -> str:
|
||||
"""Get backend for init process group."""
|
||||
if get_current_accelerator().type == DeviceType.NPU:
|
||||
return "hccl"
|
||||
elif get_current_accelerator().type == DeviceType.CUDA:
|
||||
return "nccl"
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||
world_size = get_world_size()
|
||||
|
||||
@@ -34,10 +34,14 @@ from typing import Any, Optional
|
||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from ..utils import logging
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
|
||||
from . import helper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
"""Dimension names."""
|
||||
|
||||
@@ -119,12 +123,13 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
helper.set_device_index()
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.current_device = helper.get_current_device()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
@@ -140,15 +145,14 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
@@ -157,11 +161,12 @@ class DistributedInterface:
|
||||
self.data_device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info_rank0(f"DistributedInterface initialized: {self}.")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@@ -169,7 +174,7 @@ class DistributedInterface:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
elif self.model_device_mesh is None:
|
||||
elif not self._is_distributed:
|
||||
return None
|
||||
elif dim in self.strategy.data_mesh_dim_names:
|
||||
return self.data_device_mesh[dim.value]
|
||||
@@ -178,14 +183,14 @@ class DistributedInterface:
|
||||
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
if not self._is_distributed or dim is None:
|
||||
return None
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 0
|
||||
elif dim is None:
|
||||
return self._rank
|
||||
@@ -194,7 +199,7 @@ class DistributedInterface:
|
||||
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
if not self._is_distributed:
|
||||
return 1
|
||||
elif dim is None:
|
||||
return self._world_size
|
||||
@@ -209,9 +214,9 @@ class DistributedInterface:
|
||||
"""Get parallel local world size."""
|
||||
return self._local_world_size
|
||||
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
@@ -220,30 +225,36 @@ class DistributedInterface:
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
if self._is_distributed:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Synchronize all processes."""
|
||||
helper.synchronize()
|
||||
if self._is_distributed:
|
||||
helper.synchronize()
|
||||
|
||||
def barrier(self) -> None:
|
||||
"""Barrier all processes."""
|
||||
barrier()
|
||||
if self._is_distributed:
|
||||
barrier()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy all processes."""
|
||||
destroy_process_group()
|
||||
if self._is_distributed:
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DistributedInterface(DistributedStrategy()))
|
||||
"""
|
||||
python -m llamafactory.v1.accelerator.interface
|
||||
"""
|
||||
print(DistributedInterface())
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
"ModelClass",
|
||||
"SampleArguments",
|
||||
"SampleBackend",
|
||||
"TrainingArguments",
|
||||
"get_args",
|
||||
]
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Any
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ...extras.misc import is_env_enabled
|
||||
from ..utils.env import is_env_enabled
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -30,24 +30,9 @@ from .training_args import TrainingArguments
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
data_args: DataArguments,
|
||||
model_args: ModelArguments,
|
||||
training_args: TrainingArguments,
|
||||
sample_args: SampleArguments,
|
||||
):
|
||||
"""Validate arguments."""
|
||||
if (
|
||||
model_args.quant_config is not None
|
||||
and training_args.dist_config is not None
|
||||
and training_args.dist_config.name == "deepspeed"
|
||||
):
|
||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
||||
|
||||
|
||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
|
||||
"""Parse arguments from command line or config file."""
|
||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
|
||||
|
||||
if args is None:
|
||||
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
validate_args(*parsed_args)
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from enum import StrEnum, unique
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
class ModelClass(str, Enum):
|
||||
class ModelClass(StrEnum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
LLM = "llm"
|
||||
@@ -45,11 +45,19 @@ class ModelClass(str, Enum):
|
||||
|
||||
|
||||
@unique
|
||||
class SampleBackend(str, Enum):
|
||||
class SampleBackend(StrEnum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@unique
|
||||
class BatchingStrategy(StrEnum):
|
||||
NORMAL = "normal"
|
||||
PADDING_FREE = "padding_free"
|
||||
DYNAMIC_BATCHING = "dynamic_batching"
|
||||
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
|
||||
|
||||
|
||||
def _convert_str_dict(data: dict) -> dict:
|
||||
"""Parse string representation inside the dictionary.
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: str | None = field(
|
||||
train_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
metadata={"help": "Path to the training dataset."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Cutoff length for the dataset."},
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the evaluation dataset."},
|
||||
)
|
||||
|
||||
@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model: str = field(
|
||||
default="Qwen/Qwen3-4B-Instruct-2507",
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
template: str = field(
|
||||
default="qwen3_nothink",
|
||||
metadata={"help": "Template for the model."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
use_fast_processor: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use fast processor from Hugging Face."},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
init_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Initialization configuration for the model."},
|
||||
)
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
@@ -49,6 +54,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
self.quant_config = get_plugin_config(self.quant_config)
|
||||
|
||||
@@ -16,35 +16,73 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
output_dir: str = field(
|
||||
default=os.path.join("outputs", str(uuid4())),
|
||||
default=os.path.join("outputs", str(uuid4().hex)),
|
||||
metadata={"help": "Path to the output directory."},
|
||||
)
|
||||
micro_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Micro batch size for training."},
|
||||
)
|
||||
global_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Global batch size for training."},
|
||||
global_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
num_train_epochs: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of training epochs."},
|
||||
)
|
||||
max_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
|
||||
)
|
||||
max_grad_norm: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum gradient norm for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
batching_strategy: BatchingStrategy = field(
|
||||
default=BatchingStrategy.NORMAL,
|
||||
metadata={"help": "Batching strategy for training."},
|
||||
)
|
||||
batching_workers: int = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
optim_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer configuration for training."},
|
||||
)
|
||||
lr_scheduler_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
67
src/llamafactory/v1/core/base_sampler.py
Normal file
67
src/llamafactory/v1/core/base_sampler.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .utils.inference_engine import HuggingFaceEngine
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base sampler.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
if args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
async for token in self.engine.generate(messages, tools):
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return await self.engine.batch_infer(dataset)
|
||||
@@ -16,20 +16,33 @@
|
||||
|
||||
Init Phase:
|
||||
|
||||
1. Init dataloader.
|
||||
1. Init batch generator.
|
||||
2. Init optimizer (deepspeed).
|
||||
3. Shard model.
|
||||
4. Init optimizer (fsdp).
|
||||
5. Init scheduler.
|
||||
5. Init lr scheduler.
|
||||
|
||||
Train Phase:
|
||||
1. Train Loop
|
||||
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.helper import compute_valid_tokens
|
||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||
from .utils.batching import BatchGenerator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
@@ -37,22 +50,160 @@ class BaseTrainer:
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
dataset: TorchDataset,
|
||||
renderer: Renderer,
|
||||
train_dataset: TorchDataset,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.dataset = dataset
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.renderer = renderer
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
pass
|
||||
# info
|
||||
self.global_step = 0
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
pass
|
||||
# cached variables
|
||||
self.device = DistributedInterface().current_device
|
||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
self.model_input_names = self.renderer.processor.model_input_names
|
||||
|
||||
self._create_batch_generator()
|
||||
# Calculate num_training_steps: max_steps takes priority if set
|
||||
if self.args.max_steps is not None and self.args.max_steps > 0:
|
||||
self.num_training_steps = self.args.max_steps
|
||||
else:
|
||||
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
|
||||
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
if self.args.dist_config is not None:
|
||||
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
||||
else:
|
||||
shard_need_optimizer = False
|
||||
|
||||
if shard_need_optimizer:
|
||||
self._init_optimizer()
|
||||
self._shard_model()
|
||||
else:
|
||||
self._shard_model()
|
||||
self._init_optimizer()
|
||||
|
||||
self._init_lr_scheduler()
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
micro_batch_size=self.args.micro_batch_size,
|
||||
global_batch_size=self.args.global_batch_size,
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
if self.args.dist_config is None:
|
||||
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
logger.warning_rank0(
|
||||
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
|
||||
)
|
||||
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||
self.model = DDP(self.model, device_ids=device_ids)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
self.model = DistributedPlugin(self.args.dist_config.name)(
|
||||
self.model,
|
||||
self.args.dist_config,
|
||||
)
|
||||
|
||||
def _init_optimizer(self) -> None:
|
||||
"""Init optimizer."""
|
||||
if self.args.optim_config is None:
|
||||
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
|
||||
|
||||
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
|
||||
|
||||
def _init_lr_scheduler(self) -> None:
|
||||
"""Init lr scheduler."""
|
||||
if self.args.lr_scheduler_config is None:
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
|
||||
else:
|
||||
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
|
||||
|
||||
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
|
||||
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
|
||||
)
|
||||
|
||||
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
|
||||
"""Compute log probs.
|
||||
|
||||
log_probs: Tensor of shape (batch_size, seq_len - 1)
|
||||
"""
|
||||
batch_size, _ = batch["labels"].shape
|
||||
model_inputs = {
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
||||
}
|
||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||
outputs: ModelOutput = model(**model_inputs)
|
||||
logits = outputs.logits.float()
|
||||
shift_labels = labels[..., 1:].contiguous().view(-1)
|
||||
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
|
||||
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
"""Compute the scalar loss."""
|
||||
...
|
||||
|
||||
def fit(self) -> None:
|
||||
pass
|
||||
"""Train the model."""
|
||||
self.model.train()
|
||||
for epoch in range(self.args.num_train_epochs):
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
for micro_batch in micro_batches:
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
|
||||
loss.backward()
|
||||
step_loss += loss.item()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
|
||||
# Check if max_steps is reached
|
||||
if self.global_step >= self.num_training_steps:
|
||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||
return
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..config.sample_args import SampleArguments, SampleBackend
|
||||
from .model_loader import ModelLoader
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
self.args = sample_args
|
||||
|
||||
|
||||
class ChatSampler:
|
||||
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
|
||||
if sample_args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(model_loader, sample_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")
|
||||
@@ -14,15 +14,23 @@
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
How to use:
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
data_engine[i]: Get the sample via index.
|
||||
|
||||
Init workflow:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
Get data sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
|
||||
Note:
|
||||
1. The data engine is equivalent to the torch dataset.
|
||||
2. The data engine is agnostic to the model used.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -33,7 +41,6 @@ from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..config.data_args import DataArguments
|
||||
from ..utils.types import DatasetInfo, HFDataset, Sample
|
||||
|
||||
|
||||
@@ -44,9 +51,9 @@ class DataEngine(Dataset):
|
||||
data_args: Data arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
self.args = data_args
|
||||
"""Data arguments."""
|
||||
def __init__(self, dataset_path: str) -> None:
|
||||
self.path = dataset_path
|
||||
"""Dataset path."""
|
||||
self.datasets: dict[str, HFDataset] = {}
|
||||
"""Dict of (dataset_name, dataset)"""
|
||||
self.dataset_infos: dict[str, DatasetInfo] = {}
|
||||
@@ -61,27 +68,30 @@ class DataEngine(Dataset):
|
||||
|
||||
def _get_dataset_info(self) -> None:
|
||||
"""Get dataset info from data arguments."""
|
||||
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.args.dataset)
|
||||
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.args.dataset)
|
||||
if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
|
||||
self.dataset_infos = OmegaConf.load(self.path)
|
||||
elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
|
||||
repo_id, filename = os.path.split(self.path)
|
||||
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
||||
self.dataset_infos = OmegaConf.load(filepath)
|
||||
elif os.path.exists(self.args.dataset): # local file(s)
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
|
||||
elif os.path.exists(self.path): # local file(s)
|
||||
self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
|
||||
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
|
||||
self.dataset_infos = {"default": {"path": self.args.dataset}}
|
||||
self.dataset_infos = {"default": {"path": self.path}}
|
||||
|
||||
def _load_dataset(self) -> None:
|
||||
"""Load datasets according to dataset info."""
|
||||
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
|
||||
self.streaming = any(is_streaming)
|
||||
if all(is_streaming) != any(is_streaming):
|
||||
raise ValueError("All datasets must be streaming or non-streaming.")
|
||||
|
||||
for dataset_name, dataset_info in self.dataset_infos.items():
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
self.streaming |= streaming
|
||||
if dataset_info.get("source", "hf_hub") == "hf_hub":
|
||||
from datasets import load_dataset
|
||||
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
|
||||
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
|
||||
else: # data loader plugin
|
||||
from ..plugins.data_plugins.loader import DataLoaderPlugin
|
||||
|
||||
@@ -90,18 +100,17 @@ class DataEngine(Dataset):
|
||||
def _build_data_index(self) -> None:
|
||||
"""Build dataset index."""
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
streaming = self.dataset_infos[dataset_name].get("streaming", False)
|
||||
if streaming:
|
||||
if self.streaming:
|
||||
data_index = [(dataset_name, -1) for _ in range(1000)]
|
||||
else:
|
||||
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
if size or weight:
|
||||
from ..plugins.data_plugins.loader import adjust_data_index
|
||||
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
data_index = adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
|
||||
@@ -150,9 +159,9 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
from ..plugins.data_plugins.loader import select_data_sample
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
selected_index = select_data_sample(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
@@ -177,11 +186,11 @@ class DataEngine(Dataset):
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
|
||||
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
_, data_args, *_ = get_args()
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
print(data_engine[0])
|
||||
|
||||
@@ -12,34 +12,44 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
"""The definition of model engine.
|
||||
|
||||
Init Phase:
|
||||
How to use:
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
model_engine.processor: Get the tokenizer or multi-modal processor.
|
||||
model_engine.renderer: Get the renderer.
|
||||
model_engine.model_config: Get the model configuration.
|
||||
model_engine.model: Get the HF model.
|
||||
|
||||
Init workflow:
|
||||
1. Init processor.
|
||||
2. Init render.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from ..accelerator.helper import DeviceType
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
class ModelEngine:
|
||||
"""Model engine.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
is_train: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
@@ -49,17 +59,22 @@ class ModelLoader:
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.renderer = Renderer(self.args.template, self.processor)
|
||||
"""Renderer."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
"""HF model."""
|
||||
|
||||
def _init_processor(self) -> Processor:
|
||||
"""Init processor."""
|
||||
"""Init processor.
|
||||
|
||||
NOTE: Transformers v5 always use fast tokenizer.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
|
||||
"""
|
||||
return AutoProcessor.from_pretrained(
|
||||
self.args.model,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.use_fast_processor,
|
||||
)
|
||||
|
||||
def _init_model_config(self) -> HFConfig:
|
||||
@@ -72,7 +87,7 @@ class ModelLoader:
|
||||
def _init_model(self) -> HFModel:
|
||||
"""Init model.
|
||||
|
||||
Let transformers handle the model init context.
|
||||
Transformers can choose the proper model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
@@ -92,14 +107,24 @@ class ModelLoader:
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
# map the entire model to the current accelerator
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
with init_empty_weights():
|
||||
model = AutoClass.from_config(self.model_config)
|
||||
else:
|
||||
model = AutoClass.from_pretrained(
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=init_device,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
@@ -116,7 +141,7 @@ class ModelLoader:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -124,12 +149,12 @@ class ModelLoader:
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
model_args, *_ = get_args()
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
print(model_engine.processor)
|
||||
print(model_engine.model_config)
|
||||
print(model_engine.model)
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from ....extras.constants import IGNORE_INDEX
|
||||
from ...plugins.data_plugins.template import Template
|
||||
from ...utils.types import Processor, Tensor
|
||||
|
||||
|
||||
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
|
||||
"""Convert sequence lengths to cumulative sequence lengths."""
|
||||
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
|
||||
def __post_init__(self):
|
||||
# callback for text tokenizer
|
||||
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
batch = defaultdict(list)
|
||||
|
||||
# batching features
|
||||
for feature in features:
|
||||
for key in feature.keys():
|
||||
batch[key].append(feature[key])
|
||||
|
||||
for key in batch.keys():
|
||||
# process padding features
|
||||
if key in ["input_ids", "attention_mask", "position_ids"]:
|
||||
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
|
||||
elif key in ["labels"]:
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
|
||||
else:
|
||||
batch[key] = default_collate(batch[key])
|
||||
|
||||
return batch
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultCollator(DataCollator):
|
||||
"""Example for now."""
|
||||
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
|
||||
features = []
|
||||
|
||||
# Check if data is already tokenized (contains input_ids)
|
||||
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
|
||||
for feature in messages:
|
||||
if not isinstance(feature, dict):
|
||||
raise ValueError(f"Expected dict but got {type(feature)}")
|
||||
tensor_feature = {
|
||||
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
|
||||
for k, v in feature.items()
|
||||
}
|
||||
features.append(tensor_feature)
|
||||
else:
|
||||
# raw messages need to be encoded
|
||||
for message in messages:
|
||||
encoded_message = self.template.encode_messages(self.tokenizer, message)
|
||||
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
|
||||
features.append(encoded_message)
|
||||
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseCollator(DataCollator):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithPacking(DefaultCollator):
|
||||
"""Data collator with packing."""
|
||||
|
||||
processor: "Processor"
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
|
||||
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
|
||||
batch = {"cu_seqlens": len2culen(seqlens)}
|
||||
for input_name in features[0].keys():
|
||||
if input_name in ("input_ids", "attention_mask", "labels"):
|
||||
batch[input_name] = torch.cat([feature[input_name] for feature in features])
|
||||
else:
|
||||
batch[input_name] = default_collate([feature[input_name] for feature in features])
|
||||
|
||||
return batch
|
||||
@@ -1,277 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# base dataloader
|
||||
class DistributedDataloader(StatefulDataLoader):
|
||||
"""Base Distributed DataLoader."""
|
||||
|
||||
dataset: "TorchDataset"
|
||||
sampler: "StatefulDistributedSampler"
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
||||
self.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: Processor
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# guidlines: fetch until get fixed batchsize.
|
||||
# save state_dict for buffer.
|
||||
# resume with state
|
||||
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
|
||||
def init_dataloader(self) -> None:
|
||||
### init dataloader
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
pass
|
||||
|
||||
def __next__(self) -> any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: "Processor"
|
||||
dataloader: "DistributedDataloader"
|
||||
batching_queue: "BaseBatchingQueue"
|
||||
collate_fn: "DataCollator"
|
||||
num_micro_batch: int = 1
|
||||
length: int = 0
|
||||
drop_last: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: any,
|
||||
collate_fn: "DataCollator",
|
||||
num_micro_batch: int = 1,
|
||||
length: int = 0,
|
||||
drop_last: bool = True,
|
||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
||||
) -> None:
|
||||
self.batching_queue = batching_queue
|
||||
self.num_micro_batch = num_micro_batch
|
||||
self.step = 0
|
||||
self._collate_fn = collate_fn
|
||||
self._dataloader = dataloader
|
||||
self._drop_last = drop_last
|
||||
self._data_iter: Iterator
|
||||
self._resume = False
|
||||
self._batch_data_iter: Generator
|
||||
|
||||
if length > 0:
|
||||
self._length = length
|
||||
elif length == -1:
|
||||
self._length = sys.maxsize
|
||||
else:
|
||||
self._length = len(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if not self._resume:
|
||||
self.step = 0
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
self._resume = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
||||
|
||||
def origin_batch_data_generator(self):
|
||||
"""Standard pass-through generator if do not use batching queue."""
|
||||
while True:
|
||||
if self._length > 0 and self.step >= self._length:
|
||||
return
|
||||
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
# split data into micro batches
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
if self.step < self._length:
|
||||
# Restart iterator to fill the requested length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
||||
raise
|
||||
|
||||
def batch_data_generator(self):
|
||||
if self.batching_queue is None:
|
||||
yield from self.origin_batch_data_generator()
|
||||
return
|
||||
|
||||
batch = []
|
||||
|
||||
while True:
|
||||
if self._length and self.step >= self._length:
|
||||
return
|
||||
|
||||
if self.batching_queue.is_full_filled():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
try:
|
||||
processing_item = next(self._data_iter)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration):
|
||||
if self.step < self._length:
|
||||
# call iter until reach length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
processing_item = next(self._data_iter)
|
||||
elif not self._drop_last and not self.batching_queue.empty():
|
||||
while not self.batching_queue.empty():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
while len(batch) < self.num_micro_batch:
|
||||
padding_batch = copy.deepcopy(micro_batch)
|
||||
padding_batch["is_padded"] = True
|
||||
batch.append(padding_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error(f"DataLoader iter data exception: {e}")
|
||||
raise
|
||||
|
||||
# put processing_item to buffer
|
||||
if isinstance(processing_item, dict):
|
||||
processing_item = [processing_item]
|
||||
|
||||
for item in processing_item:
|
||||
self.batching_queue.put_item(item)
|
||||
|
||||
def state_dict(self):
|
||||
# save state
|
||||
state = self.__dict__.copy()
|
||||
# remove internal fields
|
||||
for k in list(state.keys()):
|
||||
if k.startswith("_"):
|
||||
del state[k]
|
||||
|
||||
# save dataloader state
|
||||
if hasattr(self._dataloader, "state_dict"):
|
||||
state["dataloader_state"] = self._dataloader.state_dict()
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
||||
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
||||
if "batching_strategy" in state:
|
||||
del state["batching_strategy"]
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def load_state_dict(self, state: dict[str, any]):
|
||||
if state["num_micro_batch"] != self.num_micro_batch:
|
||||
logger.warning(
|
||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
||||
)
|
||||
del state["num_micro_batch"]
|
||||
self.__dict__.update(state)
|
||||
self._resume = True
|
||||
|
||||
if hasattr(self._dataloader, "load_state_dict"):
|
||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
self._dataloader.__setstate__(state["dataloader_state"])
|
||||
|
||||
if "batching_strategy_state" in state:
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy:
|
||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
||||
del state["batching_strategy_state"]
|
||||
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._dataloader, "set_epoch"):
|
||||
self._dataloader.set_epoch(epoch)
|
||||
244
src/llamafactory/v1/core/utils/batching.py
Normal file
244
src/llamafactory/v1/core/utils/batching.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Batching utils supports stateful dataloader.
|
||||
|
||||
1. Init stateful dataloader (tokenize)
|
||||
2. Add to buffer
|
||||
3. Yield batch indexes (micro batch * grad acc)
|
||||
a) non pack + non dynamic
|
||||
b) non pack + dynamic
|
||||
c) pack + non dynamic
|
||||
d) pack + dynamic
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ...accelerator.interface import Dim, DistributedInterface
|
||||
from ...config import BatchingStrategy
|
||||
from ...utils import logging
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.objects import StatefulBuffer
|
||||
from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset
|
||||
from .rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
num_micro_batch = batch_info["num_micro_batch"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return None
|
||||
|
||||
samples = buffer.get(batch_size)
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class BatchGenerator(Iterator):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: TorchDataset,
|
||||
renderer: Renderer,
|
||||
micro_batch_size: int = 1,
|
||||
global_batch_size: int | None = None,
|
||||
cutoff_len: int = 2048,
|
||||
batching_workers: int = 0,
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.global_batch_size = global_batch_size
|
||||
self.cutoff_len = cutoff_len
|
||||
self.batching_workers = batching_workers
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
# TODO: support length and infinity
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
if self.global_batch_size is None:
|
||||
self.global_batch_size = dp_size * micro_batch_size
|
||||
self.num_micro_batch = 1
|
||||
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
|
||||
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"Global batch size must be divisible by DP size and micro batch size. "
|
||||
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
|
||||
)
|
||||
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer = StatefulBuffer()
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
"data_iter": self._data_iter,
|
||||
}
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
f"num micro batch {self.num_micro_batch}, "
|
||||
f"cutoff len {self.cutoff_len}, "
|
||||
f"batching workers {self.batching_workers}, "
|
||||
f"batching strategy {self.batching_strategy}."
|
||||
)
|
||||
|
||||
def _init_data_provider(self) -> None:
|
||||
if len(self.dataset) != -1:
|
||||
sampler = StatefulDistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __iter__(self):
|
||||
if not self._is_resuming:
|
||||
self._buffer.clear()
|
||||
self._buffer_tokens = 0
|
||||
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._is_resuming = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self._fill_buffer()
|
||||
batch = self._generate_batch()
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
|
||||
return batch
|
||||
|
||||
def _fill_buffer(self) -> None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
while len(self._buffer) < self.micro_batch_size * self.num_micro_batch:
|
||||
try:
|
||||
samples: list[ModelInput] = next(self._data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
self._buffer.put(samples)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
|
||||
|
||||
def _generate_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
return default_collate_fn(self._buffer, self._batch_info)
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer,
|
||||
"buffer_tokens": self._buffer_tokens,
|
||||
"data_provider": self._data_provider.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self._buffer = state["buffer"]
|
||||
self._buffer_tokens = state["buffer_tokens"]
|
||||
self._data_provider.load_state_dict(state["data_provider"])
|
||||
self._is_resuming = True
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._data_provider.sampler, "set_epoch"):
|
||||
self._data_provider.sampler.set_epoch(epoch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.utils.batching \
|
||||
--model llamafactory/tiny-random-qwen2.5 \
|
||||
--train_dataset data/v1_sft_demo.yaml \
|
||||
--micro_batch_size 2 \
|
||||
--global_batch_size 4 \
|
||||
--batching_workers 0
|
||||
"""
|
||||
from ...config.arg_parser import get_args
|
||||
from ..data_engine import DataEngine
|
||||
from ..model_engine import ModelEngine
|
||||
|
||||
model_args, data_args, training_args, _ = get_args()
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
batch_generator = BatchGenerator(
|
||||
data_engine,
|
||||
model_engine.renderer,
|
||||
micro_batch_size=training_args.micro_batch_size,
|
||||
global_batch_size=training_args.global_batch_size,
|
||||
cutoff_len=training_args.cutoff_len,
|
||||
batching_workers=training_args.batching_workers,
|
||||
batching_strategy=training_args.batching_strategy,
|
||||
)
|
||||
for batch in batch_generator:
|
||||
print(batch)
|
||||
print(len(batch))
|
||||
print(batch[0]["input_ids"].shape)
|
||||
break
|
||||
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
121
src/llamafactory/v1/core/utils/inference_engine.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import AsyncTextIteratorStreamer
|
||||
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...config import ModelArguments, SampleArguments
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = AsyncTextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
async for token in streamer:
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
169
src/llamafactory/v1/core/utils/rendering.py
Normal file
169
src/llamafactory/v1/core/utils/rendering.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rendering utils.
|
||||
|
||||
How to use:
|
||||
renderer = Renderer(template, processor)
|
||||
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
|
||||
renderer.parse_message(text: str) -> Message
|
||||
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor, Sample
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||
"""
|
||||
tokenizer = get_tokenizer(processor)
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
|
||||
for message in messages:
|
||||
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
if is_generate:
|
||||
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([0.0] * len(temp_ids))
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=[1] * len(input_ids),
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self, template: str, processor: Processor):
|
||||
self.template = template
|
||||
self.processor = processor
|
||||
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
"""Apply template to messages and convert them to model input.
|
||||
|
||||
Args:
|
||||
messages (list[Message]): The messages to render.
|
||||
tools (str | None, optional): The tools to use. Defaults to None.
|
||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelInput: The rendered model input.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
"""Parse a message in the template format.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
|
||||
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
|
||||
"""Process samples to model input.
|
||||
|
||||
Args:
|
||||
samples (list[Sample]): The samples to process.
|
||||
|
||||
Returns:
|
||||
list[ModelInput]: The processed model inputs.
|
||||
"""
|
||||
model_inputs = []
|
||||
for sample in samples:
|
||||
if "messages" in sample:
|
||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
|
||||
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
|
||||
model_input = ModelInput(
|
||||
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
|
||||
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
|
||||
labels=chosen_input["labels"] + rejected_input["labels"],
|
||||
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
|
||||
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
|
||||
)
|
||||
if "position_ids" in chosen_input:
|
||||
model_input["position_ids"] = np.concatenate(
|
||||
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
|
||||
)
|
||||
else:
|
||||
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
|
||||
|
||||
if "extra_info" in sample:
|
||||
model_input["extra_info"] = sample["extra_info"]
|
||||
|
||||
if "_dataset_name" in sample:
|
||||
model_input["_dataset_name"] = sample["_dataset_name"]
|
||||
|
||||
model_inputs.append(model_input)
|
||||
|
||||
return model_inputs
|
||||
@@ -12,9 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from ..extras.env import VERSION, print_env
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
USAGE = (
|
||||
@@ -27,40 +28,152 @@ USAGE = (
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
|
||||
|
||||
|
||||
def launch():
|
||||
from .accelerator.helper import get_device_count
|
||||
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
|
||||
from .utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# NOTE:
|
||||
# `llamafactory-cli <command> ...` enters here first.
|
||||
# We may re-launch via `torchrun` for distributed training. In that case we must
|
||||
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
|
||||
# will misinterpret the first user argument (e.g. yaml config) as the command.
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
|
||||
if command == "sft": # train command will fallback to sft command
|
||||
from .trainers.sft_trainer import run_sft
|
||||
if command in _DIST_TRAIN_COMMANDS and (
|
||||
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
|
||||
):
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||
if int(nnodes) > 1:
|
||||
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
run_sft()
|
||||
# elastic launch support
|
||||
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||
rdzv_id = os.getenv("RDZV_ID")
|
||||
min_nnodes = os.getenv("MIN_NNODES")
|
||||
max_nnodes = os.getenv("MAX_NNODES")
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
|
||||
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
torchrun_args = [
|
||||
"torchrun",
|
||||
"--nproc-per-node",
|
||||
nproc_per_node,
|
||||
]
|
||||
if rdzv_id is not None:
|
||||
# launch elastic job with fault tolerant support when possible
|
||||
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
|
||||
rdzv_nnodes = nnodes
|
||||
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||
if min_nnodes is not None and max_nnodes is not None:
|
||||
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||
|
||||
torchrun_args.extend(
|
||||
[
|
||||
"--nnodes",
|
||||
rdzv_nnodes,
|
||||
"--rdzv-id",
|
||||
rdzv_id,
|
||||
"--rdzv-backend",
|
||||
"c10d",
|
||||
"--rdzv-endpoint",
|
||||
f"{master_addr}:{master_port}",
|
||||
"--max-restarts",
|
||||
max_restarts,
|
||||
]
|
||||
)
|
||||
else:
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
torchrun_args.extend(
|
||||
[
|
||||
"--nnodes",
|
||||
nnodes,
|
||||
"--node_rank",
|
||||
node_rank,
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
master_port,
|
||||
]
|
||||
)
|
||||
|
||||
script_args = [__file__, command] + sys.argv[1:]
|
||||
process = subprocess.run(
|
||||
torchrun_args + script_args,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
|
||||
sys.exit(process.returncode)
|
||||
|
||||
elif command == "chat":
|
||||
from .samplers.cli_sampler import run_chat
|
||||
|
||||
run_chat()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
raise NotImplementedError("Environment information is not implemented yet.")
|
||||
|
||||
elif command == "version":
|
||||
print(WELCOME)
|
||||
raise NotImplementedError("Version information is not implemented yet.")
|
||||
|
||||
elif command == "help":
|
||||
print(USAGE)
|
||||
|
||||
elif command in _DIST_TRAIN_COMMANDS:
|
||||
# Single GPU training without torchrun
|
||||
if command in ("train", "sft"):
|
||||
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
elif command == "dpo":
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
def main():
|
||||
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
|
||||
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
|
||||
|
||||
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
|
||||
if command in _DIST_TRAIN_COMMANDS:
|
||||
sys.argv.pop(1)
|
||||
else:
|
||||
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
|
||||
# treat it as sft by default.
|
||||
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
|
||||
command = "sft"
|
||||
if command in ("train", "sft"):
|
||||
from llamafactory.v1.trainers.sft_trainer import run_sft
|
||||
|
||||
run_sft()
|
||||
elif command == "dpo":
|
||||
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
|
||||
# run_dpo()
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
# run_rm()
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
main()
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -31,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
|
||||
|
||||
|
||||
SharegptMessage = TypedDict(
|
||||
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
|
||||
"SharegptMessage",
|
||||
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
|
||||
)
|
||||
|
||||
|
||||
@@ -61,7 +63,7 @@ class DataConverterPlugin(BasePlugin):
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
@DataConverterPlugin("alpaca").register()
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@@ -98,7 +100,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
@DataConverterPlugin("sharegpt").register()
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
@@ -117,18 +119,26 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"observation": "tool",
|
||||
"function_call": "assistant",
|
||||
}
|
||||
sample = {}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
@@ -141,16 +151,20 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
}
|
||||
)
|
||||
|
||||
sample["messages"] = messages
|
||||
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return {"messages": messages}
|
||||
return sample
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
@DataConverterPlugin("pair").register()
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
@@ -166,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
def process_message(raw_messages: list[OpenaiMessage]):
|
||||
messages = []
|
||||
for message in raw_messages:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
if message["role"] == "tool":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [{"type": "text", "value": message["content"]}],
|
||||
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
chosen_messages = process_message(raw_sample.get("chosen", []))
|
||||
rejected_messages = process_message(raw_sample.get("rejected", []))
|
||||
sample = {}
|
||||
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
|
||||
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
|
||||
|
||||
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
sample["tools"] = json.dumps(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
|
||||
return sample
|
||||
|
||||
@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
|
||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||
|
||||
|
||||
@DataLoaderPlugin("local").register
|
||||
@DataLoaderPlugin("local").register()
|
||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
|
||||
return dataset
|
||||
|
||||
|
||||
class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
def adjust_data_index(
|
||||
data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
return data_index
|
||||
return data_index
|
||||
|
||||
|
||||
class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
def select_data_sample(
|
||||
data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
user_template: str
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from ...accelerator.helper import DeviceType
|
||||
from ...accelerator.interface import DistributedInterface
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
class InitPlugin(BasePlugin):
|
||||
def __call__(self) -> torch.device:
|
||||
return super().__call__()
|
||||
|
||||
|
||||
@InitPlugin("init_on_meta").register()
|
||||
def init_on_meta() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_rank0").register()
|
||||
def init_on_rank0() -> torch.device:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_default").register()
|
||||
def init_on_default() -> torch.device:
|
||||
return DistributedInterface().current_device
|
||||
|
||||
@@ -38,17 +38,17 @@ class BaseKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
r"""Returns the unique identifier for the kernel."""
|
||||
"""Returns the unique identifier for the kernel."""
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
r"""Checks if the required dependencies for the kernel are available.
|
||||
"""Checks if the required dependencies for the kernel are available.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the kernel optimization to the model.
|
||||
"""Applies the kernel optimization to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||
|
||||
@@ -24,16 +24,17 @@ Init Phase:
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils import logging
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import HFModel
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
r"""Scan all kernels in the ``ops`` directory.
|
||||
"""Scan all kernels in the ``ops`` directory.
|
||||
|
||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||
@@ -77,7 +78,7 @@ default_kernels = scan_all_kernels()
|
||||
|
||||
|
||||
def get_default_kernels():
|
||||
r"""Get a list of default registered kernel IDs.
|
||||
"""Get a list of default registered kernel IDs.
|
||||
|
||||
Returns:
|
||||
list[str]: List of kernel IDs.
|
||||
@@ -86,7 +87,7 @@ def get_default_kernels():
|
||||
|
||||
|
||||
def apply_kernel(kernel_id: str, **kwargs):
|
||||
r"""Applies a specific kernel to the model.
|
||||
"""Applies a specific kernel to the model.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to apply.
|
||||
@@ -99,34 +100,41 @@ def apply_kernel(kernel_id: str, **kwargs):
|
||||
kernel = default_kernels.get(kernel_id)
|
||||
if kernel is None:
|
||||
raise ValueError(f"Kernel {kernel_id} not found")
|
||||
|
||||
kernel.apply(**kwargs)
|
||||
|
||||
|
||||
class KernelPlugin(BasePlugin):
|
||||
r"""Plugin for managing kernel optimizations."""
|
||||
"""Plugin for managing kernel optimizations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@KernelPlugin("auto").register
|
||||
def apply_default_kernels(**kwargs):
|
||||
r"""Applies all default registered kernels to the model.
|
||||
@KernelPlugin("auto").register()
|
||||
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
|
||||
"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance and the include_kernels configuration.
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
|
||||
If "auto" or True, applies all default kernels.
|
||||
If None or False, no kernels are applied.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernels.
|
||||
"""
|
||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||
return kwargs.get("model")
|
||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||
if not include_kernels:
|
||||
return model
|
||||
elif include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
apply_kernel(kernel, **kwargs)
|
||||
return kwargs.get("model")
|
||||
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return model
|
||||
|
||||
@@ -40,11 +40,11 @@ from ...registry import register_kernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors for backward pass.
|
||||
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class HybridGmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, num_experts, *args):
|
||||
r"""Performs the forward pass of Hybrid GMM.
|
||||
"""Performs the forward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors.
|
||||
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
r"""Performs the backward pass of Hybrid GMM.
|
||||
"""Performs the backward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class NpuMoeFused:
|
||||
r"""Container for NPU fused MoE forward functions."""
|
||||
"""Container for NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""Forward pass for MoE experts using NPU fused operations.
|
||||
"""Forward pass for MoE experts using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The MoE layer instance.
|
||||
@@ -230,11 +230,11 @@ class NpuMoeFused:
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The Qwen3 MoE block instance.
|
||||
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
|
||||
|
||||
@register_kernel
|
||||
class NpuFusedMoEKernel(BaseKernel):
|
||||
r"""NPU Fused MoE Kernel implementation."""
|
||||
"""NPU Fused MoE Kernel implementation."""
|
||||
|
||||
_kernel_id = "npu_fused_moe"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the NPU fused MoE kernel to the model.
|
||||
"""Applies the NPU fused MoE kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
@@ -324,7 +324,7 @@ class NpuFusedMoEKernel(BaseKernel):
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
|
||||
|
||||
archs = getattr(model.config, "architectures", [])
|
||||
archs = getattr(model.config, "architectures", None) or []
|
||||
target_moe_mapping = None
|
||||
for arch in archs:
|
||||
if arch in kernel_moe_mapping:
|
||||
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
|
||||
|
||||
if target_moe_mapping is None:
|
||||
return model
|
||||
|
||||
for module in model.modules():
|
||||
class_name = module.__class__.__name__
|
||||
if class_name in target_moe_mapping:
|
||||
|
||||
@@ -38,7 +38,7 @@ except ImportError:
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
r"""SwiGLU forward pass for NPU.
|
||||
"""SwiGLU forward pass for NPU.
|
||||
|
||||
Args:
|
||||
self: The MLP layer instance.
|
||||
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||
"""SwiGLU forward pass for GLM4 on NPU.
|
||||
|
||||
Args:
|
||||
self: The GLM4 MLP layer instance.
|
||||
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
|
||||
Args:
|
||||
self: The Gemma3nText MLP layer instance.
|
||||
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
|
||||
@register_kernel
|
||||
class NpuSwiGluKernel(BaseKernel):
|
||||
r"""NPU Kernel for fused SwiGLU activation."""
|
||||
"""NPU Kernel for fused SwiGLU activation."""
|
||||
|
||||
# just support apply to the following module layers
|
||||
expect_modules = frozenset(
|
||||
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user