38 Commits
v0.9.4 ... main

Author SHA1 Message Date
xvxuopop
762b480131 [feature] support using ray.remote to start distributed training. (#10109) 2026-01-28 16:05:29 +08:00
Jewon Lee
9640f79ae5 [fix] add visual.pos_embed to Qwen3-VL visual model keys (#10139) 2026-01-27 16:33:01 +08:00
jiaqiw09
7ef19eea00 [v0] Fix reward model training safetensors saving (#10137) 2026-01-27 16:27:14 +08:00
浮梦
f9f11dcb97 [v1] support training with fsdp2 (#9773)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-25 19:41:58 +08:00
Pádraic Slattery
641bfdd482 chore: Update outdated GitHub Actions versions (#10123) 2026-01-25 19:12:39 +08:00
Meng WANG
e70651ac58 [feat] support all_exhausted_without_replacement in datasets.interleave_datasets (#10112) 2026-01-20 15:54:07 +08:00
Kingsley
db2f794f7b [misc] update mcore related docker and mca supported models (#10114) 2026-01-19 14:55:16 +08:00
jiaqiw09
44eadbda1c [v1] fix kernel moe patch (#9867) 2026-01-17 09:24:54 +08:00
浮梦
9829ae0a77 [ci] using mp to run kernel test (#9754)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-13 19:43:59 +08:00
Yaowei Zheng
958b9c3468 [v1] add sft (#9752) 2026-01-12 03:15:01 +08:00
Hertz
4d3621e3d3 [model] fixed&added Hunyuan models (#9750) 2026-01-12 01:15:00 +08:00
Yaowei Zheng
a296723697 [v1] upgrade batching (#9751)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-12 00:21:36 +08:00
Hertz
15b87f3125 [model] support HY-MT model (#9746)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-11 16:25:56 +08:00
Yaowei Zheng
9f73a6eb23 [deps] fix package (#9745) 2026-01-10 04:27:53 +08:00
Yaowei Zheng
b2effbd77c [v1] add batch generator (#9744) 2026-01-10 04:24:09 +08:00
Yaowei Zheng
d7d734d54c [misc] fix fp8 (#9742) 2026-01-09 16:17:26 +08:00
Yaowei Zheng
8abb8fb533 [v1] use async streamer (#9741) 2026-01-09 16:12:07 +08:00
Yaowei Zheng
766d5ae6ad [ci] fix workflow (#9738) 2026-01-09 16:12:07 +08:00
Yaowei Zheng
5cccaeec82 [model] clean obsolete models (#9736) 2026-01-09 16:12:07 +08:00
Jackey
5fb5d7ebd3 [model] support for microsoft's Phi-4-mini (#9734) 2026-01-09 12:24:45 +08:00
Peilin Li
03a70ba8dd [fix] correct ktransformers example config paths and templates (#9732) 2026-01-08 10:52:50 +08:00
Vo Van Phuc
5cfd804b59 [refactor] rename lfm template to lfm2 and add LFM 2.5 to README (#9731) 2026-01-07 19:25:04 +08:00
Yaowei Zheng
4c1eb922e2 [misc] fix parser (#9730) 2026-01-07 17:36:08 +08:00
Vo Van Phuc
958fb523a2 [model] support LiquidAI's LFM2.5-VL vision-language model (#9729) 2026-01-07 17:20:29 +08:00
Vo Van Phuc
b4e051bea4 [model] support for LiquidAI's LFM2.5 (Liquid Foundation Models) (#9726) 2026-01-07 14:14:47 +08:00
浮梦
d43e1007e8 [ci] improve cuda ci cache (#9725)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-01-07 12:34:40 +08:00
Xunpeng Xiao
f89d9367e5 [assets] update README.md (#9724) 2026-01-07 12:11:50 +08:00
Yaowei Zheng
d22de0d4bf [v1] add renderer ut (#9722) 2026-01-07 02:06:07 +08:00
Yaowei Zheng
ea0b4e2466 [v1] add cli sampler (#9721) 2026-01-06 23:31:27 +08:00
yanglele
e944dc442c [feature] add support for EAFT loss (#9720)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-06 23:07:12 +08:00
Xunpeng Xiao
68119e5522 [misc] Add a PyTorch version warning for Conv3D. (#9715) 2026-01-05 13:26:29 +08:00
Yaowei Zheng
f60a6e3d01 [v1] add init plugin (#9716) 2026-01-04 20:51:46 +08:00
jiaqiw09
81b8a50aa5 [deps] Update pyproject.toml and requirements (#9714)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-04 19:52:16 +08:00
Yaowei Zheng
8600530002 [misc] lint (#9710) 2026-01-04 13:47:56 +08:00
Hertz
9ae62c6fc0 [model] support Youtu-LLM-2B (#9707) 2026-01-04 13:17:57 +08:00
Xunpeng Xiao
0087bc253b [misc] Compatible with an empty architectures field in config.json (#9709) 2026-01-04 12:11:35 +08:00
Santosh Bhavani
355d5c5e5a [fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-01-01 10:18:02 +08:00
Yaowei Zheng
6fe6bd290b [misc] set dev version (#9703) 2025-12-31 23:41:40 +08:00
139 changed files with 4439 additions and 2667 deletions

View File

@@ -50,7 +50,7 @@ jobs:
docker-images: false
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Get llamafactory version
id: version

View File

@@ -21,7 +21,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7

View File

@@ -54,10 +54,11 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
@@ -70,7 +71,8 @@ jobs:
run: |
uv venv
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
uv pip install -e ".[dev]"
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install transformers
if: ${{ matrix.transformers }}
@@ -79,7 +81,7 @@ jobs:
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
@@ -87,25 +89,18 @@ jobs:
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -35,9 +35,16 @@ jobs:
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
HF_HOME: "${{ github.workspace }}/../.runner_cache/huggingface"
UV_CACHE_DIR: "${{ github.workspace }}/../.runner_cache/uv"
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
@@ -52,37 +59,21 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install -e ".[dev]"
- name: Cache HuggingFace models
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: ${{ runner.temp }}/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

View File

@@ -43,10 +43,11 @@ jobs:
HF_ENDPOINT: https://hf-mirror.com
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OS_NAME: ${{ matrix.os }}
UV_NO_SYNC: 1
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install uv
uses: astral-sh/setup-uv@v7
@@ -58,8 +59,9 @@ jobs:
- name: Install dependencies
run: |
uv venv
uv pip install torch-npu==${{matrix.pytorch_npu}}
uv pip install -e ".[dev]"
uv pip install -r requirements/npu.txt
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Install node
run: |
@@ -68,35 +70,18 @@ jobs:
curl -fsSL https://deb.nodesource.com/setup_20.x | bash -
apt-get install -y nodejs
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
make style && make quality
env:
UV_NO_SYNC: 1
- name: Check license
run: |
make license
env:
UV_NO_SYNC: 1
- name: Check build
run: |
make build
env:
UV_NO_SYNC: 1
- name: Test with pytest
run: |
make test
env:
UV_NO_SYNC: 1
HF_HOME: /root/.cache/huggingface
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"

1
.gitignore vendored
View File

@@ -176,6 +176,7 @@ llamaboard_cache/
llamaboard_config/
saves/
output/
outputs/
wandb/
swanlog/
generated_predictions.jsonl

View File

@@ -92,7 +92,7 @@ Read technical notes:
## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen3, Qwen3-VL, DeepSeek, Gemma, GLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), [Muon](https://github.com/KellerJordan/Muon), [OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
@@ -279,11 +279,10 @@ Read technical notes:
| Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
@@ -292,12 +291,13 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@@ -307,18 +307,17 @@ Read technical notes:
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
@@ -327,8 +326,6 @@ Read technical notes:
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -514,12 +511,13 @@ huggingface-cli login
#### Install from Source
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]"
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LlamaFactory
pip install -e .
pip install -r requirements/metrics.txt
```
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt`
Additional dependencies for specific features are available in `examples/requirements/`.
@@ -577,36 +575,21 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
<details><summary>For Ascend NPU users</summary>
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -r requirements/npu.txt`. Additionally, you need to install the **Ascend CANN Toolkit and Kernels**. Please follow the [installation tutorial](https://llamafactory.readthedocs.io/en/latest/advanced/npu_installation.html).
You can also download the pre-built Docker images:
```bash
# replace the url according to your CANN version and devices
# install CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.0.alpha002_linux-"$(uname -i)".run --install
# Docker Hub
docker pull hiyouga/llamafactory:latest-npu-a2
docker pull hiyouga/llamafactory:latest-npu-a3
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C20SPC702/Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run
bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# quay.io
docker pull quay.io/ascend/llamafactory:latest-npu-a2
docker pull quay.io/ascend/llamafactory:latest-npu-a3
```
| Requirement | Minimum | Recommend |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### Install BitsAndBytes
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:

View File

@@ -94,7 +94,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen3、Qwen3-VL、DeepSeek、Gemma、GLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、[Muon](https://github.com/KellerJordan/Muon)、[OFT](https://github.com/huggingface/peft/tree/main/src/peft/tuners/oft)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
@@ -281,11 +281,10 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| 模型名 | 参数量 | Template |
| ----------------------------------------------------------------- | -------------------------------- | -------------------- |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (LLM/Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 3-3.2](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie/ernie_nothink |
| [ERNIE-4.5](https://huggingface.co/baidu) | 0.3B/21B/300B | ernie_nothink |
| [Falcon/Falcon H1](https://huggingface.co/tiiuae) | 0.5B/1.5B/3B/7B/11B/34B/40B/180B | falcon/falcon_h1 |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma/gemma2 |
| [Gemma 3/Gemma 3n](https://huggingface.co/google) | 270M/1B/4B/6B/8B/12B/27B | gemma3/gemma3n |
@@ -294,12 +293,13 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan (MT)](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@@ -309,18 +309,17 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 |
| [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 |
| [MiniCPM 4](https://huggingface.co/openbmb) | 0.5B/8B | cpm4 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 |
| [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
| [Phi-3/Phi-3.5](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 |
| [Phi-4-mini/Phi-4](https://huggingface.co/microsoft) | 3.8B/14B | phi4_mini/phi4 |
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
@@ -329,8 +328,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Qwen3-VL](https://huggingface.co/Qwen) | 2B/4B/8B/30B/32B/235B | qwen3_vl |
| [Seed (OSS/Coder)](https://huggingface.co/ByteDance-Seed) | 8B/36B | seed_oss/seed_coder |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [VibeThinker-1.5B](https://huggingface.co/WeiboAI) | 1.5B | qwen3 |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
@@ -516,12 +513,13 @@ huggingface-cli login
#### 从源码安装
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[metrics]"
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
cd LlamaFactory
pip install -e .
pip install -r requirements/metrics.txt
```
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
可选的额外依赖项:`metrics``deepspeed`。使用 `pip install -e . && pip install -r requirements/metrics.txt -r requirements/deepspeed.txt` 安装。
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
@@ -579,36 +577,20 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
<details><summary>昇腾 NPU 用户指南</summary>
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -r requirements/npu.txt` 命令安装。此外,还需要安装 **Ascend CANN Toolkit 与 Kernels**,安装方法请参考[安装教程](https://llamafactory.readthedocs.io/zh-cn/latest/advanced/npu_installation.html)。
您可以直接下载预安装的最新docker镜像
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
# 安装 CANN Toolkit
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
# Docker Hub
docker pull hiyouga/llamafactory:latest-npu-a2
docker pull hiyouga/llamafactory:latest-npu-a3
# 安装 CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# quay.io
docker pull quay.io/ascend/llamafactory:latest-npu-a2
docker pull quay.io/ascend/llamafactory:latest-npu-a3
```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.7.1 |
| torch-npu | 2.1.0 | 2.7.1 |
| deepspeed | 0.13.2 | 0.13.2 |
| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
#### 安装 BitsAndBytes
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:

File diff suppressed because one or more lines are too long

View File

@@ -32,7 +32,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app
# Install LLaMA Factory
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
RUN pip install --no-cache-dir --no-build-isolation -e . && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -1,12 +1,13 @@
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
FROM nvcr.io/nvidia/pytorch:24.05-py3
# NVIDIA official image (ubuntu-24.04 + cuda-12.9.1 + python-3.12)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-06.html
FROM nvcr.io/nvidia/pytorch:25.06-py3
ENV DEBIAN_FRONTEND=noninteractive
ENV PIP_ROOT_USER_ACTION=ignore
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
ENV PIP_CONSTRAINT=""
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
@@ -14,20 +15,14 @@ RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
RUN pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# RUN pip install vllm==0.8.4 \
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install --trusted-host mirrors.aliyun.com --index-url ${PYPI_MIRROR} \
"megatron-core>=0.13.0,<0.14.0" "deepspeed==0.16.4"
WORKDIR /build
@@ -37,6 +32,8 @@ RUN pip uninstall -y apex && \
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
RUN pip install --no-build-isolation transformer_engine[pytorch]
RUN rm -rf /build
WORKDIR /workspace
@@ -53,14 +50,17 @@ RUN apt-get update && apt-get install -y zip
RUN apt-get install -y openjdk-21-jdk
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory
ARG REPO_URL=https://github.com/hiyouga/LlamaFactory.git
ARG BRANCH=main
WORKDIR /app
# Copy the application into the image
COPY . /app
# Clone the repository
RUN git clone --depth 1 --branch ${BRANCH} ${REPO_URL} /app || \
git clone --depth 1 ${REPO_URL} /app
# Install LLaMA Factory
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

View File

@@ -35,7 +35,8 @@ COPY . /app
# Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]

View File

@@ -34,7 +34,8 @@ COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory
RUN pip uninstall -y torch torchvision torchaudio && \
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -0,0 +1,38 @@
### model
model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
use_eaft_loss: true
### dataset
dataset: identity,alpaca_en_demo
template: qwen
cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: qwen2.5-0_5b/full/sft_eaft
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

View File

@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,9 +1,9 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
template: deepseek
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -1,10 +1,10 @@
model_name_or_path: opensourcerelease/DeepSeek-V3-bf16
adapter_name_or_path: saves/Kllama_deepseekV3
template: deepseek
template: deepseek3
infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransformers]
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -5,6 +5,6 @@ infer_backend: ktransformers # choices: [huggingface, vllm, sglang, ktransforme
trust_remote_code: true
use_kt: true # use KTransformers as LoRA sft backend to inference
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -10,7 +10,7 @@ lora_rank: 8
lora_target: all
### dataset
dataset: identity
dataset: identity, alpaca_en_demo
template: deepseek
cutoff_len: 2048
max_samples: 100000
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -10,8 +10,8 @@ lora_rank: 8
lora_target: all
### dataset
dataset: identity
template: deepseek
dataset: identity, alpaca_en_demo
template: deepseek3
cutoff_len: 2048
max_samples: 100000
overwrite_cache: true
@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -40,7 +40,7 @@ resume_from_checkpoint: null
### ktransformers
use_kt: true # use KTransformers as LoRA sft backend
kt_optimize_rule: examples/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
kt_optimize_rule: examples/ktransformers/kt_optimize_rules/Qwen3Moe-sft-amx.yaml
cpu_infer: 32
chunk_size: 8192

View File

@@ -0,0 +1,34 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -30,7 +30,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
@@ -63,7 +62,7 @@ dependencies = [
"hf-transfer",
"safetensors",
# python
"av",
"av>=10.0.0,<=16.0.0",
"fire",
"omegaconf",
"packaging",
@@ -73,14 +72,9 @@ dependencies = [
# api
"uvicorn",
"fastapi",
"sse-starlette"
"sse-starlette",
]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"

View File

@@ -0,0 +1 @@
deepspeed>=0.10.0,<=0.16.9

4
requirements/dev.txt Normal file
View File

@@ -0,0 +1,4 @@
pre-commit
ruff
pytest
build

3
requirements/metrics.txt Normal file
View File

@@ -0,0 +1,3 @@
nltk
jieba
rouge-chinese

4
requirements/npu.txt Normal file
View File

@@ -0,0 +1,4 @@
torch==2.7.1
torch-npu==2.7.1
torchvision==0.22.1
torchaudio==2.7.1

View File

@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer, Qwen3Config, Qwen3ForCausalLM
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
config = Qwen3Config(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
model = Qwen3ForCausalLM.from_config(config)
model.save_pretrained("tiny-qwen3")
tokenizer.save_pretrained("tiny-qwen3")
model.push_to_hub("llamafactory/tiny-random-qwen3")
tokenizer.push_to_hub("llamafactory/tiny-random-qwen3")

View File

@@ -28,7 +28,7 @@ try:
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.")
print("Please install llamafactory with `pip install -r requirements/metrics.txt`.")
raise

55
scripts/hf2dcp.py Normal file
View File

@@ -0,0 +1,55 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a HuggingFace model to DCP checkpoint format.
Usage:
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
Arguments:
hf_path: Path to the HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
def convert(hf_path: str, dcp_path: str) -> None:
"""Convert HF model weights to DCP.
Args:
hf_path: HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
if not hf_path or not dcp_path:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})

View File

@@ -65,11 +65,17 @@ def merge_dataset(
if not data_args.streaming:
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
strategy_map: str = {
"interleave_under": "first_exhausted",
"interleave_over": "all_exhausted",
"interleave_once": "all_exhausted_without_replacement",
}[data_args.mix_strategy]
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
stopping_strategy=strategy_map, # type: ignore
)
else:

View File

@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
return messages
@dataclass
class LFMVLPlugin(BasePlugin):
r"""Plugin for LFM2.5-VL vision-language models.
LFM2.5-VL uses dynamic image token counts based on image resolution.
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
"""
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
if self.expand_mm_tokens and len(images) > 0:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
spatial_shapes = mm_inputs.get("spatial_shapes", [])
else:
spatial_shapes = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
h, w = spatial_shapes[num_image_tokens].tolist()
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
else:
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
return messages
PLUGINS = {
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
@@ -2104,6 +2171,7 @@ PLUGINS = {
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,

View File

@@ -649,42 +649,6 @@ register_template(
)
register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}###"]),
format_system=StringFormatter(slots=["System: {{content}}###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
)
register_template(
name="atom",
format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
register_template(
name="bailing",
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
@@ -712,20 +676,6 @@ register_template(
)
register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@@ -734,14 +684,6 @@ register_template(
)
register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
)
register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@@ -784,29 +726,6 @@ register_template(
)
register_template(
name="codegeex2",
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"你是一位智能编程助手你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题"
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
register_template(
name="cohere",
format_user=StringFormatter(
@@ -822,25 +741,6 @@ register_template(
)
register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
# copied from chatml template
register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from chatml template
register_template(
name="cpm4",
@@ -1239,19 +1139,12 @@ register_template(
register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"],
name="hunyuan_small",
format_user=StringFormatter(slots=["<hy_User>{{content}}<hy_place▁holder▁no▁8>"]),
format_assistant=StringFormatter(slots=["{{content}}<hy_place▁holder▁no▁2>"]),
format_system=StringFormatter(slots=["{{content}}<hy_place▁holder▁no▁3>"]),
format_prefix=EmptyFormatter(slots=["<hy_begin▁of▁sentence>"]),
stop_words=["<hy_place▁holder▁no▁2>"],
)
@@ -1330,6 +1223,47 @@ register_template(
)
register_template(
name="lfm2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm2"),
default_system="You are a helpful AI assistant.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
)
register_template(
name="lfm2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm2"),
default_system="You are a helpful multimodal assistant by Liquid AI.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
)
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@@ -1576,23 +1510,6 @@ register_template(
)
# copied from chatml template
register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手你的名字是Marco-o1."
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文但是有2个特例一个是对原文中的引用另一个是是数学应该使用markdown格式<Output>内的输出需要遵循用户输入的语言。\n"
),
stop_words=["<|im_end|>"],
)
# copied from qwen template
register_template(
name="mimo",
@@ -1804,13 +1721,6 @@ register_template(
)
register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
register_template(
name="paligemma",
format_user=StringFormatter(slots=["{{content}}\n"]),
@@ -1869,6 +1779,17 @@ register_template(
)
register_template(
name="phi4_mini",
format_user=StringFormatter(slots=["<|user|>{{content}}<|end|><|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
format_system=StringFormatter(slots=["<|system|>{{content}}<|end|>"]),
format_tools=StringFormatter(slots=["<|tool|>{{content}}<|/tool|>"]),
stop_words=["<|end|>"],
replace_eos=True,
)
# copied from ministral template
register_template(
name="pixtral",
@@ -2104,41 +2025,6 @@ register_template(
)
# copied from llama3 template
register_template(
name="skywork_o1",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words=["<|eot_id|>", "<|eom_id|>"],
)
register_template(
name="smollm",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@@ -2175,13 +2061,6 @@ register_template(
)
register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
)
register_template(
name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
@@ -2225,32 +2104,6 @@ register_template(
)
register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words=["<|End|>"],
)
# copied from chatml template
register_template(
name="yi",
@@ -2278,6 +2131,21 @@ register_template(
)
register_template(
name="youtu",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
format_system=StringFormatter(slots=["{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end_of_text|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@@ -2292,10 +2160,3 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.",
)
register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import json
import re
from abc import ABC, abstractmethod
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>"""
)
LFM2_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
@dataclass
class ToolUtils(ABC):
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
class LFM2ToolUtils(ToolUtils):
r"""LFM2.5 tool using template with Pythonic function call syntax."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_list = []
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_list.append(tool)
return LFM2_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
calls = []
for name, args_json in functions:
args = json.loads(args_json)
kwargs_parts = []
for key, value in args.items():
if isinstance(value, str):
kwargs_parts.append(f'{key}="{value}"')
else:
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
calls.append(f"{name}({', '.join(kwargs_parts)})")
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
@staticmethod
def _ast_to_value(node: ast.AST) -> Any:
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
# Handle JSON-style true/false/null as Name nodes
if isinstance(node, ast.Name):
if node.id == "true":
return True
elif node.id == "false":
return False
elif node.id == "null":
return None
else:
raise ValueError(f"Unknown identifier: {node.id}")
# Use literal_eval for other cases (strings, numbers, lists, dicts)
return ast.literal_eval(node)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
# Extract content between tool call markers
start_marker = "<|tool_call_start|>"
end_marker = "<|tool_call_end|>"
start_idx = content.find(start_marker)
if start_idx == -1:
return content
end_idx = content.find(end_marker, start_idx)
if end_idx == -1:
return content
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
# Parse Pythonic function call syntax using AST
try:
tree = ast.parse(tool_call_str, mode="eval")
except SyntaxError:
return content
# Handle both single call and list of calls
if isinstance(tree.body, ast.List):
call_nodes = tree.body.elts
elif isinstance(tree.body, ast.Call):
call_nodes = [tree.body]
else:
return content
results = []
for node in call_nodes:
if not isinstance(node, ast.Call):
return content
# Extract function name
if isinstance(node.func, ast.Name):
func_name = node.func.id
else:
return content
# Extract keyword arguments
args_dict = {}
for keyword in node.keywords:
key = keyword.arg
try:
value = LFM2ToolUtils._ast_to_value(keyword.value)
except (ValueError, SyntaxError):
return content
args_dict[key] = value
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"lfm2": LFM2ToolUtils(),
"minimax1": MiniMaxM1ToolUtils(),
"minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(),

View File

@@ -57,6 +57,7 @@ LLAMABOARD_CONFIG = "llamaboard_config.yaml"
MCA_SUPPORTED_MODELS = {
"deepseek_v3",
"glm4_moe",
"llama",
"mistral",
"mixtral",
@@ -181,51 +182,6 @@ register_model_group(
)
register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
template="baichuan",
)
register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
},
},
template="baichuan2",
)
register_model_group(
models={
"BLOOM-560M": {
@@ -262,21 +218,6 @@ register_model_group(
)
register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm",
)
register_model_group(
models={
"Breeze-7B": {
@@ -290,17 +231,6 @@ register_model_group(
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "zai-org/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
template="chatglm2",
)
register_model_group(
models={
"ChatGLM3-6B-Base": {
@@ -347,17 +277,6 @@ register_model_group(
)
register_model_group(
models={
"CodeGeeX4-9B-Chat": {
DownloadSource.DEFAULT: "zai-org/codegeex4-all-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
},
},
template="codegeex4",
)
register_model_group(
models={
"CodeGemma-7B": {
@@ -642,15 +561,15 @@ register_model_group(
register_model_group(
models={
"ERNIE-4.5-0.3B-PT": {
"ERNIE-4.5-0.3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-0.3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-0.3B-PT",
},
"ERNIE-4.5-21B-A3B-PT": {
"ERNIE-4.5-21B-A3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-21B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-21B-A3B-PT",
},
"ERNIE-4.5-300B-A47B-PT": {
"ERNIE-4.5-300B-A47B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-300B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-300B-A47B-PT",
},
@@ -661,7 +580,7 @@ register_model_group(
register_model_group(
models={
"ERNIE-4.5-VL-28B-A3B-PT": {
"ERNIE-4.5-VL-28B-A3B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT",
},
@@ -669,7 +588,7 @@ register_model_group(
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-28B-A3B-Thinking",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking",
},
"ERNIE-4.5-VL-424B-A47B-Base-PT": {
"ERNIE-4.5-VL-424B-A47B-Instruct": {
DownloadSource.DEFAULT: "baidu/ERNIE-4.5-VL-424B-A47B-PT",
DownloadSource.MODELSCOPE: "PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT",
},
@@ -1226,19 +1145,50 @@ register_model_group(
register_model_group(
models={
"Hunyuan-0.5B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-0.5B-Instruct",
},
"Hunyuan-1.8B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-1.8B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-1.8B-Instruct",
},
"Hunyuan-4B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-4B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-4B-Instruct",
},
"Hunyuan-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/Hunyuan-7B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-7B-Instruct",
},
"Hunyuan-MT-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-MT-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-MT-7B",
},
"HY-MT1.5-7B-Instruct": {
DownloadSource.DEFAULT: "tencent/HY-MT1.5-7B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-7B",
},
"Hunyuan-A13B-Instruct": {
DownloadSource.DEFAULT: "tencent/Hunyuan-A13B-Instruct",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/Hunyuan-A13B-Instruct",
},
},
template="hunyuan",
)
register_model_group(
models={
"HY-MT1.5-1.8B-Instruct": {
DownloadSource.DEFAULT: "tencent/HY-MT1.5-1.8B",
DownloadSource.MODELSCOPE: "Tencent-Hunyuan/HY-MT1.5-1.8B",
},
},
template="hunyuan_small",
)
register_model_group(
models={
"Index-1.9B-Base": {
@@ -1266,29 +1216,6 @@ register_model_group(
)
register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern",
)
register_model_group(
models={
"InternLM2-7B": {
@@ -1485,11 +1412,25 @@ register_model_group(
register_model_group(
models={
"LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
"LFM2.5-1.2B": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
},
"LFM2.5-1.2B-Instruct": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
},
},
template="lfm2",
)
register_model_group(
models={
"LFM2.5-VL-1.6B": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
},
},
template="lfm2_vl",
multimodal=True,
)
@@ -1804,17 +1745,6 @@ register_model_group(
)
register_model_group(
models={
"Marco-o1-Chat": {
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
},
},
template="marco",
)
register_model_group(
models={
"MiMo-7B-Base": {
@@ -1885,33 +1815,6 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
},
"MiniCPM-2B-DPO-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
},
},
template="cpm",
)
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group(
models={
"MiniCPM4-0.5B-Chat": {
@@ -1949,26 +1852,10 @@ register_model_group(
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-4": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4",
},
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-4.5": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-4_5",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-4_5",
@@ -2226,33 +2113,6 @@ register_model_group(
)
register_model_group(
models={
"Orion-14B-Base": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
},
"Orion-14B-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
},
"Orion-14B-Long-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
},
"Orion-14B-RAG-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
},
"Orion-14B-Plugin-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
},
},
template="orion",
)
register_model_group(
models={
"PaliGemma-3B-pt-224": {
@@ -2349,20 +2209,6 @@ register_model_group(
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
},
}
)
register_model_group(
models={
"Phi-3-4B-4k-Instruct": {
@@ -2419,6 +2265,15 @@ register_model_group(
template="phi4",
)
register_model_group(
models={
"Phi-4-3.8B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-4-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-4-mini-instruct",
},
},
template="phi4_mini",
)
register_model_group(
models={
@@ -2432,228 +2287,6 @@ register_model_group(
)
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
},
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
},
"Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
},
"Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
},
"Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
},
"CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
},
"CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-0.5B": {
@@ -3421,27 +3054,6 @@ register_model_group(
)
register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
register_model_group(
models={
"Skywork-o1-Open-Llama-3.1-8B": {
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
}
},
template="skywork_o1",
)
register_model_group(
models={
"SmolLM-135M": {
@@ -3536,30 +3148,6 @@ register_model_group(
)
register_model_group(
models={
"TeleChat-1B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
},
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-52B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
},
},
template="telechat",
)
register_model_group(
models={
"TeleChat2-3B-Chat": {
@@ -3674,80 +3262,6 @@ register_model_group(
)
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi",
)
register_model_group(
models={
"Yi-6B": {
@@ -3846,6 +3360,21 @@ register_model_group(
)
register_model_group(
models={
"Youtu-LLM-2B-Instruct": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B",
},
"Youtu-LLM-2B-Base": {
DownloadSource.DEFAULT: "tencent/Youtu-LLM-2B-Base",
DownloadSource.MODELSCOPE: "Tencent-YouTu-Research/Youtu-LLM-2B-Base",
},
},
template="youtu",
)
register_model_group(
models={
"Yuan2-2B-Chat": {

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.4"
VERSION = "0.9.5.dev0"
def print_env() -> None:

View File

@@ -157,6 +157,33 @@ def get_current_device() -> "torch.device":
return torch.device(device)
def get_device_name() -> str:
r"""Get the name of available devices."""
if is_torch_xpu_available():
device = "xpu"
elif is_torch_npu_available():
device = "npu"
elif is_torch_mps_available():
device = "mps"
elif is_torch_cuda_available():
device = "gpu"
else:
device = "cpu"
return device
def get_torch_device():
r"""Get the torch device namespace for the available devices."""
device_name = get_device_name()
device_name = "cuda" if device_name == "gpu" else device_name
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning_rank0(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_device_count() -> int:
r"""Get the number of available devices."""
if is_torch_xpu_available():

View File

@@ -63,9 +63,9 @@ class DataArguments:
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
)
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
mix_strategy: Literal["concat", "interleave_under", "interleave_over", "interleave_once"] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling/sampling w.o. replacement)."},
)
interleave_probs: str | None = field(
default=None,

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
use_eaft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the EAFT loss."},
)
eaft_alpha: float = field(
default=1.0,
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},

View File

@@ -298,23 +298,6 @@ class QuantizationArguments:
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from pathlib import Path
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
if args is not None:
return args
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else:
return sys.argv[1:]
@@ -142,14 +143,6 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies(
model_args: "ModelArguments",
@@ -347,6 +340,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if not finetuning_args.use_mca and training_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
@@ -363,6 +359,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"

View File

@@ -14,7 +14,6 @@
import json
from dataclasses import dataclass, field
from typing import Literal
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
@@ -40,59 +39,55 @@ else:
class RayArguments:
r"""Arguments pertaining to the Ray training."""
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: dict | str | None = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)
master_addr: str | None = field(
default=None,
metadata={"help": "The master address for init_process_group"},
)
master_port: str | None = field(
default=None,
metadata={"help": "The master port for init_process_group"},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
)
import pyarrow.fs as fs
@dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments):
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(

View File

@@ -15,6 +15,7 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -29,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
@@ -203,6 +205,16 @@ def load_model(
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "
"Please downgrade torch to <2.9 or remove Conv3D. "
"See https://github.com/pytorch/pytorch/issues/166122"
)
if not is_trainable:
model.requires_grad_(False)
model.eval()
@@ -218,7 +230,7 @@ def load_model(
)
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model)
if is_trainable:

View File

@@ -356,7 +356,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@@ -365,7 +365,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@@ -374,7 +374,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)

View File

@@ -138,18 +138,25 @@ def patch_config(
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy")
if "InternVLChatModel" in getattr(config, "architectures", []):
architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
raise RuntimeError(
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
)
if getattr(config, "model_type", None) == "qwen3_omni_moe":
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging
if TYPE_CHECKING:
from ..hparams import ModelArguments
from ..hparams import TrainingArguments
logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
"""
if not model_args.fp8:
if not training_args.fp8:
return []
try:
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
from accelerate.utils import AORecipeKwargs
backend = getattr(training_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
backend = getattr(model_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
try:
# Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used
config = None
@@ -83,7 +93,10 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +105,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
"fp8" if FP8 is enabled, None otherwise
"""
return "fp8" if model_args.fp8 else None
return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None:
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +125,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
import os
if not model_args.fp8:
if not training_args.fp8:
return
# Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +135,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto":
os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args)
fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation.
Args:
accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
if not model_args.fp8:
if not training_args.fp8:
return
# Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto":
logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +180,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True

View File

@@ -19,16 +19,15 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer):
@@ -41,11 +40,13 @@ class CustomTrainer(Trainer):
model_args: Optional["ModelArguments"] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs)
if processor is not None:
@@ -64,9 +65,8 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -109,6 +109,27 @@ class PairwiseTrainer(Trainer):
else:
return loss
@override
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if state_dict is None:
state_dict = self.model.state_dict()
if self.args.save_safetensors:
from collections import defaultdict
ptrs = defaultdict(list)
for name, tensor in state_dict.items():
if isinstance(tensor, torch.Tensor):
ptrs[id(tensor)].append(name)
for names in ptrs.values():
if len(names) > 1:
names.sort()
for name in names[1:]:
state_dict.pop(name, None)
super()._save(output_dir, state_dict)
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""Save model predictions to `output_dir`.

View File

@@ -27,18 +27,17 @@ from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -55,13 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs)
if processor is not None:
@@ -88,9 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -20,7 +20,6 @@
import json
import os
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
@@ -34,6 +33,7 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.misc import get_device_name
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@@ -49,15 +49,15 @@ if is_apollo_available():
if is_ray_available():
import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
from ..hparams import DataArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -679,6 +681,67 @@ def _dft_cross_entropy(
return loss
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
vocab_size = logits.size(-1)
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)
loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
return loss
def _eaft_cross_entropy(
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
valid_losses = per_token_loss[valid_mask]
with torch.no_grad():
source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)
weighted_losses = valid_losses * adaptive_weight
if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss
def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
@@ -744,36 +807,88 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
def get_placement_group(num_workers: int) -> tuple["PlacementGroup", dict[str, int]]:
r"""Get the Ray placement group for distributed training."""
bundle = {"CPU": 10}
device_name = get_device_name().upper()
if device_name != "CPU":
bundle[device_name] = 1
bundles = [bundle for _ in range(num_workers)]
pg = placement_group(bundles, strategy="PACK")
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
return pg, bundle
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
def get_ray_remote_config_for_worker(
placement_group: "PlacementGroup",
bundle_idx: int,
rank: int,
world_size: int,
master_addr: str,
master_port: str,
env: dict[str, str] = None,
) -> dict[str, Any]:
r"""Get the remote config for a Ray worker."""
env_vars = {
"RANK": str(rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
"TORCHELASTIC_USE_AGENT_STORE": "False",
}
env.update(env_vars)
remote_config = {
"scheduling_strategy": PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_filesystem=ray_args.ray_storage_filesystem,
storage_path=storage_path,
),
)
return trainer
"runtime_env": {"env_vars": env},
"num_cpus": 10,
}
device_name = get_device_name()
if device_name == "gpu":
remote_config["num_gpus"] = 1
elif device_name == "npu":
remote_config["resources"] = {"NPU": 1}
return remote_config
def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
return head_ip
def sort_placement_group_by_node_ip(placement_group: "PlacementGroup", master_addr: str = None) -> list[int]:
r"""Sort the placement group bundles by their node IP addresses."""
@ray.remote
def _get_node_ip():
return ray.util.get_node_ip_address().strip("[]")
tasks = []
for bundle_idx in range(placement_group.bundle_count):
task = _get_node_ip.options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=bundle_idx,
),
).remote()
tasks.append(task)
bundle_ips = ray.get(tasks)
bundle_node_ip_list = list(enumerate(bundle_ips))
sorted_bundle_node_ip_list = sorted(bundle_node_ip_list, key=lambda x: x[1])
sorted_bundle_indices = [item[0] for item in sorted_bundle_node_ip_list]
if master_addr is not None:
preferred_indices = [idx for idx, ip in bundle_node_ip_list if ip == master_addr]
if preferred_indices:
remaining = [i for i in sorted_bundle_indices if i not in preferred_indices]
sorted_bundle_indices = preferred_indices + remaining
return sorted_bundle_indices

View File

@@ -23,9 +23,9 @@ from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
@@ -34,12 +34,17 @@ from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback
from .trainer_utils import (
get_placement_group,
get_ray_head_node_ip,
get_ray_remote_config_for_worker,
get_swanlab_callback,
sort_placement_group_by_node_ip,
)
if is_ray_available():
import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
@@ -115,13 +120,7 @@ def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["Tra
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
_ray_training_function(ray_args, config={"args": args, "callbacks": callbacks})
else:
_training_function(config={"args": args, "callbacks": callbacks})
@@ -212,3 +211,94 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
with open(ollama_modelfile, "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
class Worker:
def __init__(self):
self._setup_env_visible_devices()
local_rank = os.environ.get("LOCAL_RANK", "0")
get_torch_device().set_device(int(local_rank))
def _setup_env_visible_devices(self) -> None:
RAY_NOSET_VISIBLE_DEVICES_LIST = [
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
]
is_ray_noset_visible_devices = any(os.environ.get(env_var, None) for env_var in RAY_NOSET_VISIBLE_DEVICES_LIST)
if is_ray_noset_visible_devices:
device_name = get_device_name().upper()
local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0]
os.environ["LOCAL_RANK"] = local_rank
else:
os.environ["LOCAL_RANK"] = "0"
def _training_function(self, config: dict[str, Any]) -> None:
_training_function(config)
def _ray_training_function(ray_args: "RayArguments", config: dict[str, Any]) -> None:
num_workers = ray_args.ray_num_workers
master_addr = ray_args.master_addr
master_port = ray_args.master_port
logger.info(f"Using ray.remote mode with {num_workers} workers for distributed training.")
# initialize ray
if not ray.is_initialized():
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
else:
ray.init()
# verify resources
device_name = get_device_name().upper()
total_devices = int(ray.cluster_resources().get(device_name, 0))
if num_workers > total_devices:
raise ValueError(
f"The number of devices in the Ray cluster ({total_devices}) should be greater than num_workers ({num_workers})."
)
# verify master_addr
if master_addr is None:
master_addr = get_ray_head_node_ip()
logger.info(f"`master_addr` is not specified, using head node ip: {master_addr}.")
else:
nodes = [node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]
if master_addr not in nodes:
raise ValueError(f"The `master_addr` ({master_addr}) is not in Ray cluster or not alive ")
# create placementgroup for resource management
pg, bundle = get_placement_group(total_devices)
ray.get(pg.ready())
logger.info(f"Create placement group with {num_workers} bundles: {bundle}")
# get sorted_bundle_indices
sorted_bundle_indices = sort_placement_group_by_node_ip(pg, master_addr)
# get master port
if master_port is None:
master_port = find_available_port()
logger.info(f"`master_port` is not specified, using available port: {master_port}.")
master_port = str(master_port)
# backing up environment variables
current_env = dict(os.environ.items())
# launch workers
RayWorker = ray.remote(Worker)
workers = []
for rank in range(num_workers):
remote_config = get_ray_remote_config_for_worker(
placement_group=pg,
bundle_idx=sorted_bundle_indices[rank],
rank=rank,
world_size=num_workers,
master_addr=master_addr,
master_port=master_port,
env=current_env,
)
worker = RayWorker.options(**remote_config).remote()
workers.append(worker)
ray.get([worker._training_function.remote(config=config) for worker in workers])
ray.shutdown()

View File

@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def set_device_index() -> None:
"""Set current accelerator index to local rank."""
if get_current_accelerator().type != DeviceType.CPU:
torch.accelerator.set_device_index(get_local_rank())
@requires_accelerator
def get_current_device() -> torch.device:
"""Get current accelerator device."""
if get_current_accelerator().type == DeviceType.CPU:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available():
@@ -170,6 +180,16 @@ def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs)
return result.tolist()
def get_process_group_backend() -> str:
"""Get backend for init process group."""
if get_current_accelerator().type == DeviceType.NPU:
return "hccl"
elif get_current_accelerator().type == DeviceType.CUDA:
return "nccl"
else:
return "gloo"
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size = get_world_size()

View File

@@ -34,10 +34,14 @@ from typing import Any, Optional
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, TensorLike
from . import helper
logger = logging.get_logger(__name__)
class Dim(str, Enum):
"""Dimension names."""
@@ -119,12 +123,13 @@ class DistributedInterface:
if self._initialized:
return
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count()
if config is None:
@@ -140,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
init_process_group(timeout=timedelta(seconds=timeout), backend=helper.get_process_group_backend())
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
@@ -157,11 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None
self._initialized = True
logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@@ -169,7 +174,7 @@ class DistributedInterface:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
elif self.model_device_mesh is None:
elif not self._is_distributed:
return None
elif dim in self.strategy.data_mesh_dim_names:
return self.data_device_mesh[dim.value]
@@ -178,14 +183,14 @@ class DistributedInterface:
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
if not self._is_distributed or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 0
elif dim is None:
return self._rank
@@ -194,7 +199,7 @@ class DistributedInterface:
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
if not self._is_distributed:
return 1
elif dim is None:
return self._world_size
@@ -209,9 +214,9 @@ class DistributedInterface:
"""Get parallel local world size."""
return self._local_world_size
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
else:
return data
@@ -220,30 +225,36 @@ class DistributedInterface:
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
if self._is_distributed:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
else:
return data
def sync(self) -> None:
"""Synchronize all processes."""
helper.synchronize()
if self._is_distributed:
helper.synchronize()
def barrier(self) -> None:
"""Barrier all processes."""
barrier()
if self._is_distributed:
barrier()
def destroy(self) -> None:
"""Destroy all processes."""
destroy_process_group()
if self._is_distributed:
destroy_process_group()
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))
"""
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())

View File

@@ -0,0 +1,33 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments
__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]

View File

@@ -20,7 +20,7 @@ from typing import Any
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from ...extras.misc import is_env_enabled
from ..utils.env import is_env_enabled
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
@@ -30,24 +30,9 @@ from .training_args import TrainingArguments
InputArgument = dict[str, Any] | list[str] | None
def validate_args(
data_args: DataArguments,
model_args: ModelArguments,
training_args: TrainingArguments,
sample_args: SampleArguments,
):
"""Validate arguments."""
if (
model_args.quant_config is not None
and training_args.dist_config is not None
and training_args.dist_config.name == "deepspeed"
):
raise ValueError("Quantization is not supported with deepspeed backend.")
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, TrainingArguments, SampleArguments]:
"""Parse arguments from command line or config file."""
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
parser = HfArgumentParser([ModelArguments, DataArguments, TrainingArguments, SampleArguments])
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_KEYS")
if args is None:
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
validate_args(*parsed_args)
return tuple(parsed_args)

View File

@@ -17,7 +17,7 @@
import json
from enum import Enum, unique
from enum import StrEnum, unique
class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique
class ModelClass(str, Enum):
class ModelClass(StrEnum):
"""Auto class for model config."""
LLM = "llm"
@@ -45,11 +45,19 @@ class ModelClass(str, Enum):
@unique
class SampleBackend(str, Enum):
class SampleBackend(StrEnum):
HF = "hf"
VLLM = "vllm"
@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

View File

@@ -18,11 +18,11 @@ from dataclasses import dataclass, field
@dataclass
class DataArguments:
dataset: str | None = field(
train_dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset."},
metadata={"help": "Path to the training dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
eval_dataset: str | None = field(
default=None,
metadata={"help": "Path to the evaluation dataset."},
)

View File

@@ -21,20 +21,25 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
model: str = field(
default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
template: str = field(
default="qwen3_nothink",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
@@ -49,6 +54,7 @@ class ModelArguments:
)
def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)

View File

@@ -16,35 +16,73 @@ import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
@dataclass
class TrainingArguments:
output_dir: str = field(
default=os.path.join("outputs", str(uuid4())),
default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
num_train_epochs: int = field(
default=3,
metadata={"help": "Number of training epochs."},
)
max_steps: int | None = field(
default=None,
metadata={"help": "Maximum number of training steps. If set, overrides num_train_epochs."},
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
enable_activation_checkpointing: bool = field(
default=True,
metadata={"help": "Enable activation checkpointing for training."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)
optim_config: PluginConfig | None = field(
default=None,
metadata={"help": "Optimizer configuration for training."},
)
lr_scheduler_config: PluginConfig | None = field(
default=None,
metadata={"help": "Learning rate scheduler configuration for training."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)

View File

@@ -0,0 +1,67 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import AsyncGenerator
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.inference_engine import HuggingFaceEngine
from .utils.rendering import Renderer
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -16,20 +16,33 @@
Init Phase:
1. Init dataloader.
1. Init batch generator.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
5. Init lr scheduler.
Train Phase:
1. Train Loop
"""
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
from abc import abstractmethod
import torch
import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..utils import logging
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class BaseTrainer:
@@ -37,22 +50,160 @@ class BaseTrainer:
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
dataset: TorchDataset,
renderer: Renderer,
train_dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
self.renderer = renderer
self.train_dataset = train_dataset
def init_model_and_optimizer(self) -> None:
pass
# info
self.global_step = 0
def create_dataloader(self) -> None:
pass
# cached variables
self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
# Calculate num_training_steps: max_steps takes priority if set
if self.args.max_steps is not None and self.args.max_steps > 0:
self.num_training_steps = self.args.max_steps
else:
self.num_training_steps = self.args.num_train_epochs * len(self.train_batch_generator)
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
if shard_need_optimizer:
self._init_optimizer()
self._shard_model()
else:
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
micro_batch_size=self.args.micro_batch_size,
global_batch_size=self.args.global_batch_size,
cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy,
)
def _shard_model(self) -> None:
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
logger.warning_rank0(
"dist_config is None but distributed training is enabled; falling back to DistributedDataParallel."
)
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids)
else:
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.model = DistributedPlugin(self.args.dist_config.name)(
self.model,
self.args.dist_config,
)
def _init_optimizer(self) -> None:
"""Init optimizer."""
if self.args.optim_config is None:
_trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(_trainable_params, lr=self.args.learning_rate)
else:
from ..plugins.trainer_plugins.optimizer import OptimizerPlugin
self.optimizer = OptimizerPlugin(self.args.optim_config.name)(self.model, self.args.optim_config)
def _init_lr_scheduler(self) -> None:
"""Init lr scheduler."""
if self.args.lr_scheduler_config is None:
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1.0)
else:
from ..plugins.trainer_plugins.lr_scheduler import LRSchedulerPlugin
self.lr_scheduler = LRSchedulerPlugin(self.args.lr_scheduler_config.name)(
self.optimizer, self.num_training_steps, self.args.lr_scheduler_config
)
def compute_log_probs(self, model: HFModel, batch: BatchInput) -> Tensor:
"""Compute log probs.
log_probs: Tensor of shape (batch_size, seq_len - 1)
"""
batch_size, _ = batch["labels"].shape
model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
}
labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs)
logits = outputs.logits.float()
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = logits[..., :-1, :].contiguous().view(shift_labels.size(0), -1)
return -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
@abstractmethod
def compute_loss(self, batch: BatchInput) -> Tensor:
"""Compute the scalar loss."""
...
def fit(self) -> None:
pass
"""Train the model."""
self.model.train()
for epoch in range(self.args.num_train_epochs):
self.train_batch_generator.set_epoch(epoch)
for micro_batches in self.train_batch_generator:
self.global_step += 1
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
return
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -1,44 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from ..config.sample_args import SampleArguments, SampleBackend
from .model_loader import ModelLoader
class BaseEngine(ABC):
@abstractmethod
def __init__(self, sample_args: SampleArguments, model_loader: ModelLoader) -> None: ...
@abstractmethod
async def generate(self):
pass
@abstractmethod
async def batch_infer(self):
pass
class HuggingFaceEngine(BaseEngine):
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
self.args = sample_args
class ChatSampler:
def __init__(self, model_loader: ModelLoader, sample_args: SampleArguments) -> None:
if sample_args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(model_loader, sample_args)
else:
raise ValueError(f"Unknown sample backend: {sample_args.sample_backend}")

View File

@@ -14,15 +14,23 @@
"""The definition of data engine.
Init Data engine:
How to use:
data_engine = DataEngine(data_args.train_dataset)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
Get data sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
"""
import os
@@ -33,7 +41,6 @@ from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..utils.types import DatasetInfo, HFDataset, Sample
@@ -44,9 +51,9 @@ class DataEngine(Dataset):
data_args: Data arguments.
"""
def __init__(self, data_args: DataArguments) -> None:
self.args = data_args
"""Data arguments."""
def __init__(self, dataset_path: str) -> None:
self.path = dataset_path
"""Dataset path."""
self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)"""
self.dataset_infos: dict[str, DatasetInfo] = {}
@@ -61,27 +68,30 @@ class DataEngine(Dataset):
def _get_dataset_info(self) -> None:
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(self.args.dataset): # local file
self.dataset_infos = OmegaConf.load(self.args.dataset)
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
if self.path.endswith(".yaml") and os.path.isfile(self.path): # local file
self.dataset_infos = OmegaConf.load(self.path)
elif self.path.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.path)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(self.args.dataset): # local file(s)
self.dataset_infos = {"default": {"path": self.args.dataset, "source": "local"}}
elif os.path.exists(self.path): # local file(s)
self.dataset_infos = {"default": {"path": self.path, "source": "local"}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_infos = {"default": {"path": self.args.dataset}}
self.dataset_infos = {"default": {"path": self.path}}
def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")
for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin
@@ -90,18 +100,17 @@ class DataEngine(Dataset):
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
if size or weight:
from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
@@ -150,9 +159,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index)
selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
@@ -177,11 +186,11 @@ class DataEngine(Dataset):
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --train_dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)
_, data_args, *_ = get_args()
data_engine = DataEngine(data_args.train_dataset)
print(data_engine[0])

View File

@@ -12,34 +12,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
"""The definition of model engine.
Init Phase:
How to use:
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor
from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
class ModelEngine:
"""Model engine.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -49,17 +59,22 @@ class ModelLoader:
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
"""HF model."""
def _init_processor(self) -> Processor:
"""Init processor."""
"""Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)
def _init_model_config(self) -> HFConfig:
@@ -72,7 +87,7 @@ class ModelLoader:
def _init_model(self) -> HFModel:
"""Init model.
Let transformers handle the model init context.
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
@@ -92,14 +107,24 @@ class ModelLoader:
AutoClass = AutoModel
# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.peft_config is None:
if self.is_train:
@@ -116,7 +141,7 @@ class ModelLoader:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model
@@ -124,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)
model_args, *_ = get_args()
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)
print(model_engine.model)

View File

@@ -1,119 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate
from ....extras.constants import IGNORE_INDEX
from ...plugins.data_plugins.template import Template
from ...utils.types import Processor, Tensor
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
"""Convert sequence lengths to cumulative sequence lengths."""
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
class DataCollator:
"""Default Data collator."""
processor: "Processor" # processor name -> map to encode_messages function
def __post_init__(self):
# callback for text tokenizer
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
batch = defaultdict(list)
# batching features
for feature in features:
for key in feature.keys():
batch[key].append(feature[key])
for key in batch.keys():
# process padding features
if key in ["input_ids", "attention_mask", "position_ids"]:
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
elif key in ["labels"]:
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
else:
batch[key] = default_collate(batch[key])
return batch
# sft: messages
# dpo: chosen_messages, rejected_messages
@dataclass
class DefaultCollator(DataCollator):
"""Example for now."""
processor: "Processor" # processor name -> map to encode_messages function
template: "Template"
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
features = []
# Check if data is already tokenized (contains input_ids)
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
for feature in messages:
if not isinstance(feature, dict):
raise ValueError(f"Expected dict but got {type(feature)}")
tensor_feature = {
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
for k, v in feature.items()
}
features.append(tensor_feature)
else:
# raw messages need to be encoded
for message in messages:
encoded_message = self.template.encode_messages(self.tokenizer, message)
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
features.append(encoded_message)
return super().__call__(features)
@dataclass
class PairwiseCollator(DataCollator):
pass
@dataclass
class DataCollatorWithPacking(DefaultCollator):
"""Data collator with packing."""
processor: "Processor"
template: "Template"
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
batch = {"cu_seqlens": len2culen(seqlens)}
for input_name in features[0].keys():
if input_name in ("input_ids", "attention_mask", "labels"):
batch[input_name] = torch.cat([feature[input_name] for feature in features])
else:
batch[input_name] = default_collate([feature[input_name] for feature in features])
return batch

View File

@@ -1,277 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...utils.batching_queue import BaseBatchingQueue
from ...utils.logging import get_logger
from ...utils.types import Processor, TorchDataset
from .data_collator import DataCollator
logger = get_logger(__name__)
# base dataloader
class DistributedDataloader(StatefulDataLoader):
"""Base Distributed DataLoader."""
dataset: "TorchDataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
@dataclass
class BaseDataLoader:
"""Default DataLoader."""
processor: Processor
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def init_dataloader(self) -> None:
### init dataloader
pass
def __iter__(self) -> Iterator:
pass
def __next__(self) -> any:
pass
@dataclass
class DataLoader:
"""Default DataLoader."""
processor: "Processor"
dataloader: "DistributedDataloader"
batching_queue: "BaseBatchingQueue"
collate_fn: "DataCollator"
num_micro_batch: int = 1
length: int = 0
drop_last: bool = True
def __init__(
self,
dataloader: any,
collate_fn: "DataCollator",
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
batching_queue: Optional["BaseBatchingQueue"] = None,
) -> None:
self.batching_queue = batching_queue
self.num_micro_batch = num_micro_batch
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
return self._length
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
def origin_batch_data_generator(self):
"""Standard pass-through generator if do not use batching queue."""
while True:
if self._length > 0 and self.step >= self._length:
return
try:
batch = []
data = next(self._data_iter)
# split data into micro batches
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
if self.step < self._length:
# Restart iterator to fill the requested length
self._data_iter = iter(self._dataloader)
try:
batch = []
data = next(self._data_iter)
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
return
else:
return
except Exception as e:
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
raise
def batch_data_generator(self):
if self.batching_queue is None:
yield from self.origin_batch_data_generator()
return
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_queue.is_full_filled():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_queue.empty():
while not self.batching_queue.empty():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["is_padded"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DataLoader iter data exception: {e}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_queue.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy and hasattr(batching_strategy, "state_dict"):
state["batching_strategy_state"] = batching_strategy.state_dict()
if "batching_strategy" in state:
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: dict[str, any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy:
batching_strategy.load_state_dict(state["batching_strategy_state"])
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
def set_epoch(self, epoch: int) -> None:
if hasattr(self._dataloader, "set_epoch"):
self._dataloader.set_epoch(epoch)

View File

@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching utils supports stateful dataloader.
1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""
from collections.abc import Iterator
from typing import Any
from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ...accelerator.interface import Dim, DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.types import BatchInfo, BatchInput, ModelInput, TorchDataset
from .rendering import Renderer
logger = logging.get_logger(__name__)
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))
return batch
class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP)
if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)
def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
def __len__(self) -> int:
return self._length
def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0
self._data_iter = iter(self._data_provider)
self._is_resuming = False
return self
def __next__(self):
self._fill_buffer()
batch = self._generate_batch()
if batch is None:
raise StopIteration
return batch
def _fill_buffer(self) -> None:
if self.batching_strategy == BatchingStrategy.NORMAL:
while len(self._buffer) < self.micro_batch_size * self.num_micro_batch:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break
self._buffer.put(samples)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
return default_collate_fn(self._buffer, self._batch_info)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True
def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--train_dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine
model_args, data_args, training_args, _ = get_args()
data_engine = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break

View File

@@ -0,0 +1,121 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from threading import Thread
import torch
from transformers import AsyncTextIteratorStreamer
from ...accelerator.interface import DistributedInterface
from ...config import ModelArguments, SampleArguments
from ...utils.helper import get_tokenizer
from ...utils.types import HFModel, Message, Sample, TorchDataset
from .rendering import Renderer
class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")

View File

@@ -0,0 +1,169 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import numpy as np
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor, Sample
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
if is_generate:
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ModelInput(
input_ids=input_ids,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs = []
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
labels=chosen_input["labels"] + rejected_input["labels"],
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
)
if "position_ids" in chosen_input:
model_input["position_ids"] = np.concatenate(
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
)
else:
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
if "extra_info" in sample:
model_input["extra_info"] = sample["extra_info"]
if "_dataset_name" in sample:
model_input["_dataset_name"] = sample["_dataset_name"]
model_inputs.append(model_input)
return model_inputs

View File

@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from ..extras.env import VERSION, print_env
from copy import deepcopy
USAGE = (
@@ -27,40 +28,152 @@ USAGE = (
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
_DIST_TRAIN_COMMANDS = ("train", "sft", "dpo", "rm")
def launch():
from .accelerator.helper import get_device_count
from .utils.env import find_available_port, is_env_enabled, use_kt, use_ray
from .utils.logging import get_logger
logger = get_logger(__name__)
# NOTE:
# `llamafactory-cli <command> ...` enters here first.
# We may re-launch via `torchrun` for distributed training. In that case we must
# forward `<command>` as argv[1] to the re-executed script, otherwise the script
# will misinterpret the first user argument (e.g. yaml config) as the command.
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "sft": # train command will fallback to sft command
from .trainers.sft_trainer import run_sft
if command in _DIST_TRAIN_COMMANDS and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
run_sft()
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torchrun_args = [
"torchrun",
"--nproc-per-node",
nproc_per_node,
]
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
torchrun_args.extend(
[
"--nnodes",
rdzv_nnodes,
"--rdzv-id",
rdzv_id,
"--rdzv-backend",
"c10d",
"--rdzv-endpoint",
f"{master_addr}:{master_port}",
"--max-restarts",
max_restarts,
]
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
torchrun_args.extend(
[
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--master_addr",
master_addr,
"--master_port",
master_port,
]
)
script_args = [__file__, command] + sys.argv[1:]
process = subprocess.run(
torchrun_args + script_args,
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "chat":
from .samplers.cli_sampler import run_chat
run_chat()
elif command == "env":
print_env()
raise NotImplementedError("Environment information is not implemented yet.")
elif command == "version":
print(WELCOME)
raise NotImplementedError("Version information is not implemented yet.")
elif command == "help":
print(USAGE)
elif command in _DIST_TRAIN_COMMANDS:
# Single GPU training without torchrun
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
else:
print(f"Unknown command: {command}.\n{USAGE}")
def main():
# sys.argv[1] contains the command (sft/dpo/rm/train), sys.argv[2:] contains the rest args
command = sys.argv[1] if len(sys.argv) > 1 else "sft"
# Routing needs the sub-command, but downstream trainers usually expect argv without it.
if command in _DIST_TRAIN_COMMANDS:
sys.argv.pop(1)
else:
# Backward-compat: if someone runs `torchrun launcher.py config.yaml`,
# treat it as sft by default.
if len(sys.argv) > 1 and sys.argv[1].endswith((".yaml", ".yml")):
command = "sft"
if command in ("train", "sft"):
from llamafactory.v1.trainers.sft_trainer import run_sft
run_sft()
elif command == "dpo":
# from llamafactory.v1.trainers.dpo_trainer import run_dpo
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
if __name__ == "__main__":
pass
main()

View File

@@ -13,11 +13,12 @@
# limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -31,7 +32,8 @@ class AlpacaSample(TypedDict, total=False):
SharegptMessage = TypedDict(
"SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}
"SharegptMessage",
{"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str},
)
@@ -61,7 +63,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +100,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
@DataConverterPlugin("sharegpt").register
@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -117,18 +119,26 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"observation": "tool",
"function_call": "assistant",
}
sample = {}
messages = []
tools = raw_sample.get("tools", "")
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -141,16 +151,20 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
}
)
sample["messages"] = messages
tools = raw_sample.get("tools")
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return {"messages": messages}
return sample
@DataConverterPlugin("pair").register
@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.
@@ -166,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample:
def process_message(raw_messages: list[OpenaiMessage]):
messages = []
for message in raw_messages:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
if message["role"] == "tool":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": message["role"],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
else:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "value": message["content"]}],
"loss_weight": 1.0 if message["role"] == "assistant" else 0.0,
}
)
return messages
chosen_messages = process_message(raw_sample.get("chosen", []))
rejected_messages = process_message(raw_sample.get("rejected", []))
sample = {}
sample["chosen_messages"] = process_message(raw_sample.get("chosen", []))
sample["rejected_messages"] = process_message(raw_sample.get("rejected", []))
return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages}
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
sample["tools"] = json.dumps(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
return sample

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
@DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
return data_index
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
def select_data_sample(
data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select(
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin
class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()
@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -24,16 +24,17 @@ Init Phase:
import importlib
from pathlib import Path
from ....utils.logging import get_logger
from ....utils import logging
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +78,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
@@ -86,7 +87,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
@@ -99,34 +100,41 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
@KernelPlugin("auto").register()
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.
Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
if not include_kernels:
return model
elif include_kernels == "auto" or include_kernels is True:
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")
apply_kernel(kernel, model=model)
return model

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
@@ -324,7 +324,7 @@ class NpuFusedMoEKernel(BaseKernel):
if not cls.check_deps():
raise RuntimeError("torch_npu is not available but NpuMoEFusedMoEKernel was called.")
archs = getattr(model.config, "architectures", [])
archs = getattr(model.config, "architectures", None) or []
target_moe_mapping = None
for arch in archs:
if arch in kernel_moe_mapping:
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None:
return model
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU.
"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU.
"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.

Some files were not shown because too many files have changed in this diff Show More