38 Commits
v0.9.4 ... main

Author SHA1 Message Date
xvxuopop
762b480131 [feature] support using ray.remote to start distributed training. (#10109) 2026-01-28 16:05:29 +08:00
Jewon Lee
9640f79ae5 [fix] add visual.pos_embed to Qwen3-VL visual model keys (#10139) 2026-01-27 16:33:01 +08:00
jiaqiw09
7ef19eea00 [v0] Fix reward model training safetensors saving (#10137) 2026-01-27 16:27:14 +08:00
浮梦
f9f11dcb97 [v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-25 19:41:58 +08:00
Pádraic Slattery
641bfdd482 chore: Update outdated GitHub Actions versions (#10123) 2026-01-25 19:12:39 +08:00
Meng WANG
e70651ac58 [feat] support all_exhausted_without_replacement in datasets.interleave_datasets (#10112) 2026-01-20 15:54:07 +08:00
Kingsley
db2f794f7b [misc] update mcore related docker and mca supported models (#10114) 2026-01-19 14:55:16 +08:00
jiaqiw09
44eadbda1c [v1] fix kernel moe patch (#9867) 2026-01-17 09:24:54 +08:00
浮梦
9829ae0a77 [ci] using mp to run kernel test (#9754)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-13 19:43:59 +08:00
Yaowei Zheng
958b9c3468 [v1] add sft (#9752) 2026-01-12 03:15:01 +08:00
Hertz
4d3621e3d3 [model] fixed&added Hunyuan models (#9750) 2026-01-12 01:15:00 +08:00
Yaowei Zheng
a296723697 [v1] upgrade batching (#9751)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-12 00:21:36 +08:00
Hertz
15b87f3125 [model] support HY-MT model (#9746)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-11 16:25:56 +08:00
Yaowei Zheng
9f73a6eb23 [deps] fix package (#9745) 2026-01-10 04:27:53 +08:00
Yaowei Zheng
b2effbd77c [v1] add batch generator (#9744) 2026-01-10 04:24:09 +08:00
Yaowei Zheng
d7d734d54c [misc] fix fp8 (#9742) 2026-01-09 16:17:26 +08:00
Yaowei Zheng
8abb8fb533 [v1] use async streamer (#9741) 2026-01-09 16:12:07 +08:00
Yaowei Zheng
766d5ae6ad [ci] fix workflow (#9738) 2026-01-09 16:12:07 +08:00
Yaowei Zheng
5cccaeec82 [model] clean obsolete models (#9736) 2026-01-09 16:12:07 +08:00
Jackey
5fb5d7ebd3 [model] support for microsoft's Phi-4-mini (#9734) 2026-01-09 12:24:45 +08:00
Peilin Li
03a70ba8dd [fix] correct ktransformers example config paths and templates (#9732) 2026-01-08 10:52:50 +08:00
Vo Van Phuc
5cfd804b59 [refactor] rename lfm template to lfm2 and add LFM 2.5 to README (#9731) 2026-01-07 19:25:04 +08:00
Yaowei Zheng
4c1eb922e2 [misc] fix parser (#9730) 2026-01-07 17:36:08 +08:00
Vo Van Phuc
958fb523a2 [model] support LiquidAI's LFM2.5-VL vision-language model (#9729) 2026-01-07 17:20:29 +08:00
Vo Van Phuc
b4e051bea4 [model] support for LiquidAI's LFM2.5 (Liquid Foundation Models) (#9726) 2026-01-07 14:14:47 +08:00
浮梦
d43e1007e8 [ci] improve cuda ci cache (#9725)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-01-07 12:34:40 +08:00
Xunpeng Xiao
f89d9367e5 [assets] update README.md (#9724) 2026-01-07 12:11:50 +08:00
Yaowei Zheng
d22de0d4bf [v1] add renderer ut (#9722) 2026-01-07 02:06:07 +08:00
Yaowei Zheng
ea0b4e2466 [v1] add cli sampler (#9721) 2026-01-06 23:31:27 +08:00
yanglele
e944dc442c [feature] add support for EAFT loss (#9720)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-06 23:07:12 +08:00
Xunpeng Xiao
68119e5522 [misc] Add a PyTorch version warning for Conv3D. (#9715) 2026-01-05 13:26:29 +08:00
Yaowei Zheng
f60a6e3d01 [v1] add init plugin (#9716) 2026-01-04 20:51:46 +08:00
jiaqiw09
81b8a50aa5 [deps] Update pyproject.toml and requirements (#9714)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-04 19:52:16 +08:00
Yaowei Zheng
8600530002 [misc] lint (#9710) 2026-01-04 13:47:56 +08:00
Hertz
9ae62c6fc0 [model] support Youtu-LLM-2B (#9707) 2026-01-04 13:17:57 +08:00
Xunpeng Xiao
0087bc253b [misc] Compatible with an empty architectures field in config.json (#9709) 2026-01-04 12:11:35 +08:00
Santosh Bhavani
355d5c5e5a [fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-01 10:18:02 +08:00
Yaowei Zheng
6fe6bd290b [misc] set dev version (#9703) 2025-12-31 23:41:40 +08:00
139 changed files with 4439 additions and 2667 deletions

View File

@@ -50,7 +50,7 @@ jobs:
docker-images: false docker-images: false
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v6
- name: Get llamafactory version - name: Get llamafactory version
id: version id: version

View File

@@ -21,7 +21,7 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v6
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7

View File

@@ -54,10 +54,11 @@ jobs:
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }} OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v6
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
@@ -70,7 +71,8 @@ jobs:
run: | run: |
uv venv uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install transformers - name: Install transformers
if: ${{ matrix.transformers }} if: ${{ matrix.transformers }}
@@ -79,7 +81,7 @@ jobs:
- name: Cache files - name: Cache files
id: hf-hub-cache id: hf-hub-cache
uses: actions/cache@v4 uses: actions/cache@v5
with: with:
path: ${{ runner.temp }}/huggingface path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }} key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
@@ -87,25 +89,18 @@ jobs:
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env: env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}" HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -35,9 +35,16 @@ jobs:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }} group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v6
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
@@ -52,37 +59,21 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Cache HuggingFace models
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -43,10 +43,11 @@ jobs:
HF_ENDPOINT: https://hf-mirror.com HF_ENDPOINT: https://hf-mirror.com
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }} OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v6
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
@@ -58,8 +59,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
uv venv uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}} uv pip install -r requirements/npu.txt
uv pip install -e ".[dev]" uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install node - name: Install node
run: | run: |
@@ -68,35 +70,18 @@ jobs:
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
apt-get install -y nodejs apt-get install -y nodejs
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license - name: Check license
run: | run: |
make license make license
env:
UV_NO_SYNC: 1
- name: Check build - name: Check build
run: | run: |
make build make build
env:
UV_NO_SYNC: 1
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test
env:
UV_NO_SYNC: 1
HF_HOME: /root/.cache/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

1
.gitignore vendored
View File

@@ -176,6 +176,7 @@ llamaboard_cache/
llamaboard_config/ llamaboard_config/
saves/ saves/
output/ output/
outputs/
wandb/ wandb/
swanlog/ swanlog/
generated_predictions.jsonl generated_predictions.jsonl

View File

@@ -92,7 +92,7 @@ Read technical notes:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen3, Qwen3-VL, DeepSeek, Gemma, GLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
@@ -279,11 +279,10 @@ Read technical notes:
| Model | Model size | Template | | Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- | | ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink | | [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 | | [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n | | [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
@@ -292,12 +291,13 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 | | [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@@ -307,18 +307,17 @@ Read technical notes:
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
@@ -327,8 +326,6 @@ Read technical notes:
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl | | [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder | | [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -514,12 +511,13 @@ huggingface-cli login
#### Install from Source #### Install from Source
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LLaMA-Factory cd LlamaFactory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"` Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
Additional dependencies for specific features are available in `examples/requirements/`. Additional dependencies for specific features are available in `examples/requirements/`.
@@ -577,36 +575,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary> <details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
You can also download the pre-built Docker images:
```bash ```bash
# replace the url according to your CANN version and devices # Docker Hub
# install CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# install CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes #### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps: To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:

View File

@@ -94,7 +94,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen3、Qwen3-VL、DeepSeek、Gemma、GLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
@@ -281,11 +281,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| 模型名 | 参数量 | Template | | 模型名 | 参数量 | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- | | ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 | | [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink | | [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 | | [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n | | [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
@@ -294,12 +293,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 | | [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@@ -309,18 +309,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma | | [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi | | [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
@@ -329,8 +328,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl | | [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder | | [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
@@ -516,12 +513,13 @@ huggingface-cli login
#### 从源码安装 #### 从源码安装
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LLaMA-Factory cd LlamaFactory
pip install -e ".[metrics]" pip install -e .
pip install -r requirements/metrics.txt
``` ```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。 可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。 其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
@@ -579,36 +577,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary> <details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: 在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
您可以直接下载预安装的最新docker镜像
```bash ```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL # Docker Hub
# 安装 CANN Toolkit docker pull hiyouga/llamafactory:latest-npu-a2
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run docker pull hiyouga/llamafactory:latest-npu-a3
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# 安装 CANN Kernels # quay.io
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run docker pull quay.io/ascend/llamafactory:latest-npu-a2
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install docker pull quay.io/ascend/llamafactory:latest-npu-a3
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes #### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤: 如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:

File diff suppressed because one or more lines are too long

View File

@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]" RUN pip install --no-cache-dir --no-build-isolation -e . && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -1,12 +1,13 @@
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10) # NVIDIA official image (ubuntu-24.04 + cuda-12.9.1 + python-3.12)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-06.html
FROM nvcr.io/nvidia/pytorch:24.05-py3 FROM nvcr.io/nvidia/pytorch:25.06-py3
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV PIP_ROOT_USER_ACTION=ignore ENV PIP_ROOT_USER_ACTION=ignore
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/ ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
ENV PIP_CONSTRAINT=""
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
@@ -14,20 +15,14 @@ RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \ flash_attn transformer-engine \
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 RUN pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \ RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \ rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \ RUN pip install --trusted-host mirrors.aliyun.com --index-url ${PYPI_MIRROR} \
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \ "megatron-core>=0.13.0,<0.14.0" "deepspeed==0.16.4"
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# RUN pip install vllm==0.8.4 \
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
WORKDIR /build WORKDIR /build
@@ -37,6 +32,8 @@ RUN pip uninstall -y apex && \
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url} --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
RUN pip install --no-build-isolation transformer_engine[pytorch]
RUN rm -rf /build RUN rm -rf /build
WORKDIR /workspace WORKDIR /workspace
@@ -53,14 +50,17 @@ RUN apt-get update && apt-get install -y zip
RUN apt-get install -y openjdk-21-jdk RUN apt-get install -y openjdk-21-jdk
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64 ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory ARG REPO_URL=https://github.com/hiyouga/LlamaFactory.git
ARG BRANCH=main
WORKDIR /app WORKDIR /app
# Copy the application into the image # Clone the repository
COPY . /app RUN git clone --depth 1 --branch ${BRANCH} ${REPO_URL} /app || \
git clone --depth 1 ${REPO_URL} /app
# Install LLaMA Factory # Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter" RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

View File

@@ -35,7 +35,8 @@ COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ] # VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -34,7 +34,8 @@ COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}" pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -0,0 +1,38 @@
### model
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
use_eaft_loss: true
### dataset
dataset: identity,alpaca_en_demo
template: qwen
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: qwen2.5-0_5b/full/sft_eaft
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

View File

@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
trust_remote_code: true trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -1,9 +1,9 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16 model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
template: deepseek template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers] infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -1,10 +1,10 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16 model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
adapter_name_or_path: saves/Kllama_deepseekV3 adapter_name_or_path: saves/Kllama_deepseekV3
template: deepseek template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers] infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
trust_remote_code: true trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -10,7 +10,7 @@ lora_rank: 8
lora_target: all lora_target: all
### dataset ### dataset
dataset: identity dataset: identity, alpaca_en_demo
template: deepseek template: deepseek
cutoff_len: 2048 cutoff_len: 2048
max_samples: 100000 max_samples: 100000
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers ### ktransformers
use_kt: true # use KTransformers as LoRA sft backend use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -10,8 +10,8 @@ lora_rank: 8
lora_target: all lora_target: all
### dataset ### dataset
dataset: identity dataset: identity, alpaca_en_demo
template: deepseek template: deepseek3
cutoff_len: 2048 cutoff_len: 2048
max_samples: 100000 max_samples: 100000
overwrite_cache: true overwrite_cache: true
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers ### ktransformers
use_kt: true # use KTransformers as LoRA sft backend use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers ### ktransformers
use_kt: true # use KTransformers as LoRA sft backend use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32 cpu_infer: 32
chunk_size: 8192 chunk_size: 8192

View File

@@ -0,0 +1,34 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -30,7 +30,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
@@ -63,7 +62,7 @@ dependencies = [
"hf-transfer", "hf-transfer",
"safetensors", "safetensors",
# python # python
"av", "av>=10.0.0,<=16.0.0",
"fire", "fire",
"omegaconf", "omegaconf",
"packaging", "packaging",
@@ -73,14 +72,9 @@ dependencies = [
# api # api
"uvicorn", "uvicorn",
"fastapi", "fastapi",
"sse-starlette" "sse-starlette",
] ]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts] [project.scripts]
llamafactory-cli = "llamafactory.cli:main" llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main" lmf = "llamafactory.cli:main"

View File

@@ -0,0 +1 @@
deepspeed>=0.10.0,<=0.16.9

4
requirements/dev.txt Normal file
View File

@@ -0,0 +1,4 @@
pre-commit
ruff
pytest
build

3
requirements/metrics.txt Normal file
View File

@@ -0,0 +1,3 @@
nltk
jieba
rouge-chinese

4
requirements/npu.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer, Qwen3Config, Qwen3ForCausalLM
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
config = Qwen3Config(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
model = Qwen3ForCausalLM.from_config(config)
model.save_pretrained("tiny-qwen3")
tokenizer.save_pretrained("tiny-qwen3")
model.push_to_hub("llamafactory/tiny-random-qwen3")
tokenizer.push_to_hub("llamafactory/tiny-random-qwen3")

View File

@@ -28,7 +28,7 @@ try:
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
jieba.initialize() jieba.initialize()
except ImportError: except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.") print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
raise raise

55
scripts/hf2dcp.py Normal file
View File

@@ -0,0 +1,55 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a HuggingFace model to DCP checkpoint format.
Usage:
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
Arguments:
hf_path: Path to the HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
def convert(hf_path: str, dcp_path: str) -> None:
"""Convert HF model weights to DCP.
Args:
hf_path: HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
if not hf_path or not dcp_path:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -65,11 +65,17 @@ def merge_dataset(
if not data_args.streaming: if not data_args.streaming:
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
strategy_map: str = {
"interleave_under": "first_exhausted",
"interleave_over": "all_exhausted",
"interleave_once": "all_exhausted_without_replacement",
}[data_args.mix_strategy]
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
probabilities=data_args.interleave_probs, probabilities=data_args.interleave_probs,
seed=seed, seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy=strategy_map, # type: ignore
) )
else: else:

View File

@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
return messages return messages
@dataclass
class LFMVLPlugin(BasePlugin):
r"""Plugin for LFM2.5-VL vision-language models.
LFM2.5-VL uses dynamic image token counts based on image resolution.
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
"""
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
if self.expand_mm_tokens and len(images) > 0:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
spatial_shapes = mm_inputs.get("spatial_shapes", [])
else:
spatial_shapes = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
h, w = spatial_shapes[num_image_tokens].tolist()
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
else:
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
return messages
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"ernie_vl": ErnieVLPlugin, "ernie_vl": ErnieVLPlugin,
@@ -2104,6 +2171,7 @@ PLUGINS = {
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin, "llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin, "llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin, "minicpm_v": MiniCPMVPlugin,
"mllama": MllamaPlugin, "mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin, "paligemma": PaliGemmaPlugin,

View File

@@ -649,42 +649,6 @@ register_template(
) )
register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}###"]),
format_system=StringFormatter(slots=["System: {{content}}###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
)
register_template(
name="atom",
format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
register_template( register_template(
name="bailing", name="bailing",
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]), format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
@@ -712,20 +676,6 @@ register_template(
) )
register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
register_template( register_template(
name="breeze", name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@@ -734,14 +684,6 @@ register_template(
) )
register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
)
register_template( register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@@ -784,29 +726,6 @@ register_template(
) )
register_template(
name="codegeex2",
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题"
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
register_template( register_template(
name="cohere", name="cohere",
format_user=StringFormatter( format_user=StringFormatter(
@@ -822,25 +741,6 @@ register_template(
) )
register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
# copied from chatml template
register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="cpm4", name="cpm4",
@@ -1239,19 +1139,12 @@ register_template(
register_template( register_template(
name="intern", name="hunyuan_small",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<hy_User>{{content}}<hy_place▁holder▁no▁8>"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]), format_assistant=StringFormatter(slots=["{{content}}<hy_place▁holder▁no▁2>"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), format_system=StringFormatter(slots=["{{content}}<hy_place▁holder▁no▁3>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=["<hy_begin▁of▁sentence>"]),
default_system=( stop_words=["<hy_place▁holder▁no▁2>"],
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"],
) )
@@ -1330,6 +1223,47 @@ register_template(
) )
register_template(
name="lfm2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm2"),
default_system="You are a helpful AI assistant.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
)
register_template(
name="lfm2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm2"),
default_system="You are a helpful multimodal assistant by Liquid AI.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
)
register_template( register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@@ -1576,23 +1510,6 @@ register_template(
) )
# copied from chatml template
register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),
stop_words=["<|im_end|>"],
)
# copied from qwen template # copied from qwen template
register_template( register_template(
name="mimo", name="mimo",
@@ -1804,13 +1721,6 @@ register_template(
) )
register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
register_template( register_template(
name="paligemma", name="paligemma",
format_user=StringFormatter(slots=["{{content}}\n"]), format_user=StringFormatter(slots=["{{content}}\n"]),
@@ -1869,6 +1779,17 @@ register_template(
) )
register_template(
name="phi4_mini",
format_user=StringFormatter(slots=["<|user|>{{content}}<|end|><|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_system=StringFormatter(slots=["<|system|>{{content}}<|end|>"]),
format_tools=StringFormatter(slots=["<|tool|>{{content}}<|/tool|>"]),
stop_words=["<|end|>"],
replace_eos=True,
)
# copied from ministral template # copied from ministral template
register_template( register_template(
name="pixtral", name="pixtral",
@@ -2104,41 +2025,6 @@ register_template(
) )
# copied from llama3 template
register_template(
name="skywork_o1",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words=["<|eot_id|>", "<|eom_id|>"],
)
register_template( register_template(
name="smollm", name="smollm",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -2175,13 +2061,6 @@ register_template(
) )
register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
)
register_template( register_template(
name="telechat2", name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
@@ -2225,32 +2104,6 @@ register_template(
) )
register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words=["<|End|>"],
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="yi", name="yi",
@@ -2278,6 +2131,21 @@ register_template(
) )
register_template(
name="youtu",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end_of_text|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@@ -2292,10 +2160,3 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.", default_system="You are Zephyr, a helpful assistant.",
) )
register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>""" """"arguments": <args-json-object>}}\n</tool_call>"""
) )
LFM2_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off" return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
class LFM2ToolUtils(ToolUtils):
r"""LFM2.5 tool using template with Pythonic function call syntax."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_list = []
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_list.append(tool)
return LFM2_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
calls = []
for name, args_json in functions:
args = json.loads(args_json)
kwargs_parts = []
for key, value in args.items():
if isinstance(value, str):
kwargs_parts.append(f'{key}="{value}"')
else:
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
calls.append(f"{name}({', '.join(kwargs_parts)})")
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
@staticmethod
def _ast_to_value(node: ast.AST) -> Any:
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
# Handle JSON-style true/false/null as Name nodes
if isinstance(node, ast.Name):
if node.id == "true":
return True
elif node.id == "false":
return False
elif node.id == "null":
return None
else:
raise ValueError(f"Unknown identifier: {node.id}")
# Use literal_eval for other cases (strings, numbers, lists, dicts)
return ast.literal_eval(node)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
# Extract content between tool call markers
start_marker = "<|tool_call_start|>"
end_marker = "<|tool_call_end|>"
start_idx = content.find(start_marker)
if start_idx == -1:
return content
end_idx = content.find(end_marker, start_idx)
if end_idx == -1:
return content
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
# Parse Pythonic function call syntax using AST
try:
tree = ast.parse(tool_call_str, mode="eval")
except SyntaxError:
return content
# Handle both single call and list of calls
if isinstance(tree.body, ast.List):
call_nodes = tree.body.elts
elif isinstance(tree.body, ast.Call):
call_nodes = [tree.body]
else:
return content
results = []
for node in call_nodes:
if not isinstance(node, ast.Call):
return content
# Extract function name
if isinstance(node.func, ast.Name):
func_name = node.func.id
else:
return content
# Extract keyword arguments
args_dict = {}
for keyword in node.keywords:
key = keyword.arg
try:
value = LFM2ToolUtils._ast_to_value(keyword.value)
except (ValueError, SyntaxError):
return content
args_dict[key] = value
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
TOOLS = { TOOLS = {
"default": DefaultToolUtils(), "default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(), "glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(), "llama3": Llama3ToolUtils(),
"lfm2": LFM2ToolUtils(),
"minimax1": MiniMaxM1ToolUtils(), "minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(), "minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(), "mistral": MistralToolUtils(),

View File

@@ -57,6 +57,7 @@ LLAMABOARD_CONFIG = "llamaboard_config.yaml"
MCA_SUPPORTED_MODELS = { MCA_SUPPORTED_MODELS = {
"deepseek_v3", "deepseek_v3",
"glm4_moe",
"llama", "llama",
"mistral", "mistral",
"mixtral", "mixtral",
@@ -181,51 +182,6 @@ register_model_group(
) )
register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
template="baichuan",
)
register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
},
},
template="baichuan2",
)
register_model_group( register_model_group(
models={ models={
"BLOOM-560M": { "BLOOM-560M": {
@@ -262,21 +218,6 @@ register_model_group(
) )
register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm",
)
register_model_group( register_model_group(
models={ models={
"Breeze-7B": { "Breeze-7B": {
@@ -290,17 +231,6 @@ register_model_group(
) )
register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
template="chatglm2",
)
register_model_group( register_model_group(
models={ models={
"ChatGLM3-6B-Base": { "ChatGLM3-6B-Base": {
@@ -347,17 +277,6 @@ register_model_group(
) )
register_model_group(
models={
"CodeGeeX4-9B-Chat": {
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
},
},
template="codegeex4",
)
register_model_group( register_model_group(
models={ models={
"CodeGemma-7B": { "CodeGemma-7B": {
@@ -642,15 +561,15 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ERNIE-4.5-0.3B-PT": { "ERNIE-4.5-0.3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT",
}, },
"ERNIE-4.5-21B-A3B-PT": { "ERNIE-4.5-21B-A3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT",
}, },
"ERNIE-4.5-300B-A47B-PT": { "ERNIE-4.5-300B-A47B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT",
}, },
@@ -661,7 +580,7 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ERNIE-4.5-VL-28B-A3B-PT": { "ERNIE-4.5-VL-28B-A3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
}, },
@@ -669,7 +588,7 @@ register_model_group(
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
}, },
"ERNIE-4.5-VL-424B-A47B-Base-PT": { "ERNIE-4.5-VL-424B-A47B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT", DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT", DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
}, },
@@ -1226,19 +1145,50 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Hunyuan-0.5B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-0.5B-Instruct",
},
"Hunyuan-1.8B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-1.8B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-1.8B-Instruct",
},
"Hunyuan-4B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-4B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-4B-Instruct",
},
"Hunyuan-7B-Instruct": { "Hunyuan-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct", DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct", DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-7B-Instruct",
}, },
"Hunyuan-MT-7B-Instruct": { "Hunyuan-MT-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B", DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B", DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
}, },
"HY-MT1.5-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/HY-MT1.5-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-7B",
},
"Hunyuan-A13B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-A13B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-A13B-Instruct",
},
}, },
template="hunyuan", template="hunyuan",
) )
register_model_group(
models={
"HY-MT1.5-1.8B-Instruct": {
DownloadSource.DEFAULT: "tencent/HY-MT1.5-1.8B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-1.8B",
},
},
template="hunyuan_small",
)
register_model_group( register_model_group(
models={ models={
"Index-1.9B-Base": { "Index-1.9B-Base": {
@@ -1266,29 +1216,6 @@ register_model_group(
) )
register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern",
)
register_model_group( register_model_group(
models={ models={
"InternLM2-7B": { "InternLM2-7B": {
@@ -1485,11 +1412,25 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LingoWhale-8B": { "LFM2.5-1.2B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B", },
} "LFM2.5-1.2B-Instruct": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
},
}, },
template="lfm2",
)
register_model_group(
models={
"LFM2.5-VL-1.6B": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
},
},
template="lfm2_vl",
multimodal=True,
) )
@@ -1804,17 +1745,6 @@ register_model_group(
) )
register_model_group(
models={
"Marco-o1-Chat": {
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
},
},
template="marco",
)
register_model_group( register_model_group(
models={ models={
"MiMo-7B-Base": { "MiMo-7B-Base": {
@@ -1885,33 +1815,6 @@ register_model_group(
) )
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
},
"MiniCPM-2B-DPO-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
},
},
template="cpm",
)
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group( register_model_group(
models={ models={
"MiniCPM4-0.5B-Chat": { "MiniCPM4-0.5B-Chat": {
@@ -1949,26 +1852,10 @@ register_model_group(
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6", DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
}, },
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-4": { "MiniCPM-V-4": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4", DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4",
}, },
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-4.5": { "MiniCPM-V-4.5": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5", DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5",
@@ -2226,33 +2113,6 @@ register_model_group(
) )
register_model_group(
models={
"Orion-14B-Base": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
},
"Orion-14B-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
},
"Orion-14B-Long-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
},
"Orion-14B-RAG-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
},
"Orion-14B-Plugin-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
},
},
template="orion",
)
register_model_group( register_model_group(
models={ models={
"PaliGemma-3B-pt-224": { "PaliGemma-3B-pt-224": {
@@ -2349,20 +2209,6 @@ register_model_group(
) )
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
},
}
)
register_model_group( register_model_group(
models={ models={
"Phi-3-4B-4k-Instruct": { "Phi-3-4B-4k-Instruct": {
@@ -2419,6 +2265,15 @@ register_model_group(
template="phi4", template="phi4",
) )
register_model_group(
models={
"Phi-4-3.8B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-4-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-4-mini-instruct",
},
},
template="phi4_mini",
)
register_model_group( register_model_group(
models={ models={
@@ -2432,228 +2287,6 @@ register_model_group(
) )
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
},
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
},
"Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
},
"Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
},
"Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
},
"CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
},
"CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
register_model_group( register_model_group(
models={ models={
"Qwen2-0.5B": { "Qwen2-0.5B": {
@@ -3421,27 +3054,6 @@ register_model_group(
) )
register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
register_model_group(
models={
"Skywork-o1-Open-Llama-3.1-8B": {
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
}
},
template="skywork_o1",
)
register_model_group( register_model_group(
models={ models={
"SmolLM-135M": { "SmolLM-135M": {
@@ -3536,30 +3148,6 @@ register_model_group(
) )
register_model_group(
models={
"TeleChat-1B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
},
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-52B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
},
},
template="telechat",
)
register_model_group( register_model_group(
models={ models={
"TeleChat2-3B-Chat": { "TeleChat2-3B-Chat": {
@@ -3674,80 +3262,6 @@ register_model_group(
) )
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi",
)
register_model_group( register_model_group(
models={ models={
"Yi-6B": { "Yi-6B": {
@@ -3846,6 +3360,21 @@ register_model_group(
) )
register_model_group(
models={
"Youtu-LLM-2B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
},
"Youtu-LLM-2B-Base": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
},
},
template="youtu",
)
register_model_group( register_model_group(
models={ models={
"Yuan2-2B-Chat": { "Yuan2-2B-Chat": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict from collections import OrderedDict
VERSION = "0.9.4" VERSION = "0.9.5.dev0"
def print_env() -> None: def print_env() -> None:

View File

@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
return torch.device(device) return torch.device(device)
def get_device_name() -> str:
r"""Get the name of available devices."""
if is_torch_xpu_available():
device = "xpu"
elif is_torch_npu_available():
device = "npu"
elif is_torch_mps_available():
device = "mps"
elif is_torch_cuda_available():
device = "gpu"
else:
device = "cpu"
return device
def get_torch_device():
r"""Get the torch device namespace for the available devices."""
device_name = get_device_name()
device_name = "cuda" if device_name == "gpu" else device_name
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_device_count() -> int: def get_device_count() -> int:
r"""Get the number of available devices.""" r"""Get the number of available devices."""
if is_torch_xpu_available(): if is_torch_xpu_available():

View File

@@ -63,9 +63,9 @@ class DataArguments:
default=16384, default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
) )
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
) )
interleave_probs: str | None = field( interleave_probs: str | None = field(
default=None, default=None,

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether to use the DFT loss."}, metadata={"help": "Whether to use the DFT loss."},
) )
use_eaft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the EAFT loss."},
)
eaft_alpha: float = field(
default=1.0,
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
)
freeze_vision_tower: bool = field( freeze_vision_tower: bool = field(
default=True, default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},

View File

@@ -298,23 +298,6 @@ class QuantizationArguments:
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
if args is not None: if args is not None:
return args return args
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"): if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
override_config = OmegaConf.from_cli(sys.argv[2:]) override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"): elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:]) override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute()) dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: else:
return sys.argv[1:] return sys.argv[1:]
@@ -142,14 +143,6 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
@@ -347,6 +340,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if not finetuning_args.use_mca and training_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.infer_backend != EngineName.HF: if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
@@ -363,6 +359,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
if ( if (
training_args.do_train training_args.do_train
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"

View File

@@ -14,7 +14,6 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -40,59 +39,55 @@ else:
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field( ray_num_workers: int = field(
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: dict | str | None = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )
master_addr: str | None = field(
default=None,
metadata={"help": "The master address for init_process_group"},
)
master_port: str | None = field(
default=None,
metadata={"help": "The master port for init_process_group"},
)
def __post_init__(self): def __post_init__(self):
self.use_ray = use_ray() self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"): if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
)
import pyarrow.fs as fs @dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
if self.ray_storage_filesystem == "s3": fp8: bool = field(
self.ray_storage_filesystem = fs.S3FileSystem() default=False,
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs": metadata={
self.ray_storage_filesystem = fs.GcsFileSystem() "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments): class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(

View File

@@ -15,6 +15,7 @@
import os import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
@@ -203,6 +205,16 @@ def load_model(
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}") logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
"Please downgrade torch to <2.9 or remove Conv3D. "
"See https://github.com/pytorch/pytorch/issues/166122"
)
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
model.eval() model.eval()
@@ -218,7 +230,7 @@ def load_model(
) )
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels) model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:

View File

@@ -356,7 +356,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_vl", model_type="qwen3_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -365,7 +365,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_vl_moe", model_type="qwen3_vl_moe",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -374,7 +374,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_omni_moe_thinker", model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -138,18 +138,25 @@ def patch_config(
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable: if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy") setattr(config.text_config, "topk_method", "greedy")
if "InternVLChatModel" in getattr(config, "architectures", []): architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
raise ValueError( raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format " "Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)." "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
) )
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
raise RuntimeError(
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
)
if getattr(config, "model_type", None) == "qwen3_omni_moe": if getattr(config, "model_type", None) == "qwen3_omni_moe":
patch_qwen3_omni_moe_thinker_text_sparse_moe_block() patch_qwen3_omni_moe_thinker_text_sparse_moe_block()

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
""" """
if not model_args.fp8: if not training_args.fp8:
return [] return []
try: backend = getattr(training_args, "fp8_backend", "auto")
# Check if AORecipeKwargs is available (Accelerate 1.8.0+) logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
from accelerate.utils import AORecipeKwargs
backend = getattr(model_args, "fp8_backend", "auto") try:
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") # Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used # Create Float8LinearConfig if torchao backend is used
config = None config = None
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True return True
# Map FSDP all-gather setting if available (this affects the underlying implementation) # Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested") logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return [] return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]: def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8. """Get the mixed precision setting for Accelerate when using FP8.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
"fp8" if FP8 is enabled, None otherwise "fp8" if FP8 is enabled, None otherwise
""" """
return "fp8" if model_args.fp8 else None return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None: def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate. """Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration. variables and validates the FP8 configuration.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
import os if not training_args.fp8:
if not model_args.fp8:
return return
# Set mixed precision to fp8 for HuggingFace Accelerate # Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options # Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto": if backend != "auto":
os.environ["FP8_BACKEND"] = backend os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}") logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging) # Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args) fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested # Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation. """Verify that FP8 training is actually working after model preparation.
Args: Args:
accelerator: The HuggingFace Accelerator instance accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
if not model_args.fp8: if not training_args.fp8:
return return
# Check Accelerate's FP8 status # Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False) fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto": if backend == "torchao" or backend == "auto":
logger.info_rank0( logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, " "FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled: if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True

View File

@@ -19,16 +19,15 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
model_args: Optional["ModelArguments"] = None, model_args: Optional["ModelArguments"] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args: TrainingArguments = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): verify_fp8_status(self.accelerator, training_args)
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
else: else:
return loss return loss
@override
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if state_dict is None:
state_dict = self.model.state_dict()
if self.args.save_safetensors:
from collections import defaultdict
ptrs = defaultdict(list)
for name, tensor in state_dict.items():
if isinstance(tensor, torch.Tensor):
ptrs[id(tensor)].append(name)
for names in ptrs.values():
if len(names) > 1:
names.sort()
for name in names[1:]:
state_dict.pop(name, None)
super()._save(output_dir, state_dict)
def save_predictions(self, predict_results: "PredictionOutput") -> None: def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""Save model predictions to `output_dir`. r"""Save model predictions to `output_dir`.

View File

@@ -27,18 +27,17 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args: TrainingArguments = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
else: patch_accelerator_for_fp8()
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available) elif finetuning_args.use_eaft_loss:
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): from ..trainer_utils import eaft_loss_func
verify_fp8_status(self.accelerator, model_args)
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -20,7 +20,6 @@
import json import json
import os import os
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
@@ -34,6 +33,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.misc import get_device_name
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -49,15 +49,15 @@ if is_apollo_available():
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train import RunConfig, ScalingConfig from ray.util.placement_group import PlacementGroup, placement_group
from ray.train.torch import TorchTrainer from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments from ..hparams import DataArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None): def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits") logits = outputs.get("logits")
if logits is None: if logits is None:
return outputs.get("loss", torch.tensor(0.0)) return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy( def _dft_cross_entropy(
source: torch.Tensor, source: "torch.Tensor",
target: torch.Tensor, target: "torch.Tensor",
num_items_in_batch: Optional[torch.Tensor] = None, num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100, ignore_index: int = -100,
) -> torch.Tensor: ) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none") per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index valid_mask = target != ignore_index
if not valid_mask.any(): if not valid_mask.any():
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
return loss return loss
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
vocab_size = logits.size(-1)
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
return loss
def _eaft_cross_entropy(
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
valid_losses = per_token_loss[valid_mask]
with torch.no_grad():
source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)
weighted_losses = valid_losses * adaptive_weight
if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss
def nested_detach( def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]], tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False, clone: bool = False,
@@ -744,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
return swanlab_callback return swanlab_callback
def get_ray_trainer( def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
training_function: Callable, r"""Get the Ray placement group for distributed training."""
train_loop_config: dict[str, Any], bundle = {"CPU": 10}
ray_args: "RayArguments", device_name = get_device_name().upper()
) -> "TorchTrainer": if device_name != "CPU":
if not ray_args.use_ray: bundle[device_name] = 1
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.") bundles = [bundle for _ in range(num_workers)]
pg = placement_group(bundles, strategy="PACK")
if ray_args.ray_init_kwargs is not None: return pg, bundle
ray.init(**ray_args.ray_init_kwargs)
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer( def get_ray_remote_config_for_worker(
training_function, placement_group: "PlacementGroup",
train_loop_config=train_loop_config, bundle_idx: int,
scaling_config=ScalingConfig( rank: int,
num_workers=ray_args.ray_num_workers, world_size: int,
resources_per_worker=ray_args.resources_per_worker, master_addr: str,
placement_strategy=ray_args.placement_strategy, master_port: str,
use_gpu=True, env: dict[str, str] = None,
) -> dict[str, Any]:
r"""Get the remote config for a Ray worker."""
env_vars = {
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"TORCHELASTIC_USE_AGENT_STORE": "False",
}
env.update(env_vars)
remote_config = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
), ),
run_config=RunConfig( "runtime_env": {"env_vars": env},
name=ray_args.ray_run_name, "num_cpus": 10,
storage_filesystem=ray_args.ray_storage_filesystem, }
storage_path=storage_path,
), device_name = get_device_name()
) if device_name == "gpu":
return trainer remote_config["num_gpus"] = 1
elif device_name == "npu":
remote_config["resources"] = {"NPU": 1}
return remote_config
def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
return head_ip
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
r"""Sort the placement group bundles by their node IP addresses."""
@ray.remote
def _get_node_ip():
return ray.util.get_node_ip_address().strip("[]")
tasks = []
for bundle_idx in range(placement_group.bundle_count):
task = _get_node_ip.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
).remote()
tasks.append(task)
bundle_ips = ray.get(tasks)
bundle_node_ip_list = list(enumerate(bundle_ips))
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
if master_addr is not None:
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
if preferred_indices:
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
sorted_bundle_indices = preferred_indices + remaining
return sorted_bundle_indices

View File

@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
@@ -34,12 +34,17 @@ from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback from .trainer_utils import (
get_placement_group,
get_ray_head_node_ip,
get_ray_remote_config_for_worker,
get_swanlab_callback,
sort_placement_group_by_node_ip,
)
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
ray_args = get_ray_args(args) ray_args = get_ray_args(args)
callbacks = callbacks or [] callbacks = callbacks or []
if ray_args.use_ray: if ray_args.use_ray:
callbacks.append(RayTrainReportCallback()) _ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else: else:
_training_function(config={"args": args, "callbacks": callbacks}) _training_function(config={"args": args, "callbacks": callbacks})
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
with open(ollama_modelfile, "w", encoding="utf-8") as f: with open(ollama_modelfile, "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer)) f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}") logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
class Worker:
def __init__(self):
self._setup_env_visible_devices()
local_rank = os.environ.get("LOCAL_RANK", "0")
get_torch_device().set_device(int(local_rank))
def _setup_env_visible_devices(self) -> None:
RAY_NOSET_VISIBLE_DEVICES_LIST = [
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
]
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
if is_ray_noset_visible_devices:
device_name = get_device_name().upper()
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
os.environ["LOCAL_RANK"] = local_rank
else:
os.environ["LOCAL_RANK"] = "0"
def _training_function(self, config: dict[str, Any]) -> None:
_training_function(config)
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
num_workers = ray_args.ray_num_workers
master_addr = ray_args.master_addr
master_port = ray_args.master_port
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
# initialize ray
if not ray.is_initialized():
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
else:
ray.init()
# verify resources
device_name = get_device_name().upper()
total_devices = int(ray.cluster_resources().get(device_name, 0))
if num_workers > total_devices:
raise ValueError(
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
)
# verify master_addr
if master_addr is None:
master_addr = get_ray_head_node_ip()
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
else:
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
if master_addr not in nodes:
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
# create placementgroup for resource management
pg, bundle = get_placement_group(total_devices)
ray.get(pg.ready())
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
# get sorted_bundle_indices
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
# get master port
if master_port is None:
master_port = find_available_port()
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
master_port = str(master_port)
# backing up environment variables
current_env = dict(os.environ.items())
# launch workers
RayWorker = ray.remote(Worker)
workers = []
for rank in range(num_workers):
remote_config = get_ray_remote_config_for_worker(
placement_group=pg,
bundle_idx=sorted_bundle_indices[rank],
rank=rank,
world_size=num_workers,
master_addr=master_addr,
master_port=master_port,
env=current_env,
)
worker = RayWorker.options(**remote_config).remote()
workers.append(worker)
ray.get([worker._training_function.remote(config=config) for worker in workers])
ray.shutdown()

View File

@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator @requires_accelerator
def set_device() -> None: def set_device_index() -> None:
"""Set current accelerator.""" """Set current accelerator index to local rank."""
torch.accelerator.set_device_index(get_local_rank()) if get_current_accelerator().type != DeviceType.CPU:
torch.accelerator.set_device_index(get_local_rank())
@requires_accelerator
def get_current_device() -> torch.device:
"""Get current accelerator device."""
if get_current_accelerator().type == DeviceType.CPU:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available(): def is_torch_cuda_available():
@@ -170,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
return result.tolist() return result.tolist()
def get_process_group_backend() -> str:
"""Get backend for init process group."""
if get_current_accelerator().type == DeviceType.NPU:
return "hccl"
elif get_current_accelerator().type == DeviceType.CUDA:
return "nccl"
else:
return "gloo"
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim.""" """Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size() world_size = get_world_size()

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
from . import helper from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum): class Dim(str, Enum):
"""Dimension names.""" """Dimension names."""
@@ -119,12 +123,13 @@ class DistributedInterface:
if self._initialized: if self._initialized:
return return
helper.set_device_index()
self._is_distributed = helper.is_distributed() self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank() self._rank = helper.get_rank()
self._world_size = helper.get_world_size() self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank() self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size() self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator() self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count() self.device_count = helper.get_device_count()
if config is None: if config is None:
@@ -140,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000) timeout = config.get("timeout", 18000)
if self._is_distributed: if self._is_distributed:
helper.set_device() init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh( self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type, device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape, mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names, mesh_dim_names=self.strategy.model_mesh_dim_names,
) )
self.data_device_mesh = init_device_mesh( self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type, device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape, mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names, mesh_dim_names=self.strategy.data_mesh_dim_names,
) )
@@ -157,11 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None self.data_device_mesh = None
self._initialized = True self._initialized = True
logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, " f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, " f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
) )
@@ -169,7 +174,7 @@ class DistributedInterface:
"""Get device mesh for specified dimension.""" """Get device mesh for specified dimension."""
if dim is None: if dim is None:
raise ValueError("dim must be specified.") raise ValueError("dim must be specified.")
elif self.model_device_mesh is None: elif not self._is_distributed:
return None return None
elif dim in self.strategy.data_mesh_dim_names: elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value] return self.data_device_mesh[dim.value]
@@ -178,14 +183,14 @@ class DistributedInterface:
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]: def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension.""" """Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None: if not self._is_distributed or dim is None:
return None return None
else: else:
return self.get_device_mesh(dim).get_group() return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int: def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension.""" """Get parallel rank for specified dimension."""
if self.model_device_mesh is None: if not self._is_distributed:
return 0 return 0
elif dim is None: elif dim is None:
return self._rank return self._rank
@@ -194,7 +199,7 @@ class DistributedInterface:
def get_world_size(self, dim: Dim | None = None) -> int: def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension.""" """Get parallel size for specified dimension."""
if self.model_device_mesh is None: if not self._is_distributed:
return 1 return 1
elif dim is None: elif dim is None:
return self._world_size return self._world_size
@@ -209,9 +214,9 @@ class DistributedInterface:
"""Get parallel local world size.""" """Get parallel local world size."""
return self._local_world_size return self._local_world_size
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor: def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
"""Gather tensor across specified parallel group.""" """Gather tensor across specified parallel group."""
if self.model_device_mesh is not None: if self._is_distributed:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else: else:
return data return data
@@ -220,30 +225,36 @@ class DistributedInterface:
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike: ) -> TensorLike:
"""Reduce tensor across specified parallel group.""" """Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None: if self._is_distributed:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim)) return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else: else:
return data return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike: def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group.""" """Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None: if self._is_distributed:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else: else:
return data return data
def sync(self) -> None: def sync(self) -> None:
"""Synchronize all processes.""" """Synchronize all processes."""
helper.synchronize() if self._is_distributed:
helper.synchronize()
def barrier(self) -> None: def barrier(self) -> None:
"""Barrier all processes.""" """Barrier all processes."""
barrier() if self._is_distributed:
barrier()
def destroy(self) -> None: def destroy(self) -> None:
"""Destroy all processes.""" """Destroy all processes."""
destroy_process_group() if self._is_distributed:
destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy())) """
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())

View File

@@ -0,0 +1,33 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -20,7 +20,7 @@ from typing import Any
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled from ..utils.env import is_env_enabled
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments
@@ -30,24 +30,9 @@ from .training_args import TrainingArguments
InputArgument = dict[str, Any] | list[str] | None InputArgument = dict[str, Any] | list[str] | None
def validate_args( def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file.""" """Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments]) parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
if args is None: if args is None:
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -17,7 +17,7 @@
import json import json
from enum import Enum, unique from enum import StrEnum, unique
class PluginConfig(dict): class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique @unique
class ModelClass(str, Enum): class ModelClass(StrEnum):
"""Auto class for model config.""" """Auto class for model config."""
LLM = "llm" LLM = "llm"
@@ -45,11 +45,19 @@ class ModelClass(str, Enum):
@unique @unique
class SampleBackend(str, Enum): class SampleBackend(StrEnum):
HF = "hf" HF = "hf"
VLLM = "vllm" VLLM = "vllm"
@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
def _convert_str_dict(data: dict) -> dict: def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary. """Parse string representation inside the dictionary.

View File

@@ -18,11 +18,11 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class DataArguments: class DataArguments:
dataset: str | None = field( train_dataset: str | None = field(
default=None, default=None,
metadata={"help": "Path to the dataset."}, metadata={"help": "Path to the training dataset."},
) )
cutoff_len: int = field( eval_dataset: str | None = field(
default=2048, default=None,
metadata={"help": "Cutoff length for the dataset."}, metadata={"help": "Path to the evaluation dataset."},
) )

View File

@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model: str = field( model: str = field(
default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."}, metadata={"help": "Path to the model or model identifier from Hugging Face."},
) )
template: str = field(
default="qwen3_nothink",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field( trust_remote_code: bool = field(
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field( model_class: ModelClass = field(
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
) )
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field( peft_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "PEFT configuration for the model."}, metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +54,7 @@ class ModelArguments:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config) self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config) self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config) self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -16,35 +16,73 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from uuid import uuid4 from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
@dataclass @dataclass
class TrainingArguments: class TrainingArguments:
output_dir: str = field( output_dir: str = field(
default=os.path.join("outputs", str(uuid4())), default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."}, metadata={"help": "Path to the output directory."},
) )
micro_batch_size: int = field( micro_batch_size: int = field(
default=1, default=1,
metadata={"help": "Micro batch size for training."}, metadata={"help": "Micro batch size for training."},
) )
global_batch_size: int = field( global_batch_size: int | None = field(
default=1, default=None,
metadata={"help": "Global batch size for training."}, metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
) )
learning_rate: float = field( learning_rate: float = field(
default=1e-4, default=1e-4,
metadata={"help": "Learning rate for training."}, metadata={"help": "Learning rate for training."},
) )
num_train_epochs: int = field(
default=3,
metadata={"help": "Number of training epochs."},
)
max_steps: int | None = field(
default=None,
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for training."},
)
bf16: bool = field( bf16: bool = field(
default=False, default=False,
metadata={"help": "Use bf16 for training."}, metadata={"help": "Use bf16 for training."},
) )
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
enable_activation_checkpointing: bool = field(
default=True,
metadata={"help": "Enable activation checkpointing for training."},
)
dist_config: PluginConfig | None = field( dist_config: PluginConfig | None = field(
default=None, default=None,
metadata={"help": "Distribution configuration for training."}, metadata={"help": "Distribution configuration for training."},
) )
optim_config: PluginConfig | None = field(
default=None,
metadata={"help": "Optimizer configuration for training."},
)
lr_scheduler_config: PluginConfig | None = field(
default=None,
metadata={"help": "Learning rate scheduler configuration for training."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)

View File

@@ -0,0 +1,67 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import AsyncGenerator
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.inference_engine import HuggingFaceEngine
from .utils.rendering import Renderer
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -16,20 +16,33 @@
Init Phase: Init Phase:
1. Init dataloader. 1. Init batch generator.
2. Init optimizer (deepspeed). 2. Init optimizer (deepspeed).
3. Shard model. 3. Shard model.
4. Init optimizer (fsdp). 4. Init optimizer (fsdp).
5. Init scheduler. 5. Init lr scheduler.
Train Phase: Train Phase:
1. Train Loop 1. Train Loop
""" """
from ..config.training_args import TrainingArguments from abc import abstractmethod
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator import torch
import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..utils import logging
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class BaseTrainer: class BaseTrainer:
@@ -37,22 +50,160 @@ class BaseTrainer:
self, self,
args: TrainingArguments, args: TrainingArguments,
model: HFModel, model: HFModel,
processor: Processor, renderer: Renderer,
dataset: TorchDataset, train_dataset: TorchDataset,
) -> None: ) -> None:
self.args = args self.args = args
self.model = model self.model = model
self.processor = processor self.renderer = renderer
self.dataset = dataset self.train_dataset = train_dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None: # info
pass self.global_step = 0
def create_dataloader(self) -> None: # cached variables
pass self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
# Calculate num_training_steps: max_steps takes priority if set
if self.args.max_steps is not None and self.args.max_steps > 0:
self.num_training_steps = self.args.max_steps
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
if shard_need_optimizer:
self._init_optimizer()
self._shard_model()
else:
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
micro_batch_size=self.args.micro_batch_size,
global_batch_size=self.args.global_batch_size,
cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy,
)
def _shard_model(self) -> None:
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
logger.warning_rank0(
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
)
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids)
else:
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
)
def _init_optimizer(self) -> None:
"""Init optimizer."""
if self.args.optim_config is None:
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
else:
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
def _init_lr_scheduler(self) -> None:
"""Init lr scheduler."""
if self.args.lr_scheduler_config is None:
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
else:
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
)
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
"""Compute log probs.
log_probs: Tensor of shape (batch_size, seq_len - 1)
"""
batch_size, _ = batch["labels"].shape
model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
}
labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs)
logits = outputs.logits.float()
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
@abstractmethod
def compute_loss(self, batch: BatchInput) -> Tensor:
"""Compute the scalar loss."""
...
def fit(self) -> None: def fit(self) -> None:
pass """Train the model."""
self.model.train()
for epoch in range(self.args.num_train_epochs):
self.train_batch_generator.set_epoch(epoch)
for micro_batches in self.train_batch_generator:
self.global_step += 1
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
return
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,15 +14,23 @@
"""The definition of data engine. """The definition of data engine.
Init Data engine: How to use:
data_engine = DataEngine(data_args.train_dataset)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments. 1. Parse dataset info from arguments.
2. Load datasets according to dataset info. 2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary). 3. Build data index (and reweight samples if necessary).
Get Data Sample: Get data sample:
1. Get sample from data index. 1. Get sample from data index.
2. Convert sample to standard format. 2. Convert sample to standard format.
3. Return sample. 3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
""" """
import os import os
@@ -33,7 +41,6 @@ from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..utils.types import DatasetInfo, HFDataset, Sample from ..utils.types import DatasetInfo, HFDataset, Sample
@@ -44,9 +51,9 @@ class DataEngine(Dataset):
data_args: Data arguments. data_args: Data arguments.
""" """
def __init__(self, data_args: DataArguments) -> None: def __init__(self, dataset_path: str) -> None:
self.args = data_args self.path = dataset_path
"""Data arguments.""" """Dataset path."""
self.datasets: dict[str, HFDataset] = {} self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)""" """Dict of (dataset_name, dataset)"""
self.dataset_infos: dict[str, DatasetInfo] = {} self.dataset_infos: dict[str, DatasetInfo] = {}
@@ -61,27 +68,30 @@ class DataEngine(Dataset):
def _get_dataset_info(self) -> None: def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments.""" """Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset) self.dataset_infos = OmegaConf.load(self.path)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset) repo_id, filename = os.path.split(self.path)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath) self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s) elif os.path.exists(self.path): # local file(s)
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}} self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}} self.dataset_infos = {"default": {"path": self.path}}
def _load_dataset(self) -> None: def _load_dataset(self) -> None:
"""Load datasets according to dataset info.""" """Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")
for dataset_name, dataset_info in self.dataset_infos.items(): for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train") split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub": if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset from datasets import load_dataset
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming) self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin from ..plugins.data_plugins.loader import DataLoaderPlugin
@@ -90,18 +100,17 @@ class DataEngine(Dataset):
def _build_data_index(self) -> None: def _build_data_index(self) -> None:
"""Build dataset index.""" """Build dataset index."""
for dataset_name, dataset in self.datasets.items(): for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False) if self.streaming:
if streaming:
data_index = [(dataset_name, -1) for _ in range(1000)] data_index = [(dataset_name, -1) for _ in range(1000)]
else: else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))] data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
size = self.dataset_infos[dataset_name].get("size") size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight") weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin if size or weight:
from ..plugins.data_plugins.loader import DataIndexPlugin from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight) data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index) self.data_index.extend(data_index)
@@ -150,9 +159,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index] dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index) selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list): if isinstance(selected_index, list):
return [ return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
@@ -177,11 +186,11 @@ class DataEngine(Dataset):
if __name__ == "__main__": if __name__ == "__main__":
""" """
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
""" """
from ..config.arg_parser import get_args from ..config.arg_parser import get_args
data_args, *_ = get_args() _, data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args) data_engine = DataEngine(data_args.train_dataset)
print(data_engine[0]) print(data_engine[0])

View File

@@ -12,34 +12,44 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""The definition of model loader. """The definition of model engine.
Init Phase: How to use:
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init workflow:
1. Init processor. 1. Init processor.
2. Init render.
2. Init model config. 2. Init model config.
3. Init model. 3. Init model.
4. Init adapter. 4. Init adapter.
""" """
import torch import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class ModelLoader: class ModelEngine:
"""Model loader. """Model engine.
Args: Args:
model_args: Model arguments. model_args: Model arguments.
is_trainable: Whether to train the model. is_train: Whether to train the model.
""" """
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None: def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -49,17 +59,22 @@ class ModelLoader:
"""Whether to train the model.""" """Whether to train the model."""
self.processor = self._init_processor() self.processor = self._init_processor()
"""Tokenizer or multi-modal processor.""" """Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config() self.model_config = self._init_model_config()
"""Model configuration.""" """Model configuration."""
self.model = self._init_model() self.model = self._init_model()
"""HF model.""" """HF model."""
def _init_processor(self) -> Processor: def _init_processor(self) -> Processor:
"""Init processor.""" """Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained( return AutoProcessor.from_pretrained(
self.args.model, self.args.model,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
) )
def _init_model_config(self) -> HFConfig: def _init_model_config(self) -> HFConfig:
@@ -72,7 +87,7 @@ class ModelLoader:
def _init_model(self) -> HFModel: def _init_model(self) -> HFModel:
"""Init model. """Init model.
Let transformers handle the model init context. Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
""" """
if self.args.model_class == ModelClass.LLM: if self.args.model_class == ModelClass.LLM:
@@ -92,14 +107,24 @@ class ModelLoader:
AutoClass = AutoModel AutoClass = AutoModel
# map the entire model to the current accelerator if self.args.init_config is not None:
model = AutoClass.from_pretrained( from ..plugins.model_plugins.initialization import InitPlugin
self.args.model,
config=self.model_config, init_device = InitPlugin(self.args.init_config.name)()
dtype="auto", else:
device_map=DistributedInterface().current_accelerator, init_device = DistributedInterface().current_device
trust_remote_code=self.args.trust_remote_code,
) if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:
@@ -116,7 +141,7 @@ class ModelLoader:
from ..plugins.model_plugins.kernels.interface import KernelPlugin from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)( model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels") model, include_kernels=self.args.kernel_config.get("include_kernels")
) )
return model return model
@@ -124,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__": if __name__ == "__main__":
""" """
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5 python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
""" """
from ..config.arg_parser import get_args from ..config.arg_parser import get_args
_, model_args, *_ = get_args() model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
print(model_loader.processor) print(model_engine.processor)
print(model_loader.model_config) print(model_engine.model_config)
print(model_loader.model) print(model_engine.model)

View File

@@ -1,119 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate
from ....extras.constants import IGNORE_INDEX
from ...plugins.data_plugins.template import Template
from ...utils.types import Processor, Tensor
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
"""Convert sequence lengths to cumulative sequence lengths."""
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
class DataCollator:
"""Default Data collator."""
processor: "Processor" # processor name -> map to encode_messages function
def __post_init__(self):
# callback for text tokenizer
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
batch = defaultdict(list)
# batching features
for feature in features:
for key in feature.keys():
batch[key].append(feature[key])
for key in batch.keys():
# process padding features
if key in ["input_ids", "attention_mask", "position_ids"]:
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
elif key in ["labels"]:
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
else:
batch[key] = default_collate(batch[key])
return batch
# sft: messages
# dpo: chosen_messages, rejected_messages
@dataclass
class DefaultCollator(DataCollator):
"""Example for now."""
processor: "Processor" # processor name -> map to encode_messages function
template: "Template"
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
features = []
# Check if data is already tokenized (contains input_ids)
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
for feature in messages:
if not isinstance(feature, dict):
raise ValueError(f"Expected dict but got {type(feature)}")
tensor_feature = {
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
for k, v in feature.items()
}
features.append(tensor_feature)
else:
# raw messages need to be encoded
for message in messages:
encoded_message = self.template.encode_messages(self.tokenizer, message)
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
features.append(encoded_message)
return super().__call__(features)
@dataclass
class PairwiseCollator(DataCollator):
pass
@dataclass
class DataCollatorWithPacking(DefaultCollator):
"""Data collator with packing."""
processor: "Processor"
template: "Template"
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
batch = {"cu_seqlens": len2culen(seqlens)}
for input_name in features[0].keys():
if input_name in ("input_ids", "attention_mask", "labels"):
batch[input_name] = torch.cat([feature[input_name] for feature in features])
else:
batch[input_name] = default_collate([feature[input_name] for feature in features])
return batch

View File

@@ -1,277 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)
# base dataloader
class DistributedDataloader(StatefulDataLoader):
"""Base Distributed DataLoader."""
dataset: "TorchDataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
@dataclass
class BaseDataLoader:
"""Default DataLoader."""
processor: Processor
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def init_dataloader(self) -> None:
### init dataloader
pass
def __iter__(self) -> Iterator:
pass
def __next__(self) -> any:
pass
@dataclass
class DataLoader:
"""Default DataLoader."""
processor: "Processor"
dataloader: "DistributedDataloader"
batching_queue: "BaseBatchingQueue"
collate_fn: "DataCollator"
num_micro_batch: int = 1
length: int = 0
drop_last: bool = True
def __init__(
self,
dataloader: any,
collate_fn: "DataCollator",
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
batching_queue: Optional["BaseBatchingQueue"] = None,
) -> None:
self.batching_queue = batching_queue
self.num_micro_batch = num_micro_batch
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
return self._length
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
def origin_batch_data_generator(self):
"""Standard pass-through generator if do not use batching queue."""
while True:
if self._length > 0 and self.step >= self._length:
return
try:
batch = []
data = next(self._data_iter)
# split data into micro batches
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
if self.step < self._length:
# Restart iterator to fill the requested length
self._data_iter = iter(self._dataloader)
try:
batch = []
data = next(self._data_iter)
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
return
else:
return
except Exception as e:
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
raise
def batch_data_generator(self):
if self.batching_queue is None:
yield from self.origin_batch_data_generator()
return
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_queue.is_full_filled():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_queue.empty():
while not self.batching_queue.empty():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["is_padded"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DataLoader iter data exception: {e}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_queue.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy and hasattr(batching_strategy, "state_dict"):
state["batching_strategy_state"] = batching_strategy.state_dict()
if "batching_strategy" in state:
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: dict[str, any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy:
batching_strategy.load_state_dict(state["batching_strategy_state"])
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
def set_epoch(self, epoch: int) -> None:
if hasattr(self._dataloader, "set_epoch"):
self._dataloader.set_epoch(epoch)

View File

@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching utils supports stateful dataloader.
1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""
from collections.abc import Iterator
from typing import Any
from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import Dim, DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset
from .rendering import Renderer
logger = logging.get_logger(__name__)
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
return batch
class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP)
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)
def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
def __len__(self) -> int:
return self._length
def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider)
self._is_resuming = False
return self
def __next__(self):
self._fill_buffer()
batch = self._generate_batch()
if batch is None:
raise StopIteration
return batch
def _fill_buffer(self) -> None:
if self.batching_strategy == BatchingStrategy.NORMAL:
while len(self._buffer) < self.micro_batch_size * self.num_micro_batch:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break
self._buffer.put(samples)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
return default_collate_fn(self._buffer, self._batch_info)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True
def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--train_dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine
model_args, data_args, training_args, _ = get_args()
data_engine = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break

View File

@@ -0,0 +1,121 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from threading import Thread
import torch
from transformers import AsyncTextIteratorStreamer
from ...accelerator.interface import DistributedInterface
from ...config import ModelArguments, SampleArguments
from ...utils.helper import get_tokenizer
from ...utils.types import HFModel, Message, Sample, TorchDataset
from .rendering import Renderer
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")

View File

@@ -0,0 +1,169 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import numpy as np
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor, Sample
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
if is_generate:
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ModelInput(
input_ids=input_ids,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs = []
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
labels=chosen_input["labels"] + rejected_input["labels"],
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
)
if "position_ids" in chosen_input:
model_input["position_ids"] = np.concatenate(
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
)
else:
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
if "extra_info" in sample:
model_input["extra_info"] = sample["extra_info"]
if "_dataset_name" in sample:
model_input["_dataset_name"] = sample["_dataset_name"]
model_inputs.append(model_input)
return model_inputs

View File

@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess
import sys import sys
from copy import deepcopy
from ..extras.env import VERSION, print_env
USAGE = ( USAGE = (
@@ -27,40 +28,152 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
def launch(): def launch():
from .accelerator.helper import get_device_count
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
from .utils.logging import get_logger
logger = get_logger(__name__)
# NOTE:
# `llamafactory-cli <command> ...` enters here first.
# We may re-launch via `torchrun` for distributed training. In that case we must
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
# will misinterpret the first user argument (e.g. yaml config) as the command.
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft": # train command will fallback to sft command if command in _DIST_TRAIN_COMMANDS and (
from .trainers.sft_trainer import run_sft is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
run_sft() # elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torchrun_args = [
"torchrun",
"--nproc-per-node",
nproc_per_node,
]
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
torchrun_args.extend(
[
"--nnodes",
rdzv_nnodes,
"--rdzv-id",
rdzv_id,
"--rdzv-backend",
"c10d",
"--rdzv-endpoint",
f"{master_addr}:{master_port}",
"--max-restarts",
max_restarts,
]
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
torchrun_args.extend(
[
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--master_addr",
master_addr,
"--master_port",
master_port,
]
)
script_args = [__file__, command] + sys.argv[1:]
process = subprocess.run(
torchrun_args + script_args,
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "chat":
from .samplers.cli_sampler import run_chat
run_chat()
elif command == "env": elif command == "env":
print_env() raise NotImplementedError("Environment information is not implemented yet.")
elif command == "version": elif command == "version":
print(WELCOME) raise NotImplementedError("Version information is not implemented yet.")
elif command == "help": elif command == "help":
print(USAGE) print(USAGE)
elif command in _DIST_TRAIN_COMMANDS:
# Single GPU training without torchrun
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
else: else:
print(f"Unknown command: {command}.\n{USAGE}") print(f"Unknown command: {command}.\n{USAGE}")
def main():
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
if command in _DIST_TRAIN_COMMANDS:
sys.argv.pop(1)
else:
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
# treat it as sft by default.
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
command = "sft"
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
if __name__ == "__main__": if __name__ == "__main__":
pass main()

View File

@@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging from ...utils import logging
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -31,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
SharegptMessage = TypedDict( SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str} "SharegptMessage",
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
) )
@@ -61,7 +63,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample) return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register @DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample. """Convert Alpaca sample to SFT sample.
@@ -98,7 +100,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages} return {"messages": messages}
@DataConverterPlugin("sharegpt").register @DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample. """Convert ShareGPT sample to SFT sample.
@@ -117,18 +119,26 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"observation": "tool", "observation": "tool",
"function_call": "assistant", "function_call": "assistant",
} }
sample = {}
messages = [] messages = []
tools = raw_sample.get("tools", "")
for message in raw_sample.get("conversations", []): for message in raw_sample.get("conversations", []):
tag = message["from"] tag = message["from"]
if tag not in tag_mapping: if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}") logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call": elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append( messages.append(
{ {
"role": "assistant", "role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}], "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0, "loss_weight": 1.0,
} }
) )
@@ -141,16 +151,20 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
} }
) )
sample["messages"] = messages
tools = raw_sample.get("tools")
if tools: if tools:
if messages and messages[0]["role"] == "system": try:
messages[0]["content"].append({"type": "tools", "value": tools}) tools: list[dict[str, Any]] = json.loads(tools)
else: sample["tools"] = json.dumps(tools)
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0}) except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return {"messages": messages} return sample
@DataConverterPlugin("pair").register @DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample: def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample. """Convert Pair sample to DPO sample.
@@ -166,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
def process_message(raw_messages: list[OpenaiMessage]): def process_message(raw_messages: list[OpenaiMessage]):
messages = [] messages = []
for message in raw_messages: for message in raw_messages:
messages.append( if message["role"] == "tool":
{ try:
"role": message["role"], tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
"content": [{"type": "text", "value": message["content"]}], except json.JSONDecodeError:
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0, logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
} continue
)
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": message["role"],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
else:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
return messages return messages
chosen_messages = process_message(raw_sample.get("chosen", [])) sample = {}
rejected_messages = process_message(raw_sample.get("rejected", [])) sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages} tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.") raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register @DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset: def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath): if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0]) filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset return dataset
class DataIndexPlugin(BasePlugin): def adjust_data_index(
"""Plugin for adjusting dataset index.""" data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index( Args:
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
) -> list[tuple[str, int]]: size (Optional[int]): Desired dataset size.
"""Adjust dataset index by size and weight. weight (Optional[float]): Desired dataset weight.
Args: Returns:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). list[tuple[str, int]]: Adjusted dataset index.
size (Optional[int]): Desired dataset size. """
weight (Optional[float]): Desired dataset weight. if size is not None:
data_index = random.choices(data_index, k=size)
Returns: if weight is not None:
list[tuple[str, int]]: Adjusted dataset index. data_index = random.choices(data_index, k=int(len(data_index) * weight))
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None: return data_index
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
class DataSelectorPlugin(BasePlugin): def select_data_sample(
"""Plugin for selecting dataset samples.""" data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select( Args:
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
) -> tuple[str, int] | list[tuple[str, int]]: index (Union[slice, list[int], Any]): Index of dataset samples.
"""Select dataset samples.
Args: Returns:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
index (Union[slice, list[int], Any]): Index of dataset samples. """
if isinstance(index, slice):
Returns: return [data_index[i] for i in range(*index.indices(len(data_index)))]
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. elif isinstance(index, list):
""" return [data_index[i] for i in index]
if isinstance(index, slice): else:
return [data_index[i] for i in range(*index.indices(len(data_index)))] raise ValueError(f"Invalid index type {type(index)}.")
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod @classmethod
def get_kernel_id(cls) -> str: def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel.""" """Returns the unique identifier for the kernel."""
return cls._kernel_id return cls._kernel_id
@classmethod @classmethod
def get_device(cls) -> str: def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device return cls._device
@classmethod @classmethod
def check_deps(cls) -> bool: def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available. """Checks if the required dependencies for the kernel are available.
Returns: Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise. bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def apply(cls, **kwargs) -> HFModel: def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model. """Applies the kernel optimization to the model.
Args: Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration. **kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -24,16 +24,17 @@ Init Phase:
import importlib import importlib
from pathlib import Path from pathlib import Path
from ....utils.logging import get_logger from ....utils import logging
from ....utils.plugin import BasePlugin from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry from .registry import Registry
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def scan_all_kernels(): def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory. """Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them. Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels. Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +78,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels(): def get_default_kernels():
r"""Get a list of default registered kernel IDs. """Get a list of default registered kernel IDs.
Returns: Returns:
list[str]: List of kernel IDs. list[str]: List of kernel IDs.
@@ -86,7 +87,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs): def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model. """Applies a specific kernel to the model.
Args: Args:
kernel_id (str): The ID of the kernel to apply. kernel_id (str): The ID of the kernel to apply.
@@ -99,34 +100,41 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id) kernel = default_kernels.get(kernel_id)
if kernel is None: if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found") raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs) kernel.apply(**kwargs)
class KernelPlugin(BasePlugin): class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations.""" """Plugin for managing kernel optimizations."""
pass pass
@KernelPlugin("auto").register @KernelPlugin("auto").register()
def apply_default_kernels(**kwargs): def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
r"""Applies all default registered kernels to the model. """Applies all default registered kernels to the model.
Args: Args:
**kwargs: Keyword arguments passed to the kernel application function. model (HFModel): The model instance to apply kernels to.
Typically includes the model instance and the include_kernels configuration. include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.
Returns: Returns:
HFModel: The model with applied kernels. HFModel: The model with applied kernels.
""" """
if not kwargs.get("include_kernels"): # None/False/empty string if not include_kernels:
return kwargs.get("model") return model
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto elif include_kernels == "auto" or include_kernels is True:
use_kernels = default_kernels.keys() use_kernels = default_kernels.keys()
else: else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels: for kernel in use_kernels:
if kernel not in default_kernels: if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found") raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model") apply_kernel(kernel, model=model)
return model

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function): class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM).""" """Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod @staticmethod
def forward(ctx, x, weight, group_list): def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication. """Performs the forward pass of Grouped Matrix Multiplication.
Args: Args:
ctx: Context object to save tensors for backward pass. ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication. """Performs the backward pass of Grouped Matrix Multiplication.
Args: Args:
ctx: Context object containing saved tensors. ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function): class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU.""" """Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod @staticmethod
def forward(ctx, num_experts, *args): def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM. """Performs the forward pass of Hybrid GMM.
Args: Args:
ctx: Context object to save tensors. ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM. """Performs the backward pass of Hybrid GMM.
Args: Args:
ctx: Context object containing saved tensors. ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused: class NpuMoeFused:
r"""Container for NPU fused MoE forward functions.""" """Container for NPU fused MoE forward functions."""
@staticmethod @staticmethod
def npu_moe_experts_forward( def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations. """Forward pass for MoE experts using NPU fused operations.
Args: Args:
self: The MoE layer instance. self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused: class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions.""" """Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod @staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations. """Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args: Args:
self: The Qwen3 MoE block instance. self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel @register_kernel
class NpuFusedMoEKernel(BaseKernel): class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation.""" """NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe" _kernel_id = "npu_fused_moe"
_device = DeviceType.NPU _device = DeviceType.NPU
@classmethod @classmethod
def apply(cls, **kwargs) -> HFModel: def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model. """Applies the NPU fused MoE kernel to the model.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.
@@ -324,7 +324,7 @@ class NpuFusedMoEKernel(BaseKernel):
if not cls.check_deps(): if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.") raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", []) archs = getattr(model.config, "architectures", None) or []
target_moe_mapping = None target_moe_mapping = None
for arch in archs: for arch in archs:
if arch in kernel_moe_mapping: if arch in kernel_moe_mapping:
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None: if target_moe_mapping is None:
return model return model
for module in model.modules(): for module in model.modules():
class_name = module.__class__.__name__ class_name = module.__class__.__name__
if class_name in target_moe_mapping: if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state): def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU. """SwiGLU forward pass for NPU.
Args: Args:
self: The MLP layer instance. self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states): def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU. """SwiGLU forward pass for GLM4 on NPU.
Args: Args:
self: The GLM4 MLP layer instance. self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states): def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU. """SwiGLU forward pass for Gemma3nText on NPU.
Args: Args:
self: The Gemma3nText MLP layer instance. self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel @register_kernel
class NpuSwiGluKernel(BaseKernel): class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation.""" """NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers # just support apply to the following module layers
expect_modules = frozenset( expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod @classmethod
def apply(cls, **kwargs) -> "HFModel": def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model. """Applies the NPU fused SwiGLU kernel to the model.
Args: Args:
**kwargs: Keyword arguments containing the model. **kwargs: Keyword arguments containing the model.

Some files were not shown because too many files have changed in this diff Show More