37 Commits

Author SHA1 Message Date
Kingsley
833f6027b1 [fix] fit neat_packing & mrope model packing (#10283)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-20 16:50:11 +08:00
robertglools
d91d8af89e [data] add SGSC zero-hallucination B2B dataset (NOO-Protocol) (#10284)
Co-authored-by: GloolsGuan <GloolsGuan@gmail.com>
2026-03-20 15:49:03 +08:00
xxddccaa
e67ab9e2f2 fix:MiniCPMVPlugin IndexError in process_messages when training with video (#10276)
Co-authored-by: xxddccaa <xxddccaa@users.noreply.github.com>
2026-03-18 19:18:06 +08:00
LincolnBurrows2017
2c4f121817 [fix] handle empty content list in system message (#10291)
Co-authored-by: AI Assistant <assistant@example.com>
2026-03-18 12:05:49 +08:00
xvxuopop
487f8b8191 [v1] add qwen3 templates and fix rendering plugin. (#10212)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-18 11:30:50 +08:00
SnowCharm
78cad1e332 [fix] unused keys in ray example (#10290) 2026-03-18 00:23:53 +08:00
LincolnBurrows2017
70653026f5 [fix] make position_id_per_seconds configurable for Qwen2OmniPlugin (#10281)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
2026-03-16 19:42:38 +08:00
Ruijie Hou
246192abd2 [data] correct gpt_oss template format_assistant (#10269) 2026-03-10 21:36:38 +08:00
浮梦
0258dc14d0 [docker] update npu docker (#10268)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-10 19:37:27 +08:00
xxddccaa
3045adf0ba [fix] fallback to audio_processor when feature_extractor is missing (#10267)
Co-authored-by: kevin <742971636@qq.com>
2026-03-10 19:36:41 +08:00
Kingsley
a3d44e3152 [mca] support qwen3.5 (#10265)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-10 10:55:16 +08:00
JiangNan
edeb953bc7 [data] convert filter() to list in read_cloud_json to fix broken empty-check (#10260)
Signed-off-by: JiangNan <1394485448@qq.com>
2026-03-09 17:12:53 +08:00
yizhouChen
d045794387 [docs] fix Python version requirement from 3.10 to >=3.11.0 (#10259)
Co-authored-by: chaiyzh <chaiyzh@126.com>
2026-03-09 16:44:07 +08:00
pyx
9501c3308a [train] fix compatibility issue with HuggingFace Dataset Column when sav… (#10254)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-06 18:44:57 +08:00
jiaqiw09
0ee1c42c2b [v1] Support meta loading for full and free (#10236) 2026-03-05 23:15:27 +08:00
SnowCharm
3061f48d55 [ray] fix get ray head ip (#10252) 2026-03-05 23:14:38 +08:00
LittleYanlin
2d9bd2aa14 [fix] qwen3.5 projector path (#10242)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-04 01:31:09 +08:00
Hertz
c0245c43fc [model] support Qwen3.5 all series models (#10237)
Co-authored-by: gatilin <gatilin@tencent.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-03 17:34:59 +08:00
Parag Ekbote
eb976d75a2 [tracker] Add Trackio Integration for LlamaFactory (#10165)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-03 17:19:37 +08:00
Yaowei Zheng
b5cb7cb0e6 [misc] fix constants (#10232) 2026-03-02 11:10:48 +08:00
Philip Ottesen
0779846513 [infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
2026-02-28 20:26:53 +08:00
jiaqiw09
45d335c709 [v1] add seed for training and fix gradient checkpointing (#10211) 2026-02-28 18:16:06 +08:00
Kingsley
816480012f [fix] register visual part for Qwen3.5 (#10227) 2026-02-28 16:39:24 +08:00
Mikko Tukiainen
d3bf882e87 [docker] upgrade to ROCm 7.2 base image, drop PyTorch reinstall (#10223)
Co-authored-by: Mikko Tukiainen <mtukiain@chi-mi300x-012.ord.vultr.cpe.ice.amd.com>
2026-02-27 20:16:33 +08:00
娄宗志
589da21d32 [model] support Aeva (#10214) 2026-02-26 23:03:13 +08:00
Yaowei Zheng
122cd46084 [model] update constants (#10220) 2026-02-26 21:13:56 +08:00
浮梦
2b8b871475 [model] Adapt Qwen3.5 (#10213)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-26 20:45:02 +08:00
Shanay Mehta
aab9b400bb [model] Add DeepSpeed Z3 leaf module for Qwen3-Next (#10194)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 19:54:37 +08:00
P. Clawmogorov
50599c719b [misc] remove safe_serialization arg for transformers v5 compatibility (#10208)
Co-authored-by: P. Clawmogorov <262173731+Alm0stSurely@users.noreply.github.com>
2026-02-24 11:14:19 +08:00
Kingsley
a0f3ad0cee [mca] update supported models (#10196) 2026-02-20 22:02:49 +08:00
jiaqiw09
f80e15dbb4 [ci] fix ut huggingface hub 429 error when transformers>=5.0.0 (#10155) 2026-02-12 22:14:10 +08:00
sunyi0505
991267fd3b [v1] support quantization (#10161) 2026-02-12 20:37:41 +08:00
浮梦
5c52afa30d [v1] support deepspeed (#10181) 2026-02-12 17:24:30 +08:00
Junyou Su
675ce8cc7f [algo] add ASFT (#10174) 2026-02-12 13:12:14 +08:00
jiaqiw09
ab073f4c13 [v1] add LoRA/Freeze support and merge workflow (#10157) 2026-02-12 13:02:09 +08:00
Shanay Mehta
184304b5b4 [model] add liger kernel support for Qwen3-Next (#10176)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 21:47:48 +08:00
Xue Yadong
d3ebd5678d [model] support GLM-OCR SFT (#10183) 2026-02-10 21:41:01 +08:00
78 changed files with 3169 additions and 506 deletions

View File

@@ -25,16 +25,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.10'
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
- name: Build Sphinx - name: Build Sphinx
run: | run: |
sphinx-build -b html docs/zh docs/_build/html/zh sphinx-build -b html docs/zh docs/_build/html/zh
@@ -56,10 +56,10 @@ jobs:
> docs/_build/html/index.html > docs/_build/html/index.html
touch docs/_build/html/.nojekyll touch docs/_build/html/.nojekyll
- name: Setup Pages - name: Setup Pages
uses: actions/configure-pages@v5 uses: actions/configure-pages@v5
- name: Upload artifact - name: Upload artifact
uses: actions/upload-pages-artifact@v3 uses: actions/upload-pages-artifact@v3
with: with:

View File

@@ -35,15 +35,12 @@ jobs:
transformers: transformers:
- "" - ""
include: # test backward compatibility include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.55.0" transformers: "4.55.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.57.1"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@@ -61,6 +61,7 @@ jobs:
uv venv uv venv
uv pip install -e . uv pip install -e .
uv pip install -r requirements/dev.txt uv pip install -r requirements/dev.txt
uv pip install -r requirements/bitsandbytes.txt
- name: Check quality - name: Check quality
run: | run: |

View File

@@ -291,7 +291,7 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -319,6 +319,7 @@ Read technical notes:
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -472,7 +473,7 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.11 | >=3.11 |
| torch | 2.0.0 | 2.6.0 | | torch | 2.0.0 | 2.6.0 |
| torchvision | 0.15.0 | 0.21.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.49.0 | 4.50.0 | | transformers | 4.49.0 | 4.50.0 |

View File

@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -474,7 +475,7 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.11 | >=3.11 |
| torch | 2.0.0 | 2.6.0 | | torch | 2.0.0 | 2.6.0 |
| torchvision | 0.15.0 | 0.21.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.49.0 | 4.50.0 | | transformers | 4.49.0 | 4.50.0 |

View File

@@ -236,6 +236,13 @@
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4", "ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"sgsc_b2b_entities": {
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
}
},
"ultrachat_200k": { "ultrachat_200k": {
"hf_hub_url": "HuggingFaceH4/ultrachat_200k", "hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k",

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags # https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11 ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ RUN pip uninstall -y torch torchvision torchaudio
pip install --no-cache-dir -e . --no-build-isolation && \ RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile dockerfile: ./docker/docker-npu/Dockerfile
context: ../.. context: ../..
args: args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3 container_name: llamafactory-a3
image: llamafactory:npu-a3 image: llamafactory:npu-a3

View File

@@ -1,12 +1,12 @@
# https://hub.docker.com/r/rocm/pytorch/tags # https://hub.docker.com/r/rocm/pytorch/tags
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0 # ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
# Define environments # Define environments
ENV MAX_JOBS=16 ENV MAX_JOBS=16
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
# Copy the application into the image # Copy the application into the image
COPY . /app COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip install --no-cache-dir -e . --pre && \
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -47,4 +47,3 @@
border-color: rgba(255, 255, 255, 0.45); border-color: rgba(255, 255, 255, 0.45);
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12); box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
} }

View File

@@ -1,33 +1,31 @@
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
import os
import sys
# Define common settings here # Define common settings here
project = 'LlamaFactory' project = "LlamaFactory"
copyright = '2024, LlamaFactory Team' copyright = "2024, LlamaFactory Team"
author = 'LlamaFactory Team' author = "LlamaFactory Team"
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'myst_parser', "myst_parser",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
html_static_path = ['_static'] html_static_path = ["_static"]
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]
html_css_files = [ html_css_files = [
'css/lang-switcher.css', "css/lang-switcher.css",
] ]
myst_enable_extensions = [ myst_enable_extensions = [

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'en' language = "en"
html_search_language = 'en' html_search_language = "en"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'zh_CN' language = "zh_CN"
html_search_language = 'zh' html_search_language = "zh"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Llama-2-7b
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.1
### dataset
dataset: med
template: llama2
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/llama2-7b/full/asft2
logging_steps: 1
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 2.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -0,0 +1,45 @@
### model
model_name_or_path: models/Qwen2.5-7B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z0_config.json
use_asft_loss: true
asft_alpha: 0.05
### dataset
dataset: math
template: qwen
cutoff_len: 2048
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 4
### output
output_dir: saves/qwen2-7b/full/asft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 4
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500

View File

@@ -28,12 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use. ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker:
GPU: 1
# ray_init_kwargs: # ray_init_kwargs:
# runtime_env: # runtime_env:
# env_vars: # env_vars:

View File

@@ -0,0 +1,38 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# Freeze Configuration
peft_config:
name: freeze
freeze_trainable_layers: 2 # Train the last 2 layers
freeze_trainable_modules: all # In these layers, train specific modules
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_freeze
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,24 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
dist_config:
name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -14,16 +14,12 @@ dist_config:
name: fsdp2 name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: false bf16: false

View File

@@ -0,0 +1,7 @@
model: Qwen/Qwen3-4B
peft_config:
name: lora
adapter_name_or_path: ./outputs/test_lora
export_dir: ./merge_lora_model
export_size: 5
infer_dtype: auto

View File

@@ -0,0 +1,39 @@
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,43 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
# Quantization Config
quant_config:
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
quantization_bit: 4 # choice: 8/4(bnb)
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0", "transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1", "peft>=0.18.0,<=0.18.1",

View File

@@ -1 +1 @@
liger-kernel>=0.5.5 liger-kernel>=0.6.3

View File

@@ -1,4 +1,4 @@
torch==2.7.1 torch==2.7.1
torch-npu==2.7.1 torch-npu==2.7.1.post2
torchvision==0.22.1 torchvision==0.22.1
torchaudio==2.7.1 torchaudio==2.7.1

View File

@@ -71,6 +71,7 @@ def convert(
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1, expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None, virtual_pipeline_model_parallel_size: int | None = None,
moe_grouped_gemm: bool | None = None,
): ):
"""Convert checkpoint between MCA and HuggingFace formats. """Convert checkpoint between MCA and HuggingFace formats.
@@ -84,6 +85,10 @@ def convert(
pipeline_model_parallel_size: Pipeline model parallel size pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
Must match the format used when saving the checkpoint.
""" """
if bf16 and fp16: if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.") raise ValueError("bf16 and fp16 cannot be both True.")
@@ -97,8 +102,9 @@ def convert(
pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
moe_grouped_gemm=moe_grouped_gemm,
transformer_impl="transformer_engine", # hard code here since we default using te for training
) )
convert_checkpoint_to_mca( convert_checkpoint_to_mca(
checkpoint_path, checkpoint_path,
output_path, output_path,

View File

@@ -154,25 +154,24 @@ def vllm_infer(
batch = train_dataset[i : min(i + batch_size, len(train_dataset))] batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])): for j in range(len(batch["input_ids"])):
multi_modal_data = {}
video_metadata_kwargs = None
if batch["images"][j] is not None: if batch["images"][j] is not None:
image = batch["images"][j] image = batch["images"][j]
multi_modal_data = { multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
"image": template_obj.mm_plugin._regularize_images( image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels )["images"]
)["images"]
} if batch["videos"][j] is not None:
elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
"video": template_obj.mm_plugin._regularize_videos( video,
video, image_max_pixels=image_max_pixels,
image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels,
image_min_pixels=image_min_pixels, video_fps=video_fps,
video_fps=video_fps, video_maxlen=video_maxlen,
video_maxlen=video_maxlen, )["videos"]
)["videos"]
}
if need_video_kwargs: if need_video_kwargs:
container = av.open(video[0], "r") container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -192,18 +191,17 @@ def vllm_infer(
video_backend="opencv", video_backend="opencv",
) )
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None:
if batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
audio, audio,
sampling_rate=16000, sampling_rate=16000,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: if video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data) vllm_inputs.append(vllm_input_data)

View File

@@ -88,7 +88,10 @@ def _process_request(
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content if isinstance(content, list):
system = content[0].text if content else ""
else:
system = content
else: else:
system = None system = None

View File

@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
else self.generating_args["skip_special_tokens"], else self.generating_args["skip_special_tokens"],
) )
multi_modal_data = {}
if images is not None: # add image features if images is not None: # add image features
multi_modal_data = { multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
"image": self.template.mm_plugin._regularize_images( images,
images, image_max_pixels=self.model_args.image_max_pixels,
image_max_pixels=self.model_args.image_max_pixels, image_min_pixels=self.model_args.image_min_pixels,
image_min_pixels=self.model_args.image_min_pixels, )["images"]
)["images"]
} if videos is not None:
elif videos is not None: multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
multi_modal_data = { videos,
"video": self.template.mm_plugin._regularize_videos( image_max_pixels=self.model_args.video_max_pixels,
videos, image_min_pixels=self.model_args.video_min_pixels,
image_max_pixels=self.model_args.video_max_pixels, video_fps=self.model_args.video_fps,
image_min_pixels=self.model_args.video_min_pixels, video_maxlen=self.model_args.video_maxlen,
video_fps=self.model_args.video_fps, )["videos"]
video_maxlen=self.model_args.video_maxlen,
)["videos"] if audios is not None:
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios( audio_data = self.template.mm_plugin._regularize_audios(
audios, audios,
sampling_rate=self.model_args.audio_sampling_rate, sampling_rate=self.model_args.audio_sampling_rate,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,

View File

@@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -24,7 +26,7 @@ import torch.nn.functional as F
from peft import PeftModel from peft import PeftModel
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
@@ -38,6 +40,56 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
def _slice_mm_inputs_for_sample(
mm_inputs: dict[str, Any],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_idx: int,
images_per_subseq: Optional[list[int]] = None,
videos_per_subseq: Optional[list[int]] = None,
subseq_idx: Optional[int] = None,
) -> dict[str, Any]:
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
is given, further restrict to that sub-seq's counts via packed_*_counts.
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
"""
image_start_idx = sum(batch_imglens[:batch_idx])
image_end_idx = sum(batch_imglens[: batch_idx + 1])
video_start_idx = sum(batch_vidlens[:batch_idx])
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
if subseq_idx is not None and images_per_subseq is not None:
image_start_idx += sum(images_per_subseq[:subseq_idx])
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
if subseq_idx is not None and videos_per_subseq is not None:
video_start_idx += sum(videos_per_subseq[:subseq_idx])
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
sliced_mm_inputs: dict[str, Any] = {}
key_to_slice_meta = {
"image_grid_thw": (image_start_idx, image_end_idx, True),
"video_grid_thw": (video_start_idx, video_end_idx, True),
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
}
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
if key not in mm_inputs:
continue
mm_value = mm_inputs[key]
if mm_value is not None and end_idx > start_idx:
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
elif assign_none_when_empty:
sliced_mm_inputs[key] = None
return sliced_mm_inputs
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""Expand 2d attention mask to 4d attention mask. r"""Expand 2d attention mask to 4d attention mask.
@@ -105,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
else: else:
self.get_rope_func = None self.get_rope_func = None
def _compute_rope_position_ids(
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
) -> None:
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if features["attention_mask"].sum() == 0:
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
def _compute_rope_position_ids_with_packing(
self,
features: dict[str, "torch.Tensor"],
mm_inputs: dict[str, Any],
packing_params_list: list[dict[str, Any] | None],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_audlens: list[int],
has_dummy_image: bool,
) -> None:
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
bsz = features["input_ids"].size(0)
seq_len = features["input_ids"].size(1)
all_position_ids: list[torch.Tensor] = []
all_rope_deltas: list[torch.Tensor] = []
if has_dummy_image:
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs)
for sample_idx in range(bsz):
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
sequence_boundaries = sample_packing.get("sequence_boundaries")
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
images_per_subseq = (
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
)
videos_per_subseq = (
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
)
if has_dummy_image:
mm_inputs = {}
if num_sub_seqs <= 1:
sample_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
}
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
)
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
all_position_ids.append(sample_features["position_ids"])
all_rope_deltas.append(sample_features["rope_deltas"])
else:
# when we do packing, don't need rope_deltas when training.
sample_position_ids: list[torch.Tensor] = []
for subseq_idx in range(num_sub_seqs):
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_features = {
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
}
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
mm_inputs,
batch_imglens,
batch_vidlens,
sample_idx,
images_per_subseq,
videos_per_subseq,
subseq_idx
)
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
if has_dummy_image:
mm_inputs = dummy_mm_inputs
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
all_position_ids[0].size(0),
bsz,
seq_len,
)
# Check if position_ids shape matches expected shape.
# for further usage, we should padding to the right when some padding token on the right.
if has_dummy_image:
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
if features["position_ids"].shape != expected_position_ids_shape:
raise ValueError(
"Merged position_ids shape mismatch: "
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
packing_params_list: list[dict[str, Any] | None] = []
for feature in features: for feature in features:
images = feature.pop("images", None) or [] images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or [] videos = feature.pop("videos", None) or []
@@ -119,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_audlens.append(len(audios)) batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
packing_params_list.append(feature.pop("packing_params", None))
fake_input_ids = [] fake_input_ids = []
has_dummy_image = False
if ( if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case ): # avoid process hanging in zero3/fsdp case
@@ -136,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_input_ids.extend(_fake_input_ids) fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images batch_images = fake_images
batch_imglens[0] = 1 batch_imglens[0] = 1
has_dummy_image = True
if ( if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
@@ -182,45 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: dict[str, torch.Tensor] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
bsz, seq_len = features["input_ids"].shape[:2]
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
is_omni = model_type in [
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
]
if self.get_rope_func is not None: if self.get_rope_func is not None:
rope_index_kwargs = { # for mmrope situation, we should calculate position_ids and rope_deltas per sample.
"input_ids": features["input_ids"], # When neat_packing is on, each sample has packing_params; None means no packing for that sample.
"image_grid_thw": mm_inputs.get("image_grid_thw"), boundaries_list = [
"video_grid_thw": mm_inputs.get("video_grid_thw"), p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
"attention_mask": (features["attention_mask"] >= 1).float(), ]
} has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if "second_per_grid_ts" in mm_inputs: # for qwen2vl if has_dummy_image and has_packing:
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") # FIXME: too tricky, need to be refactored
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni features["has_dummy_image"] = True
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]: # When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) if not has_packing:
feature_attention_mask = mm_inputs.get("feature_attention_mask", None) self._compute_rope_position_ids(features, mm_inputs)
if feature_attention_mask is not None: # FIXME: need to get video image lengths else:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) if is_omni:
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input raise RuntimeError("Omni models are not supported for packed sequences for now.")
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) self._compute_rope_position_ids_with_packing(
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( features,
dim=-1 mm_inputs,
).unsqueeze(-1) packing_params_list,
else: # for qwen vl batch_imglens,
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) batch_vidlens,
batch_audlens,
has_dummy_image,
)
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
if features["position_ids"].dim() == 3:
features["position_ids"] = torch.cat(
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
)
if ( if (
self.model is not None self.model is not None
and getattr(self.model.config, "model_type", None) and getattr(self.model.config, "model_type", None) in MROPE_MODELS
in [
"glm4v",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3) and ("position_ids" not in features or features["position_ids"].dim() != 3)
): ):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.") raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
@@ -248,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
neat_packing: bool = False
def __post_init__(self):
super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
@staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None:
r"""Trim padded positions for packed FA2 batches."""
attention_mask = features.get("attention_mask")
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
return
seq_len = attention_mask.size(1)
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
if non_padding_indices.numel() == seq_len:
return
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
for key, value in list(features.items()):
if not torch.is_tensor(value):
continue
if key == "position_ids" and value.size(-1) == seq_len:
features[key] = value.index_select(-1, non_padding_indices)
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
has_dummy_image = features.pop("has_dummy_image", False)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
if not has_dummy_image:
self._unpad_packed_features(features)
features["attention_mask"] = None # let transformers handle causal packed mask.
for key, value in features.items(): # cast data dtype for paligemma for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value): if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype) features[key] = value.to(self.compute_dtype)

View File

@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
# filter out non-JSON files # filter out non-JSON files
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path] files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files) files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
if not files: if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.") raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")

View File

@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
from transformers.models.mllama.processing_mllama import ( from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from transformers.video_utils import make_batched_videos
from typing_extensions import override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -47,13 +48,6 @@ if is_pyav_available():
import av import av
if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
else:
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
if TYPE_CHECKING: if TYPE_CHECKING:
from av.stream import Stream from av.stream import Stream
from numpy.typing import NDArray from numpy.typing import NDArray
@@ -161,7 +155,9 @@ class MMPluginMixin:
video_processor: BaseImageProcessor = getattr( video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None) processor, "video_processor", getattr(processor, "image_processor", None)
) )
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
@@ -390,7 +386,9 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt")) mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0: if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True, chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False): if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs}) mm_inputs.update({"audio_phs": audio_phs})
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
) )
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = ( video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0]) torch.arange(video_grid_thw[num_video_tokens][0])
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) )
.flatten() .flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens] * mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25 * position_id_per_seconds
).long() ).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] t_ntoken_per_chunk = position_id_per_seconds * 2
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = "" placeholder_string = ""

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
"""
sequence_boundaries: list[int]
image_subseq_ids: list[int]
video_subseq_ids: list[int]
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass @dataclass
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
valid_num += 1 valid_num += 1
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
requires_packing_params = self.data_args.neat_packing
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], [] packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], [] packed_images, packed_videos, packed_audios = [], [], []
if requires_packing_params:
sequence_boundaries = [0]
image_subseq_ids: list[int] = []
video_subseq_ids: list[int] = []
audio_subseq_ids: list[int] = []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_images += batch_images[index] packed_images += batch_images[index]
packed_videos += batch_videos[index] packed_videos += batch_videos[index]
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if requires_packing_params:
n_img = len(batch_images[index])
n_vid = len(batch_videos[index])
n_aud = len(batch_audios[index])
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
image_subseq_ids.extend([i] * n_img)
video_subseq_ids.extend([i] * n_vid)
audio_subseq_ids.extend([i] * n_aud)
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
if requires_packing_params:
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
if len(packed_input_ids) != self.data_args.cutoff_len + 1: if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.") raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,
)
model_inputs["packing_params"].append(asdict(packing_params))
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids) model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)

View File

@@ -1061,6 +1061,22 @@ register_template(
) )
# copied from glm4 template
register_template(
name="glm_ocr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
)
# copied from glm4_moe template # copied from glm4_moe template
register_template( register_template(
name="glm4_7", name="glm4_7",
@@ -1097,7 +1113,7 @@ register_template(
register_template( register_template(
name="gpt_oss", name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
default_system="You are ChatGPT, a large language model trained by OpenAI.", default_system="You are ChatGPT, a large language model trained by OpenAI.",
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
@@ -2013,6 +2029,39 @@ register_template(
) )
register_template(
name="qwen3_5",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_5_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
@@ -2202,3 +2251,24 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.", default_system="You are Zephyr, a helpful assistant.",
) )
# copied from glm4_7 template
register_template(
name="aeva",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"You are an AI assistant named Aeva created by Zongzhi Lou. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|user|>", "<|observation|>"],
thought_words=("<think>", "</think>"),
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)

View File

@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>""" """"arguments": <args-json-object>}}\n</tool_call>"""
) )
QWEN35_TOOL_PROMPT = (
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with your current knowledge "
"and do not tell the user about function calls\n</IMPORTANT>"
)
SEED_TOOL_PROMPT = ( SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing " "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
return results return results
class Qwen35ToolUtils(ToolUtils):
r"""Qwen 3.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f"<tool_call>\n<function={name}>"
for key, value in arguments.items():
prompt += f"\n<parameter={key}>"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += f"\n{value}\n</parameter>"
prompt += "\n</function>\n</tool_call>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
results = []
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
for func_name, params_block in re.findall(regex, content):
args_dict = {}
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block.strip()):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value.strip()
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
class GLM4MOEToolUtils(QwenToolUtils): class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template.""" r"""GLM-4-MOE tool using template."""
@@ -662,6 +728,7 @@ TOOLS = {
"minimax2": MiniMaxM2ToolUtils(), "minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(), "mistral": MistralToolUtils(),
"qwen": QwenToolUtils(), "qwen": QwenToolUtils(),
"qwen3_5": Qwen35ToolUtils(),
"glm4_moe": GLM4MOEToolUtils(), "glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(), "seed_oss": SeedToolUtils(),
"ling": LingToolUtils(), "ling": LingToolUtils(),

View File

@@ -65,15 +65,32 @@ MCA_SUPPORTED_MODELS = {
"qwen2_vl", "qwen2_vl",
"qwen2_5_vl", "qwen2_5_vl",
"qwen3_vl", "qwen3_vl",
"qwen3_vl_moe",
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
"qwen3_5",
"qwen3_5_moe",
} }
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MROPE_MODELS = {
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
}
MULTIMODAL_SUPPORTED_MODELS = set() MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora", "oft"} PEFT_METHODS = {"lora", "oft"}
@@ -950,6 +967,18 @@ register_model_group(
) )
register_model_group(
models={
"GLM-OCR": {
DownloadSource.DEFAULT: "zai-org/GLM-OCR",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-OCR",
},
},
template="glm_ocr",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"GLM-Z1-0414-9B-Chat": { "GLM-Z1-0414-9B-Chat": {
@@ -2797,6 +2826,66 @@ register_model_group(
) )
register_model_group(
models={
"Qwen3.5-0.8B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
},
"Qwen3.5-2B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
},
"Qwen3.5-4B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
},
"Qwen3.5-9B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
},
"Qwen3.5-35B-A3B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
},
"Qwen3.5-0.8B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
},
"Qwen3.5-2B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
},
"Qwen3.5-4B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
},
"Qwen3.5-9B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
},
"Qwen3.5-27B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
},
"Qwen3.5-35B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
},
"Qwen3.5-122B-A10B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
},
"Qwen3.5-397B-A17B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
},
},
template="qwen3_5",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen2-Audio-7B": { "Qwen2-Audio-7B": {
@@ -3438,3 +3527,35 @@ register_model_group(
}, },
template="zephyr", template="zephyr",
) )
register_model_group(
models={
"Aeva-Flash-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
},
"Aeva-Air-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
},
"Aeva-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
DownloadSource.OPENMIND: "louzongzhi/Aeva",
},
"Aeva-Pro-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
},
"Aeva-Max-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
},
},
template="aeva",
)

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.0.0") check_version("transformers>=4.55.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1") check_version("peft>=0.18.0,<=0.18.1")

View File

@@ -490,6 +490,14 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether to use the DFT loss."}, metadata={"help": "Whether to use the DFT loss."},
) )
use_asft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the ASFT loss."},
)
asft_alpha: float = field(
default=0.1,
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
)
use_eaft_loss: bool = field( use_eaft_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether to use the EAFT loss."}, metadata={"help": "Whether to use the EAFT loss."},

View File

@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than from ..extras.packages import is_mcore_adapter_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -100,6 +100,52 @@ def _parse_args(
return tuple(parsed_args) return tuple(parsed_args)
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
"""Validates Trackio-specific arguments.
Args:
training_args: TrainingArguments instance (not a dictionary)
"""
report_to = training_args.report_to
if not report_to:
return
if isinstance(report_to, str):
report_to = [report_to]
if "trackio" not in report_to:
return
# --- Enforce project (required by Trackio) ---
if not training_args.project:
raise ValueError("`--project` must be specified when using Trackio.")
# --- Validate trackio_space_id format ---
space_id = training_args.trackio_space_id
if space_id:
if space_id != "trackio" and "/" not in space_id:
logger.warning(
f"trackio_space_id '{space_id}' should typically be in format "
"'org/space' for Hugging Face Spaces deployment."
)
# --- Inform about default project usage ---
if training_args.project == "huggingface":
logger.info(
"Using default project name 'huggingface'. "
"Consider setting a custom project name with --project "
"for better organization."
)
# --- Validate hub repo privacy flag ---
if training_args.hub_private_repo:
logger.info("Repository will be created as private on Hugging Face Hub.")
# --- Recommend run_name for experiment clarity ---
if not training_args.run_name:
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]: if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: if training_args.report_to and any(
raise ValueError("PPO only accepts wandb or tensorboard logger.") logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
):
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
@@ -346,12 +394,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if model_args.use_kt and is_deepspeed_zero3_enabled(): if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.") raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars() _set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
_verify_trackio_args(training_args)
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
@@ -421,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None training_args.resume_from_checkpoint is None
and training_args.do_train and training_args.do_train
and os.path.isdir(training_args.output_dir) and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)

View File

@@ -77,6 +77,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe": elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "qwen3_next":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
elif model_type == "gpt_oss": elif model_type == "gpt_oss":
try: try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel

View File

@@ -142,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
if model_type == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -37,7 +37,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -45,10 +44,6 @@ import torch.nn.functional as F
from ...extras import logging from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -24,7 +24,6 @@ import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -239,6 +238,15 @@ _register_composite_model(
) )
_register_composite_model(
model_type="glm_ocr",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="internvl", model_type="internvl",
) )
@@ -335,9 +343,7 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -346,9 +352,7 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -381,7 +385,25 @@ _register_composite_model(
"visual.deepstack_merger_list", "visual.deepstack_merger_list",
"audio_tower", "audio_tower",
], ],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_5",
projector_key="model.visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_5_moe",
projector_key="model.visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
@@ -142,7 +141,6 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs) configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable) configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":

View File

@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and getattr(args, "overwrite_output_dir", False)
): ):
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
} }
) )
if "trackio" in args.report_to:
import trackio
trackio.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab: if self.finetuning_args.use_swanlab:
import swanlab # type: ignore import swanlab # type: ignore

View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -77,12 +79,43 @@ def _data_collator_wrapper(data_collator: Any):
def _check_model_support(model_args: "ModelArguments"): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
model_type = config.model_type
config = HfAutoConfig.from_pretrained( if model_type not in MCA_SUPPORTED_MODELS:
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code raise ValueError(
) f"Model {model_type} is not supported by mcore_adapter."
if config.model_type not in MCA_SUPPORTED_MODELS: "You can try to upgrade mcore_adapter to the latest version for more supported models."
raise ValueError(f"Model {config.model_type} is not supported by MCA.") )
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
def run_pt( def run_pt(
@@ -161,22 +194,8 @@ def run_sft(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl # optional freezing for qwen_vl series
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]: _freeze_model_parameters(model, finetuning_args)
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
@@ -229,6 +248,8 @@ def run_dpo(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
_freeze_model_parameters(model, finetuning_args)
if finetuning_args.use_ref_model: if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config) ref_model = AutoModel.from_config(ref_config)

View File

@@ -17,6 +17,7 @@
import json import json
import os import os
from functools import partial
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None, model_args: Optional["ModelArguments"] = None,
gen_kwargs: Optional[dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
ref_model: Optional["torch.nn.Module"] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
self.ref_model = ref_model
if ref_model is not None:
from trl.models.utils import prepare_deepspeed, prepare_fsdp
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_dft_loss: if finetuning_args.use_dft_loss:
from ..trainer_utils import dft_loss_func from ..trainer_utils import dft_loss_func
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
) )
elif finetuning_args.use_asft_loss:
from ..trainer_utils import asft_loss_func
self.compute_loss_func = partial(
asft_loss_func,
asft_alpha=finetuning_args.asft_alpha,
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
@override @override
def compute_loss(self, model, inputs, *args, **kwargs): def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs) if self.finetuning_args.use_asft_loss:
with torch.no_grad():
ref_outputs = self.ref_model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)
ref_logits = ref_outputs.logits
outputs = model(**inputs)
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
else:
return super().compute_loss(model, inputs, *args, **kwargs)
@override @override
def prediction_step( def prediction_step(
@@ -175,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
if len(pad_len): # move pad token to last if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) input_ids_column = dataset["input_ids"]
try:
input_ids_list = input_ids_column.to_pylist()
except AttributeError:
input_ids_list = list(input_ids_column)
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

View File

@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer from .trainer import CustomSeq2SeqTrainer
@@ -52,6 +52,10 @@ def run_sft(
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
if getattr(model, "is_quantized", False) and not training_args.do_train: if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
@@ -61,6 +65,7 @@ def run_sft(
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,
neat_packing=data_args.neat_packing,
attn_implementation=getattr(model.config, "_attn_implementation", None), attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
**tokenizer_module, **tokenizer_module,
@@ -124,6 +129,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
gen_kwargs=gen_kwargs, gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module, **dataset_module,
**tokenizer_module, **tokenizer_module,
**metric_module, **metric_module,

View File

@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
@@ -51,6 +52,7 @@ if is_ray_available():
import ray import ray
from ray.util.placement_group import PlacementGroup, placement_group from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.state import list_nodes
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -681,6 +683,88 @@ def _dft_cross_entropy(
return loss return loss
def asft_loss_func(
outputs,
labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
# shift for causal LM
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
vocab_size = shift_logits.size(-1)
# flatten
shift_logits = shift_logits.view(-1, vocab_size)
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1).to(shift_logits.device)
return _asft_cross_entropy(
policy_logits=shift_logits,
policy_labels=shift_labels,
ref_logits=shift_ref_logits,
asft_alpha=asft_alpha,
ignore_index=ignore_index,
)
def _asft_cross_entropy(
policy_logits: torch.Tensor,
policy_labels: torch.Tensor,
ref_logits: torch.Tensor,
asft_alpha: float = 0.1,
ignore_index: int = -100,
) -> torch.Tensor:
dft_loss = _dft_cross_entropy(
policy_logits,
policy_labels,
ignore_index=ignore_index,
)
kl_loss = _kl_divergence(
policy_logits,
ref_logits,
policy_labels,
ignore_index=ignore_index,
)
return dft_loss + asft_alpha * kl_loss
def _kl_divergence(
policy_logits: torch.Tensor,
ref_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
# log p(y|x)
log_p = F.log_softmax(policy_logits, dim=-1)
# q(y|x)
q = F.softmax(ref_logits, dim=-1)
# token-wise KL
kl = F.kl_div(
log_p,
q,
reduction="none",
).sum(dim=-1) # [N]
# mask padding tokens
mask = (labels != ignore_index).float()
return (kl * mask).sum() / mask.sum()
def eaft_loss_func( def eaft_loss_func(
outputs: "torch.Tensor", outputs: "torch.Tensor",
labels: "torch.Tensor", labels: "torch.Tensor",
@@ -858,7 +942,7 @@ def get_ray_remote_config_for_worker(
def get_ray_head_node_ip() -> str: def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node.""" r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False)) head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
return head_ip return head_ip

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
model = model.to(output_dtype) model = model.to(output_dtype)
logger.info_rank0(f"Convert model dtype to: {output_dtype}.") logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained( # Prepare save arguments (safe_serialization removed in transformers v5.0.0)
save_directory=model_args.export_dir, save_kwargs = {
max_shard_size=f"{model_args.export_size}GB", "save_directory": model_args.export_dir,
safe_serialization=(not model_args.export_legacy_format), "max_shard_size": f"{model_args.export_size}GB",
) }
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.save_pretrained(**save_kwargs)
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
push_kwargs = {
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.push_to_hub( model.push_to_hub(
model_args.export_hub_model_id, model_args.export_hub_model_id,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
max_shard_size=f"{model_args.export_size}GB", **push_kwargs,
safe_serialization=(not model_args.export_legacy_format),
) )
if finetuning_args.stage == "rm": if finetuning_args.stage == "rm":

View File

@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ..utils.env import is_env_enabled from ..utils.env import is_env_enabled
from ..utils.helper import set_seed
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -66,7 +66,7 @@ class TrainingArguments:
metadata={"help": "Number of workers for batching."}, metadata={"help": "Number of workers for batching."},
) )
enable_activation_checkpointing: bool = field( enable_activation_checkpointing: bool = field(
default=True, default=False,
metadata={"help": "Enable activation checkpointing for training."}, metadata={"help": "Enable activation checkpointing for training."},
) )
dist_config: PluginConfig | None = field( dist_config: PluginConfig | None = field(
@@ -81,6 +81,10 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "Learning rate scheduler configuration for training."}, metadata={"help": "Learning rate scheduler configuration for training."},
) )
seed: int = field(
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -76,19 +76,28 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False}) self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None: self._deepspeed_engine = None
shard_need_optimizer = self.args.dist_config.name == "deepspeed" dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
else:
shard_need_optimizer = False
if shard_need_optimizer: if dist_name == "deepspeed":
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self._deepspeed_engine = DistributedPlugin("deepspeed")(
self.model,
self.args.dist_config,
num_micro_batch=self.train_batch_generator.num_micro_batch,
micro_batch_size=self.args.micro_batch_size,
)
self._init_optimizer() self._init_optimizer()
self._shard_model() self._init_lr_scheduler()
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
self.model, self.optimizer, self.lr_scheduler
)
else: else:
# fsdp2 / DDP / no dist
self._shard_model() self._shard_model()
self._init_optimizer() self._init_optimizer()
self._init_lr_scheduler()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None: def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator( self.train_batch_generator = BatchGenerator(
@@ -99,6 +108,7 @@ class BaseTrainer:
cutoff_len=self.args.cutoff_len, cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers, batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy, batching_strategy=self.args.batching_strategy,
seed=self.args.seed,
) )
def _shard_model(self) -> None: def _shard_model(self) -> None:
@@ -171,25 +181,35 @@ class BaseTrainer:
step_loss = 0 step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches) step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches: num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch) loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch]) mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size # fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward() if self._deepspeed_engine is not None:
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
self._deepspeed_engine.backward(loss)
else:
loss.backward()
step_loss += loss.item() step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() if self._deepspeed_engine is not None:
# deepspeed: engine.step() already ran inside backward at the sync boundary
# isfinite(): argument 'input' (position 1) must be Tensor, not float grad_norm = self._deepspeed_engine.get_grad_norm()
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else: else:
self.optimizer.step() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
self.lr_scheduler.step() # isfinite(): argument 'input' (position 1) must be Tensor, not float
self.optimizer.zero_grad() if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync() DistributedInterface().sync()
@@ -203,7 +223,14 @@ class BaseTrainer:
def save_model(self) -> None: def save_model(self) -> None:
"""Save the model.""" """Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
model_to_save.save_pretrained(self.args.output_dir) from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}") DistributedPlugin(self.args.dist_config.name).save_model(
self.model, self.args.output_dir, self.renderer.processor
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -90,6 +90,26 @@ class ModelEngine:
Transformers can choose the proper model init context. Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
""" """
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
init_kwargs = {"device_map": init_device}
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
init_kwargs=init_kwargs,
config=self.model_config,
tokenizer=self.processor,
model_args=self.args,
is_trainable=self.is_train,
)
if self.args.model_class == ModelClass.LLM: if self.args.model_class == ModelClass.LLM:
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
@@ -107,14 +127,8 @@ class ModelEngine:
AutoClass = AutoModel AutoClass = AutoModel
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META: if init_device.type == DeviceType.META:
assert self.args.quant_config is None, "Quantization is not supported with meta device."
with init_empty_weights(): with init_empty_weights():
model = AutoClass.from_config(self.model_config) model = AutoClass.from_config(self.model_config)
else: else:
@@ -122,8 +136,8 @@ class ModelEngine:
self.args.model, self.args.model,
config=self.model_config, config=self.model_config,
dtype="auto", dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
) )
if self.args.peft_config is None: if self.args.peft_config is None:

View File

@@ -26,6 +26,7 @@
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
import torch
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True, pin_memory: bool = True,
drop_last: bool = True, drop_last: bool = True,
seed: int = 42,
) -> None: ) -> None:
self.dataset = dataset self.dataset = dataset
self.renderer = renderer self.renderer = renderer
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
self.batching_strategy = batching_strategy self.batching_strategy = batching_strategy
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.drop_last = drop_last self.drop_last = drop_last
self.seed = seed
# TODO: support length and infinity # TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP) dp_size = DistributedInterface().get_world_size(Dim.DP)
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
num_replicas=DistributedInterface().get_world_size(Dim.DP), num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP), rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True, shuffle=True,
seed=0, seed=self.seed,
drop_last=self.drop_last, drop_last=self.drop_last,
) )
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch, batch_size=self.micro_batch_size * self.num_micro_batch,
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type, pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
generator=generato_seed,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider) self._length = len(self._data_provider)

View File

@@ -91,7 +91,11 @@ class Renderer:
self.processor = processor self.processor = processor
def render_messages( def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False self,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Apply template to messages and convert them to model input. """Apply template to messages and convert them to model input.
@@ -99,6 +103,7 @@ class Renderer:
messages (list[Message]): The messages to render. messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None. tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False. is_generate (bool, optional): Whether to render for generation. Defaults to False.
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
Returns: Returns:
ModelInput: The rendered model input. ModelInput: The rendered model input.
@@ -108,7 +113,9 @@ class Renderer:
else: else:
from ...plugins.model_plugins.rendering import RenderingPlugin from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) return RenderingPlugin(self.template).render_messages(
self.processor, messages, tools, is_generate, enable_thinking
)
def parse_message(self, generated_text: str) -> Message: def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format. """Parse a message in the template format.

View File

@@ -125,6 +125,11 @@ def launch():
run_chat() run_chat()
elif command == "merge":
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
merge_and_export_model()
elif command == "env": elif command == "env":
raise NotImplementedError("Environment information is not implemented yet.") raise NotImplementedError("Environment information is not implemented yet.")

View File

@@ -12,14 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Literal, TypedDict import re
from typing import Literal, TypedDict, Union
from peft import LoraConfig, PeftModel, get_peft_model import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from ...config import InputArgument, get_args
from ...core.model_engine import ModelEngine
from ...utils import logging
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import HFModel from ...utils.types import HFModel
logger = logging.get_logger(__name__)
class LoraConfigDict(TypedDict, total=False): class LoraConfigDict(TypedDict, total=False):
name: Literal["lora"] name: Literal["lora"]
"""Plugin name.""" """Plugin name."""
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
"""Lora rank.""" """Lora rank."""
lora_alpha: int lora_alpha: int
"""Lora alpha.""" """Lora alpha."""
target_modules: list[str] lora_dropout: float
"""Lora dropout."""
target_modules: Union[list[str], str]
"""Target modules.""" """Target modules."""
use_rslora: bool
"""Use RS-LoRA."""
use_dora: bool
"""Use DoRA."""
modules_to_save: list[str]
"""Modules to save."""
adapter_name_or_path: Union[list[str], str]
"""Path to the adapter(s)."""
export_dir: str
"""Path to the export directory."""
export_size: int
"""Shard size for the export model."""
export_hub_model_id: str
"""Hub model ID for the export model."""
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
"""Inference data type for the export model."""
export_legacy_format: bool
"""Use legacy format for the export model."""
class FreezeConfigDict(TypedDict, total=False): class FreezeConfigDict(TypedDict, total=False):
@@ -36,22 +64,283 @@ class FreezeConfigDict(TypedDict, total=False):
"""Plugin name.""" """Plugin name."""
freeze_trainable_layers: int freeze_trainable_layers: int
"""Freeze trainable layers.""" """Freeze trainable layers."""
freeze_trainable_modules: list[str] | None freeze_trainable_modules: Union[list[str], str]
"""Freeze trainable modules.""" """Freeze trainable modules."""
freeze_extra_modules: list[str]
"""Freeze extra modules."""
cast_trainable_params_to_fp32: bool
"""Cast trainable params to fp32."""
class PeftPlugin(BasePlugin): class PeftPlugin(BasePlugin):
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel: def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config) return super().__call__(model, config, is_train)
def _find_all_linear_modules(model: HFModel) -> list[str]:
r"""Find all available modules to apply LoRA."""
forbidden_modules = {"lm_head", "output_layer", "output"}
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
return list(module_names)
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
for adapter_path in adapter_name_or_path:
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
logger.info_rank0(f"Merged adapter from {adapter_path}")
return model
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
r"""Loads adapter(s) into the model.
Determine adapter usage based on mode:
- Training: Load the single adapter for continued training.
- Inference: Merge all adapters to clean up the model.
- Unmergeable: Keep the single adapter active without merging.
"""
if not isinstance(adapter_name_or_path, list):
adapter_name_or_path = [adapter_name_or_path]
# TODO
# Adapters fix for deepspeed and quant
# Adapters fix for vision
if is_train and len(adapter_name_or_path) > 1:
raise ValueError(
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
"Training will continue on the specified adapter. "
"Please merge multiple adapters before starting a new LoRA adapter."
)
if is_train:
adapter_to_merge = []
adapter_to_resume = adapter_name_or_path[0]
else:
adapter_to_merge = adapter_name_or_path
adapter_to_resume = None
if adapter_to_merge:
model = merge_adapters(model, adapter_to_merge)
if adapter_to_resume is not None:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
if is_train:
logger.info_rank0(
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
"LoRA hyperparameters will be loaded from the adapter itself; "
"the current LoRA configuration will be ignored. "
"Merge the adapter into the base model before training if you want to start a new adapter."
)
return model
@PeftPlugin("lora").register() @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
peft_config = LoraConfig(**config) if model.device.type == "meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path:
return load_adapter(model, adapter_name_or_path, is_train)
logger.info_rank0("Fine-tuning method: LoRA")
target_modules = config.get("target_modules", "all")
# Handle target modules
if target_modules == "all":
target_modules = _find_all_linear_modules(model)
elif isinstance(target_modules, str):
target_modules = [target_modules]
logger.info_rank0(f"LoRA target modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=not is_train,
r=config.get("r", 8),
lora_alpha=config.get("lora_alpha", 16),
lora_dropout=config.get("lora_dropout", 0.05),
use_rslora=config.get("use_rslora", False),
use_dora=config.get("use_dora", False),
target_modules=target_modules,
modules_to_save=config.get("modules_to_save", None),
)
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
if is_train:
model.print_trainable_parameters()
return model return model
@PeftPlugin("freeze").register() @PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel: def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
raise NotImplementedError() logger.info_rank0("Fine-tuning method: Freeze")
if not is_train:
return model
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
freeze_extra_modules = config.get("freeze_extra_modules", [])
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
if isinstance(freeze_trainable_modules, str):
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
if isinstance(freeze_extra_modules, str):
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
# Get number of layers
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if freeze_trainable_layers > 0:
# last n layers
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
else:
# first n layers
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
# Identify hidden and non-hidden modules
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name:
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
# Build list of trainable layer patterns
trainable_layers = []
for module_name in freeze_trainable_modules:
if module_name == "all":
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.")
elif module_name in hidden_modules:
for idx in trainable_layer_ids:
trainable_layers.append(f".{idx:d}.{module_name}")
else:
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
# Add extra modules
if freeze_extra_modules:
for module_name in freeze_extra_modules:
if module_name in non_hidden_modules:
trainable_layers.append(module_name)
else:
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
# TODO
# Multi-modal special handling
# Set requires_grad
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
param.requires_grad_(True)
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
else:
param.requires_grad_(False)
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
# Count trainable params for verification
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logger.info_rank0(
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
)
return model
def merge_and_export_model(args: InputArgument = None):
model_args, _, _, _ = get_args(args)
export_config = model_args.peft_config
if export_config is None:
raise ValueError("Please specify peft_config to merge and export model.")
export_dir = export_config.get("export_dir")
if export_dir is None:
raise ValueError("Please specify export_dir.")
export_size = export_config.get("export_size", 5)
export_hub_model_id = export_config.get("export_hub_model_id")
infer_dtype = export_config.get("infer_dtype", "auto")
export_legacy_format = export_config.get("export_legacy_format", False)
adapters = None
if export_config.get("name") == "lora":
adapters = export_config.get("adapter_name_or_path")
else:
raise ValueError("Currently merge and export model function is only supported for lora.")
if adapters is None:
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
logger.info_rank0("Loading model for export...")
model_engine = ModelEngine(model_args, is_train=False)
model = model_engine.model
tokenizer = model_engine.processor
if infer_dtype == "auto":
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
model = model.to(torch.bfloat16)
logger.info_rank0("Converted model to bfloat16.")
else:
target_dtype = getattr(torch, infer_dtype)
model = model.to(target_dtype)
logger.info_rank0(f"Converted model to {infer_dtype}.")
logger.info_rank0(f"Exporting model to {export_dir}...")
model.save_pretrained(
export_dir,
max_shard_size=f"{export_size}GB",
safe_serialization=not export_legacy_format,
)
if tokenizer is not None:
try:
if hasattr(tokenizer, "padding_side"):
tokenizer.padding_side = "left"
tokenizer.save_pretrained(export_dir)
except Exception as e:
logger.warning(f"Failed to save tokenizer: {e}")
if export_hub_model_id:
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
model.push_to_hub(export_hub_model_id)
if tokenizer is not None:
tokenizer.push_to_hub(export_hub_model_id)
logger.info_rank0("Model exported successfully.")

View File

@@ -0,0 +1,122 @@
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any
import torch
from transformers import BitsAndBytesConfig
from ...accelerator.helper import get_current_device
from ...config.model_args import ModelArguments
from ...utils import logging
from ...utils.packages import check_version
from ...utils.plugin import BasePlugin
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
logger = logging.get_logger(__name__)
class QuantizationPlugin(BasePlugin):
r"""Plugin for model quantization."""
def __call__(
self,
init_kwargs: dict[str, Any] = None,
config: "PretrainedConfig" = None,
tokenizer: "PreTrainedTokenizer" = None,
model_args: "ModelArguments" = None,
is_trainable: bool = False,
) -> dict[str, Any]:
return super().__call__(
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
)
@QuantizationPlugin("auto").register()
def quantization_auto(
init_kwargs: dict[str, Any],
**kwargs,
) -> dict[str, Any]:
"""Automatic quantization selection, only support bnb currently.
Args:
init_kwargs (dict[str, Any]): The kwargs for model initialization.
**kwargs: Keyword arguments containing the model.
Returns:
dict[str, Any]: The updated kwargs for model initialization.
"""
model_args: ModelArguments = kwargs.get("model_args", None)
quant_config = model_args.quant_config
quantization_bit = quant_config.get("quantization_bit", None)
if quantization_bit is not None:
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
if quantization_bit in [8, 4]:
return quantization_with_bnb(init_kwargs, **kwargs)
else:
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
logger.warning_rank0("No quantization method applied.")
return init_kwargs
@QuantizationPlugin("bnb").register()
def quantization_with_bnb(
init_kwargs: dict[str, Any],
model_args: "ModelArguments" = None,
**kwargs,
) -> dict[str, Any]:
r"""Quantization with BNB."""
logger.info_rank0("Using Bitsandbytes quantization.")
quantization_bit = model_args.quant_config.get("quantization_bit", None)
if quantization_bit is None:
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
quantization_bit = 4
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
if quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif quantization_bit == 4:
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
bnb_4bit_quant_storage=model_args.quant_config.get(
"compute_dtype", torch.float16
), # crucial for fsdp+qlora
)
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
# TODO: improve deepspeed zero3 and fsdp detection.
if kwargs.get("is_trainable", False):
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
else:
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
if model_args.quant_config.get("quantization_bit") != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
check_version("bitsandbytes>=0.43.0", mandatory=True)
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
return init_kwargs

View File

@@ -12,224 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import importlib
import re
from ...utils.constants import IGNORE_INDEX from ...utils import logging
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor, ToolCall from ...utils.types import Message, ModelInput, Processor
logger = logging.get_logger(__name__)
class RenderingPlugin(BasePlugin): class RenderingPlugin(BasePlugin):
_attempted_template_imports: set[str] = set()
def _ensure_template_imported(self) -> None:
if self.name is None or self.name in self._attempted_template_imports:
return
full_module_name = f"{__package__}.templates.{self.name}"
self._attempted_template_imports.add(self.name)
try:
importlib.import_module(full_module_name)
except Exception as exc:
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
def __getitem__(self, method_name: str):
self._ensure_template_imported()
return super().__getitem__(method_name)
def render_messages( def render_messages(
self, self,
processor: Processor, processor: Processor,
messages: list[Message], messages: list[Message],
tools: str | None = None, tools: str | None = None,
is_generate: bool = False, is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Render messages in the template format.""" """Render messages in the template format."""
return self["render_messages"](processor, messages, tools, is_generate) return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
def parse_messages(self, generated_text: str) -> Message: def parse_messages(self, generated_text: str) -> Message:
"""Parse messages in the template format.""" """Parse messages in the template format."""
return self["parse_messages"](generated_text) return self["parse_messages"](generated_text)
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,259 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
def _get_last_query_index(messages: list[Message]) -> int:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message["role"] != "user":
continue
user_text = ""
is_plain_text = True
for content in message["content"]:
if content["type"] != "text":
is_plain_text = False
break
user_text += content["value"]
if not is_plain_text:
continue
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
last_query_index = idx
break
return last_query_index
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content = ""
reasoning_content = ""
tool_calls: list[ToolCall] = []
for content in message["content"]:
if content["type"] == "text":
text_content += content["value"]
elif content["type"] == "reasoning":
reasoning_content += content["value"]
elif content["type"] == "tool_call":
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
tool_calls.append(tool_call)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return text_content, reasoning_content, tool_calls
@RenderingPlugin("qwen3").register("render_messages")
def render_qwen3_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
last_query_index = _get_last_query_index(messages)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
else:
temp_str += text_content
for tool_call_idx, tool_call in enumerate(tool_calls):
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
temp_str += "\n"
arguments = tool_call.get("arguments")
if isinstance(arguments, str):
arguments_str = arguments
else:
arguments_str = json.dumps(arguments, ensure_ascii=False)
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ arguments_str
+ "}\n</tool_call>"
)
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking is False:
temp_str += "<think>\n\n</think>\n\n"
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3").register("parse_message")
def parse_qwen3_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "think":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,209 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking:
raise ValueError("The qwen3_nothink template does not support thinking mode.")
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,129 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSpeed integration via accelerate's built-in capabilities.
Instead of manually calling deepspeed.initialize() and syncing config,
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving.
"""
from typing import Any, Optional
import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
class DeepSpeedEngine:
"""DeepSpeed integration using accelerate's built-in capabilities.
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
- Config syncing (auto values, batch size, lr, etc.)
- deepspeed.initialize() call
- Optimizer / LR scheduler wrapping
- Backward + gradient accumulation boundary
- ZeRO-3 parameter gathering for saving
"""
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
config_file = dist_config.get("config_file")
if not config_file:
raise ValueError("DeepSpeed config_file is required in dist_config")
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
self.accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
gradient_accumulation_steps=num_micro_batch,
)
# Resolve "auto" for train_micro_batch_size_per_gpu so that
# accelerate.prepare() does not require a DataLoader to infer it.
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
"""No-op shard — actual model wrapping happens in prepare().
Returns self so the caller gets the engine instance via the hub interface.
"""
return self
def prepare(
self,
model: HFModel,
optimizer: torch.optim.Optimizer,
lr_scheduler: Optional[Any] = None,
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
"""Prepare model, optimizer, and lr_scheduler using accelerate.
Internally calls deepspeed.initialize() and wraps the returned objects.
"""
if lr_scheduler is not None:
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
else:
model, optimizer = self.accelerator.prepare(model, optimizer)
model._accelerator = self.accelerator # type: ignore[assignment]
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
return model, optimizer, lr_scheduler
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass using accelerate.
Delegates to DeepSpeedEngineWrapper.backward() which respects
sync_gradients to control gradient accumulation boundaries.
When sync_gradients=True: engine.backward(loss) + engine.step()
When sync_gradients=False: engine.backward(loss) only
"""
self.accelerator.backward(loss)
def get_grad_norm(self) -> float:
"""Get the global gradient norm from the DeepSpeed engine."""
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
if engine_wrapper is not None:
return engine_wrapper.engine.get_global_grad_norm() or 0.0
return 0.0
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
"""Save model using accelerate's built-in ZeRO-aware utilities.
Expects model._accelerator to be set during prepare().
Handles ZeRO-3 parameter gathering automatically via
accelerator.get_state_dict().
"""
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
unwrapped_model = accelerator.unwrap_model(model)
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}")

View File

@@ -12,28 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import gc import gc
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import ( from torch.distributed.fsdp import (
CPUOffloadPolicy, CPUOffloadPolicy,
MixedPrecisionPolicy, MixedPrecisionPolicy,
fully_shard, fully_shard,
) )
from transformers import PreTrainedModel
from ....accelerator.helper import get_current_accelerator from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__) logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None: def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None) no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules: if no_split_modules:
if isinstance(no_split_modules, (list, tuple)): if isinstance(no_split_modules, (list, tuple)):
@@ -49,6 +51,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
return None return None
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving...")
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(model, options=options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
logger.info(f"Model saved to {output_dir}")
class FSDP2Engine: class FSDP2Engine:
def __init__(self, dist_config: dict): def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface() self.dist_interface = DistributedInterface()
@@ -94,7 +110,10 @@ class FSDP2Engine:
cast_forward_inputs=True, cast_forward_inputs=True,
) )
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel: def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: HFModel) -> HFModel:
if self.fsdp_mesh is None: if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.") logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model return model
@@ -111,6 +130,25 @@ class FSDP2Engine:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}") logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls} transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
lora_modules.append(module)
for module in lora_modules:
fully_shard(
module,
mesh=self.fsdp_mesh,
reshard_after_forward=self.reshard_after_forward,
mp_policy=mp_policy,
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info("Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules(): for name, module in model.named_modules():
should_wrap = False should_wrap = False
@@ -129,12 +167,11 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
) )
use_gradient_checkpointing = True # Could be configurable # BaseTrainer is the single source of truth for gradient checkpointing.
if use_gradient_checkpointing: # FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if getattr(model, "is_gradient_checkpointing", False):
if self.rank == 0: if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...") logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
@@ -156,7 +193,7 @@ class FSDP2Engine:
return model return model
@torch.no_grad() @torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None): def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0: if self.rank == 0:
logger.info("Materializing sharded model params...") logger.info("Materializing sharded model params...")
@@ -176,15 +213,57 @@ class FSDP2Engine:
return model return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel: def _save_non_persistent_buffers(self, model: HFModel) -> dict:
"""Save non-persistent buffers, such as inv_freq."""
saved = {}
for mod_name, module in model.named_modules():
for buf_name in module._non_persistent_buffers_set:
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
buf = getattr(module, buf_name, None)
if buf is not None:
saved[fqn] = copy.deepcopy(buf)
if self.rank == 0 and saved:
logger.info(f"Saved {len(saved)} non-persistent buffers")
return saved
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
"""Register saved non-persistent buffers to model."""
if not saved_buffers:
return
device = get_current_accelerator()
for fqn, buf in saved_buffers.items():
buf = buf.to(device)
if "." in fqn:
parent_fqn, buf_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
buf_name = fqn
parent_module = model
parent_module.register_buffer(buf_name, buf, persistent=False)
if self.rank == 0:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta": if model.device.type == "meta":
non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
model = self.prepare_model(model) model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers)
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
return model return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str): def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
try: try:
@@ -203,7 +282,7 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}") logger.error(f"Failed to load from DCP: {e}")
raise e raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path): def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob import glob
import json import json

View File

@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ....config.arg_utils import PluginConfig from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
if TYPE_CHECKING:
from ....utils.types import HFModel, Processor
class DistributedPlugin(BasePlugin): class DistributedPlugin(BasePlugin):
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
@DistributedPlugin("fsdp2").register() @DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel: def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model) return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
from .fsdp2 import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register() @DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel: def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
return model from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
dist_config,
num_micro_batch=kwargs.get("num_micro_batch"),
micro_batch_size=kwargs.get("micro_batch_size"),
).shard_model(model)
@DistributedPlugin("deepspeed").register("save_model")
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
from .deepspeed import save_model
return save_model(model, output_dir, processor)

View File

@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args) model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config) DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset) train_dataset = DataEngine(data_args.train_dataset)
model_engine = ModelEngine(model_args) model_engine = ModelEngine(model_args, is_train=True)
trainer = SFTTrainer( trainer = SFTTrainer(
args=training_args, args=training_args,
model=model_engine.model, model=model_engine.model,

View File

@@ -15,12 +15,22 @@
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool: def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer. """Check if processor is tokenizer.

View File

@@ -21,6 +21,13 @@ from functools import lru_cache
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from packaging import version from packaging import version
from transformers.utils.versions import require_version
from . import logging
from .env import is_env_enabled
logger = logging.get_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
@lru_cache @lru_cache
def is_transformers_version_greater_than(content: str): def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content) return _get_package_version("transformers") >= version.parse(content)
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptqmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)

View File

@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict): class Content(TypedDict):
type: Literal["text", "reasoning", "tool_call", "image_url"] type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
"""Type of the content.""" """Type of the content."""
value: str value: str
"""Value of the content.""" """Value of the content."""

View File

@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Column(): with gr.Column():
enable_thinking = gr.Checkbox(value=True) enable_thinking = gr.Checkbox(value=True)
report_to = gr.Dropdown( report_to = gr.Dropdown(
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"], choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
value="none", value="none",
allow_custom_value=True, allow_custom_value=True,
) )
with gr.Accordion("Trackio Settings", open=False):
project = gr.Textbox(
value="huggingface",
label="Project Name",
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
)
trackio_space_id = gr.Textbox(
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
)
hub_private_repo = gr.Checkbox(
value=False, label="Private Repository", info="Make the Hugging Face repository private"
)
input_elems.update( input_elems.update(
{ {
logging_steps, logging_steps,
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro, use_llama_pro,
enable_thinking, enable_thinking,
report_to, report_to,
project,
trackio_space_id,
hub_private_repo,
} }
) )
elem_dict.update( elem_dict.update(
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
enable_thinking=enable_thinking, enable_thinking=enable_thinking,
report_to=report_to, report_to=report_to,
project=project,
trackio_space_id=trackio_space_id,
hub_private_repo=hub_private_repo,
) )
) )

View File

@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
def fix_valuehead_cpu_loading(): def fix_valuehead_cpu_loading():
"""Fix valuehead model loading.""" """Fix valuehead model loading."""
patch_valuehead_model() patch_valuehead_model()
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import Counter
import pytest import pytest
import torch import torch
@@ -22,6 +23,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -116,19 +118,189 @@ def test_multimodal_collator():
"labels": [ "labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q], [0, 1, 2, 3, q, q, q, q, q, q, q, q],
], ],
"position_ids": [ "position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], "rope_deltas": [[0]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image), **tokenizer_module["processor"].image_processor(fake_image),
} }
if not is_transformers_version_greater_than("5.0.0"):
# adapt position_ids and rope_deltas for transformers < 5.0.0
# https://github.com/huggingface/transformers/pull/43972
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
expected_input["rope_deltas"] = [[-8]]
assert batch_input.keys() == expected_input.keys() assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys(): for k in batch_input.keys():
if k == "position_ids" and batch_input[k].dim() == 3 and batch_input[k].shape[0] == 4:
batch_input[k] = batch_input[k][1:]
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def _make_packed_feature(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int | None = None,
vision_end_id: int | None = None,
image_pad_id: int | None = None,
) -> dict:
r"""Build one packed sample using the new PackingParams schema."""
sequence_boundaries = packing_params["sequence_boundaries"]
image_subseq_ids = packing_params["image_subseq_ids"]
video_subseq_ids = packing_params["video_subseq_ids"]
audio_subseq_ids = packing_params["audio_subseq_ids"]
unpadded_length = packing_params["unpadded_length"]
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
cutoff_plus_one = sequence_boundaries[-1]
content_len = unpadded_length
pad_len = right_padding_length
assert content_len + pad_len == cutoff_plus_one
assert sequence_boundaries[0] == 0
assert sequence_boundaries[-1] == cutoff_plus_one
content_ids = list(range(100, 100 + content_len))
if vision_start_id is not None and vision_end_id is not None and image_pad_id is not None:
image_counts_by_subseq = Counter(image_subseq_ids)
for subseq_idx, image_count in sorted(image_counts_by_subseq.items()):
if subseq_idx >= len(sequence_boundaries) - 1:
continue
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_len = subseq_end - subseq_start
if subseq_len < 3:
continue
# Build repeated image groups while preserving at least 3 tokens for each remaining image.
injected_tokens: list[int] = []
remaining = subseq_len
for image_idx in range(image_count):
remaining_images = image_count - image_idx
min_reserved_for_rest = 3 * (remaining_images - 1)
current_group_len = min(6, remaining - min_reserved_for_rest)
if current_group_len < 3:
break
group = [vision_start_id] + [image_pad_id] * max(1, current_group_len - 2) + [vision_end_id]
injected_tokens.extend(group[:current_group_len])
remaining -= current_group_len
if injected_tokens:
insert_end = subseq_start + len(injected_tokens)
content_ids[subseq_start:insert_end] = injected_tokens
input_ids = content_ids + [pad_token_id] * pad_len
attention_mask = [1] * content_len + [0] * pad_len
labels = [label_ignore_id] * cutoff_plus_one
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"images": [fake_image] * len(image_subseq_ids),
"videos": [None] * len(video_subseq_ids),
"audios": [None] * len(audio_subseq_ids),
"packing_params": packing_params,
}
def _make_packed_features(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int,
vision_end_id: int,
image_pad_id: int,
) -> list[dict]:
r"""Build packed features from caller-provided packing_params."""
return [
_make_packed_feature(
packing_params=packing_params,
pad_token_id=pad_token_id,
label_ignore_id=label_ignore_id,
fake_image=fake_image,
vision_start_id=vision_start_id,
vision_end_id=vision_end_id,
image_pad_id=image_pad_id,
)
]
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
bound_list = packing_params["sequence_boundaries"]
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
all_position_ids = []
for i, input_ids_slice in enumerate(input_ids_slices):
img_cnt = img_counts_by_subseq[i]
if sum(attention_mask_slices[i]) == 0:
continue
rope_func_kwargs = {
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
}
position_ids, _ = get_rope_func(**rope_func_kwargs)
all_position_ids.append(position_ids)
return torch.cat(all_position_ids, dim=-1)
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator_with_packing():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
tokenizer_module["tokenizer"].padding_side = "right"
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForImageTextToText.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
model=model,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
tokenizer = tokenizer_module["tokenizer"]
packing_params = {
"sequence_boundaries": [0, 2, 10, 18, 28, 32],
"image_subseq_ids": [1, 2, 3],
"video_subseq_ids": [],
"audio_subseq_ids": [],
"unpadded_length": 28,
"right_padding_length": 4,
}
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = _make_packed_features(
packing_params=packing_params,
pad_token_id=tokenizer.pad_token_id,
label_ignore_id=IGNORE_INDEX,
fake_image=fake_image,
vision_start_id=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
vision_end_id=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
image_pad_id=tokenizer.convert_tokens_to_ids("<|image_pad|>"),
)
expected_position_ids = _get_expected_position_ids(
packing_params,
data_collator.get_rope_func,
features[0]["input_ids"],
features[0]["attention_mask"],
)
batch_input = data_collator(features) # [3, bsz, seq_len]
valid_len = expected_position_ids.shape[-1]
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu"])
def test_4d_attention_mask(): def test_4d_attention_mask():
o = 0.0 o = 0.0

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.106 0.9.5.107

View File

@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu": elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1) monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
@pytest.fixture(scope="session", autouse=True)
def bypass_mistral_regex_check():
"""Disable Mistral regex network check.
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
"""
try:
from transformers.tokenization_utils_fast import TokenizersBackend
except ImportError:
# Very old transformers, nothing to patch
yield
return
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
# Method does not exist in this version
yield
return
# Backup original method
original = TokenizersBackend._patch_mistral_regex
# Replace with no-op
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
yield
# Restore original method
TokenizersBackend._patch_mistral_regex = original

View File

@@ -0,0 +1,156 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from llamafactory.v1.plugins.model_plugins import peft as peft_module
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
TINY_MODEL = "llamafactory/tiny-random-qwen3"
@pytest.fixture(scope="module")
def model_path():
return TINY_MODEL
@pytest.fixture(scope="function")
def model(model_path):
return AutoModelForCausalLM.from_pretrained(model_path)
@pytest.fixture(scope="function")
def tokenizer(model_path):
return AutoTokenizer.from_pretrained(model_path)
@pytest.fixture(scope="function")
def adapter_path(tmp_path):
# Create a dummy adapter
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
peft_model = get_peft_model(base_model, lora_config)
save_path = tmp_path / "test_adapter"
peft_model.save_pretrained(save_path)
return str(save_path)
def test_find_all_linear_modules(model):
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
modules = peft_module._find_all_linear_modules(model)
expected_subset = {"q_proj", "v_proj"}
assert expected_subset.issubset(set(modules))
def test_get_lora_model(model):
"""Verify a PeftModel is returned and LoRA config takes effect."""
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
model = peft_module.get_lora_model(model, config, is_train=True)
assert isinstance(model, PeftModel)
assert model.peft_config["default"].r == 8
assert "q_proj" in model.peft_config["default"].target_modules
def test_get_freeze_model_layers(model):
"""Verify layer-wise freezing: only the last layer stays trainable."""
# Freeze all but last layer
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
# Ensure we start with something known
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
assert num_layers > 0
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name:
assert param.requires_grad, f"{name} should be trainable"
elif "layers.0" in name and num_layers > 1:
assert not param.requires_grad, f"{name} should be frozen"
def test_get_freeze_model_modules(model):
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
# Freeze specific modules (e.g. only self_attn)
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
model = peft_module.get_freeze_model(model, config, is_train=True)
num_layers = model.config.num_hidden_layers
for name, param in model.named_parameters():
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
assert param.requires_grad, f"{name} should be trainable"
else:
assert not param.requires_grad, f"{name} should be frozen"
def test_load_adapter_single_for_inference(model, adapter_path):
"""Verify single adapter is merged+unloaded in inference mode."""
# Test loading single adapter for inference (merge and unload)
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
assert not isinstance(model_result, PeftModel)
def test_load_adapter_resume_train(model, adapter_path):
"""Verify training mode returns a trainable PeftModel."""
# Test loading for training
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
assert isinstance(model_result, PeftModel)
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
"""Verify multiple adapters are rejected in training mode."""
with pytest.raises(ValueError, match="only a single LoRA adapter"):
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
def test_load_adapter_infer_multiple_merges(model, adapter_path):
"""Verify multiple adapters are merged in inference mode."""
# Test merging multiple adapters
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
assert not isinstance(model_result, PeftModel)
def test_merge_and_export_model(tmp_path, adapter_path):
"""Verify merge_and_export_model produces export artifacts."""
export_dir = tmp_path / "export"
args_dict = {
"model": TINY_MODEL,
"peft_config": {
"name": "lora",
"adapter_name_or_path": adapter_path,
"export_dir": str(export_dir),
"export_size": 1,
"infer_dtype": "float16",
},
}
merge_and_export_model(args_dict)
assert export_dir.exists()
assert (export_dir / "config.json").exists()
assert (export_dir / "model.safetensors").exists()
assert (export_dir / "tokenizer_config.json").exists()

View File

@@ -0,0 +1,51 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
bitsandbytes = pytest.importorskip("bitsandbytes")
def check_quantization_status(model):
quantized_info = {"bnb": []}
for name, module in model.named_modules():
# check BitsAndBytes quantization
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
module, bitsandbytes.nn.modules.Linear4bit
):
quantized_info["bnb"].append(name)
return quantized_info
@pytest.mark.runs_on(["cuda"])
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
def test_quantization_plugin(name, quantization_bit):
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen3",
quant_config={
"name": name,
"quantization_bit": quantization_bit,
},
)
model_engine = ModelEngine(model_args=model_args)
quantized_info = check_quantization_status(model_engine.model)
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
assert any(v for v in quantized_info.values()), "model is not quantized properly."

View File

@@ -0,0 +1,104 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
Validates that the FSDP2 meta loading path behaves correctly for tied weights
and non-persistent buffers by comparing it with the standard non-meta path.
"""
import torch
from transformers import AutoConfig
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
TINY_MODEL = "llamafactory/tiny-random-qwen3"
def collect_non_persistent_buffers(model):
"""Collect all non-persistent buffers from model."""
result = {}
for mod_name, module in model.named_modules():
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
buf = getattr(module, buf_name, None)
if buf is not None:
result[fqn] = buf.detach().cpu().clone()
return result
def test_fsdp2_meta_loading_buffers_and_tied_weights():
"""Verify non-persistent buffers and tied weights consistency after meta load."""
# 1. Initialize DistributedInterface for single process
DistributedInterface()
# 2. Build FSDP2Engine config
engine = FSDP2Engine(
{
"name": "fsdp2",
"mixed_precision": "bf16",
"reshard_after_forward": True,
"offload_params": False,
"pin_memory": False,
"dcp_path": None,
}
)
config = AutoConfig.from_pretrained(TINY_MODEL)
# --- NORMAL PATH ---
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
normal_engine = ModelEngine(model_args=normal_args)
normal_model = normal_engine.model.to(torch.bfloat16)
normal_model = engine.shard_model(normal_model)
normal_non_persistent = collect_non_persistent_buffers(normal_model)
del normal_model
# --- META PATH ---
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
meta_model_engine = ModelEngine(model_args=meta_args)
meta_model = meta_model_engine.model
assert meta_model.device.type == "meta", "Model should be on meta device"
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
meta_model = engine.shard_model(meta_model)
meta_non_persistent = collect_non_persistent_buffers(meta_model)
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
if tie_word_embeddings:
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
"Weights should be tied after loading"
)
del meta_model
# 4. Non-persistent buffers (e.g., inv_freq)
normal_buf_keys = set(normal_non_persistent.keys())
meta_buf_keys = set(meta_non_persistent.keys())
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
for key in sorted(normal_buf_keys & meta_buf_keys):
nb = normal_non_persistent[key]
mb = meta_non_persistent[key]
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"