28 Commits

Author SHA1 Message Date
Kingsley
ffbff33af3 chore: mca workflow compatible with qwen-vl series (#10303) 2026-03-22 02:28:52 +08:00
Kingsley
833f6027b1 [fix] fit neat_packing & mrope model packing (#10283)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-20 16:50:11 +08:00
robertglools
d91d8af89e [data] add SGSC zero-hallucination B2B dataset (NOO-Protocol) (#10284)
Co-authored-by: GloolsGuan <GloolsGuan@gmail.com>
2026-03-20 15:49:03 +08:00
xxddccaa
e67ab9e2f2 fix:MiniCPMVPlugin IndexError in process_messages when training with video (#10276)
Co-authored-by: xxddccaa <xxddccaa@users.noreply.github.com>
2026-03-18 19:18:06 +08:00
LincolnBurrows2017
2c4f121817 [fix] handle empty content list in system message (#10291)
Co-authored-by: AI Assistant <assistant@example.com>
2026-03-18 12:05:49 +08:00
xvxuopop
487f8b8191 [v1] add qwen3 templates and fix rendering plugin. (#10212)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-18 11:30:50 +08:00
SnowCharm
78cad1e332 [fix] unused keys in ray example (#10290) 2026-03-18 00:23:53 +08:00
LincolnBurrows2017
70653026f5 [fix] make position_id_per_seconds configurable for Qwen2OmniPlugin (#10281)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
2026-03-16 19:42:38 +08:00
Ruijie Hou
246192abd2 [data] correct gpt_oss template format_assistant (#10269) 2026-03-10 21:36:38 +08:00
浮梦
0258dc14d0 [docker] update npu docker (#10268)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-10 19:37:27 +08:00
xxddccaa
3045adf0ba [fix] fallback to audio_processor when feature_extractor is missing (#10267)
Co-authored-by: kevin <742971636@qq.com>
2026-03-10 19:36:41 +08:00
Kingsley
a3d44e3152 [mca] support qwen3.5 (#10265)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-10 10:55:16 +08:00
JiangNan
edeb953bc7 [data] convert filter() to list in read_cloud_json to fix broken empty-check (#10260)
Signed-off-by: JiangNan <1394485448@qq.com>
2026-03-09 17:12:53 +08:00
yizhouChen
d045794387 [docs] fix Python version requirement from 3.10 to >=3.11.0 (#10259)
Co-authored-by: chaiyzh <chaiyzh@126.com>
2026-03-09 16:44:07 +08:00
pyx
9501c3308a [train] fix compatibility issue with HuggingFace Dataset Column when sav… (#10254)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-06 18:44:57 +08:00
jiaqiw09
0ee1c42c2b [v1] Support meta loading for full and free (#10236) 2026-03-05 23:15:27 +08:00
SnowCharm
3061f48d55 [ray] fix get ray head ip (#10252) 2026-03-05 23:14:38 +08:00
LittleYanlin
2d9bd2aa14 [fix] qwen3.5 projector path (#10242)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-04 01:31:09 +08:00
Hertz
c0245c43fc [model] support Qwen3.5 all series models (#10237)
Co-authored-by: gatilin <gatilin@tencent.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-03 17:34:59 +08:00
Parag Ekbote
eb976d75a2 [tracker] Add Trackio Integration for LlamaFactory (#10165)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-03 17:19:37 +08:00
Yaowei Zheng
b5cb7cb0e6 [misc] fix constants (#10232) 2026-03-02 11:10:48 +08:00
Philip Ottesen
0779846513 [infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
2026-02-28 20:26:53 +08:00
jiaqiw09
45d335c709 [v1] add seed for training and fix gradient checkpointing (#10211) 2026-02-28 18:16:06 +08:00
Kingsley
816480012f [fix] register visual part for Qwen3.5 (#10227) 2026-02-28 16:39:24 +08:00
Mikko Tukiainen
d3bf882e87 [docker] upgrade to ROCm 7.2 base image, drop PyTorch reinstall (#10223)
Co-authored-by: Mikko Tukiainen <mtukiain@chi-mi300x-012.ord.vultr.cpe.ice.amd.com>
2026-02-27 20:16:33 +08:00
娄宗志
589da21d32 [model] support Aeva (#10214) 2026-02-26 23:03:13 +08:00
Yaowei Zheng
122cd46084 [model] update constants (#10220) 2026-02-26 21:13:56 +08:00
浮梦
2b8b871475 [model] Adapt Qwen3.5 (#10213)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-26 20:45:02 +08:00
57 changed files with 1813 additions and 441 deletions

View File

@@ -35,15 +35,12 @@ jobs:
transformers: transformers:
- "" - ""
include: # test backward compatibility include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.55.0" transformers: "4.55.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.57.1"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@@ -291,7 +291,7 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -319,6 +319,7 @@ Read technical notes:
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -472,7 +473,7 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.11 | >=3.11 |
| torch | 2.0.0 | 2.6.0 | | torch | 2.0.0 | 2.6.0 |
| torchvision | 0.15.0 | 0.21.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.49.0 | 4.50.0 | | transformers | 4.49.0 | 4.50.0 |

View File

@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
@@ -474,7 +475,7 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.9 | 3.10 | | python | 3.11 | >=3.11 |
| torch | 2.0.0 | 2.6.0 | | torch | 2.0.0 | 2.6.0 |
| torchvision | 0.15.0 | 0.21.0 | | torchvision | 0.15.0 | 0.21.0 |
| transformers | 4.49.0 | 4.50.0 | | transformers | 4.49.0 | 4.50.0 |

View File

@@ -236,6 +236,13 @@
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4", "ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"sgsc_b2b_entities": {
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
}
},
"ultrachat_200k": { "ultrachat_200k": {
"hf_hub_url": "HuggingFaceH4/ultrachat_200k", "hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k",

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags # https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11 ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ RUN pip uninstall -y torch torchvision torchaudio
pip install --no-cache-dir -e . --no-build-isolation && \ RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile dockerfile: ./docker/docker-npu/Dockerfile
context: ../.. context: ../..
args: args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3 container_name: llamafactory-a3
image: llamafactory:npu-a3 image: llamafactory:npu-a3

View File

@@ -1,12 +1,12 @@
# https://hub.docker.com/r/rocm/pytorch/tags # https://hub.docker.com/r/rocm/pytorch/tags
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0 # ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
# Define environments # Define environments
ENV MAX_JOBS=16 ENV MAX_JOBS=16
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
# Copy the application into the image # Copy the application into the image
COPY . /app COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip install --no-cache-dir -e . --pre && \
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -47,4 +47,3 @@
border-color: rgba(255, 255, 255, 0.45); border-color: rgba(255, 255, 255, 0.45);
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12); box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
} }

View File

@@ -1,33 +1,31 @@
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
import os
import sys
# Define common settings here # Define common settings here
project = 'LlamaFactory' project = "LlamaFactory"
copyright = '2024, LlamaFactory Team' copyright = "2024, LlamaFactory Team"
author = 'LlamaFactory Team' author = "LlamaFactory Team"
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'myst_parser', "myst_parser",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
html_static_path = ['_static'] html_static_path = ["_static"]
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]
html_css_files = [ html_css_files = [
'css/lang-switcher.css', "css/lang-switcher.css",
] ]
myst_enable_extensions = [ myst_enable_extensions = [

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'en' language = "en"
html_search_language = 'en' html_search_language = "en"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'zh_CN' language = "zh_CN"
html_search_language = 'zh' html_search_language = "zh"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -28,12 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use. ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker:
GPU: 1
# ray_init_kwargs: # ray_init_kwargs:
# runtime_env: # runtime_env:
# env_vars: # env_vars:

View File

@@ -22,4 +22,3 @@ cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: true bf16: true
max_steps: 10 max_steps: 10

View File

@@ -14,16 +14,12 @@ dist_config:
name: fsdp2 name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: false bf16: false

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0", "transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1", "peft>=0.18.0,<=0.18.1",

View File

@@ -1,4 +1,4 @@
torch==2.7.1 torch==2.7.1
torch-npu==2.7.1 torch-npu==2.7.1.post2
torchvision==0.22.1 torchvision==0.22.1
torchaudio==2.7.1 torchaudio==2.7.1

View File

@@ -71,6 +71,7 @@ def convert(
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1, expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None, virtual_pipeline_model_parallel_size: int | None = None,
moe_grouped_gemm: bool | None = None,
): ):
"""Convert checkpoint between MCA and HuggingFace formats. """Convert checkpoint between MCA and HuggingFace formats.
@@ -84,6 +85,10 @@ def convert(
pipeline_model_parallel_size: Pipeline model parallel size pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
Must match the format used when saving the checkpoint.
""" """
if bf16 and fp16: if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.") raise ValueError("bf16 and fp16 cannot be both True.")
@@ -97,8 +102,9 @@ def convert(
pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
moe_grouped_gemm=moe_grouped_gemm,
transformer_impl="transformer_engine", # hard code here since we default using te for training
) )
convert_checkpoint_to_mca( convert_checkpoint_to_mca(
checkpoint_path, checkpoint_path,
output_path, output_path,

View File

@@ -154,25 +154,24 @@ def vllm_infer(
batch = train_dataset[i : min(i + batch_size, len(train_dataset))] batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])): for j in range(len(batch["input_ids"])):
multi_modal_data = {}
video_metadata_kwargs = None
if batch["images"][j] is not None: if batch["images"][j] is not None:
image = batch["images"][j] image = batch["images"][j]
multi_modal_data = { multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
"image": template_obj.mm_plugin._regularize_images( image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels )["images"]
)["images"]
} if batch["videos"][j] is not None:
elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
"video": template_obj.mm_plugin._regularize_videos( video,
video, image_max_pixels=image_max_pixels,
image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels,
image_min_pixels=image_min_pixels, video_fps=video_fps,
video_fps=video_fps, video_maxlen=video_maxlen,
video_maxlen=video_maxlen, )["videos"]
)["videos"]
}
if need_video_kwargs: if need_video_kwargs:
container = av.open(video[0], "r") container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -192,18 +191,17 @@ def vllm_infer(
video_backend="opencv", video_backend="opencv",
) )
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None:
if batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
audio, audio,
sampling_rate=16000, sampling_rate=16000,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: if video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data) vllm_inputs.append(vllm_input_data)

View File

@@ -88,7 +88,10 @@ def _process_request(
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content if isinstance(content, list):
system = content[0].text if content else ""
else:
system = content
else: else:
system = None system = None

View File

@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
else self.generating_args["skip_special_tokens"], else self.generating_args["skip_special_tokens"],
) )
multi_modal_data = {}
if images is not None: # add image features if images is not None: # add image features
multi_modal_data = { multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
"image": self.template.mm_plugin._regularize_images( images,
images, image_max_pixels=self.model_args.image_max_pixels,
image_max_pixels=self.model_args.image_max_pixels, image_min_pixels=self.model_args.image_min_pixels,
image_min_pixels=self.model_args.image_min_pixels, )["images"]
)["images"]
} if videos is not None:
elif videos is not None: multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
multi_modal_data = { videos,
"video": self.template.mm_plugin._regularize_videos( image_max_pixels=self.model_args.video_max_pixels,
videos, image_min_pixels=self.model_args.video_min_pixels,
image_max_pixels=self.model_args.video_max_pixels, video_fps=self.model_args.video_fps,
image_min_pixels=self.model_args.video_min_pixels, video_maxlen=self.model_args.video_maxlen,
video_fps=self.model_args.video_fps, )["videos"]
video_maxlen=self.model_args.video_maxlen,
)["videos"] if audios is not None:
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios( audio_data = self.template.mm_plugin._regularize_audios(
audios, audios,
sampling_rate=self.model_args.audio_sampling_rate, sampling_rate=self.model_args.audio_sampling_rate,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,

View File

@@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -24,7 +26,7 @@ import torch.nn.functional as F
from peft import PeftModel from peft import PeftModel
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
@@ -38,6 +40,56 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
def _slice_mm_inputs_for_sample(
mm_inputs: dict[str, Any],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_idx: int,
images_per_subseq: Optional[list[int]] = None,
videos_per_subseq: Optional[list[int]] = None,
subseq_idx: Optional[int] = None,
) -> dict[str, Any]:
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
is given, further restrict to that sub-seq's counts via packed_*_counts.
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
"""
image_start_idx = sum(batch_imglens[:batch_idx])
image_end_idx = sum(batch_imglens[: batch_idx + 1])
video_start_idx = sum(batch_vidlens[:batch_idx])
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
if subseq_idx is not None and images_per_subseq is not None:
image_start_idx += sum(images_per_subseq[:subseq_idx])
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
if subseq_idx is not None and videos_per_subseq is not None:
video_start_idx += sum(videos_per_subseq[:subseq_idx])
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
sliced_mm_inputs: dict[str, Any] = {}
key_to_slice_meta = {
"image_grid_thw": (image_start_idx, image_end_idx, True),
"video_grid_thw": (video_start_idx, video_end_idx, True),
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
}
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
if key not in mm_inputs:
continue
mm_value = mm_inputs[key]
if mm_value is not None and end_idx > start_idx:
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
elif assign_none_when_empty:
sliced_mm_inputs[key] = None
return sliced_mm_inputs
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""Expand 2d attention mask to 4d attention mask. r"""Expand 2d attention mask to 4d attention mask.
@@ -105,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
else: else:
self.get_rope_func = None self.get_rope_func = None
def _compute_rope_position_ids(
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
) -> None:
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if features["attention_mask"].sum() == 0:
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
def _compute_rope_position_ids_with_packing(
self,
features: dict[str, "torch.Tensor"],
mm_inputs: dict[str, Any],
packing_params_list: list[dict[str, Any] | None],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_audlens: list[int],
has_dummy_image: bool,
) -> None:
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
bsz = features["input_ids"].size(0)
seq_len = features["input_ids"].size(1)
all_position_ids: list[torch.Tensor] = []
all_rope_deltas: list[torch.Tensor] = []
if has_dummy_image:
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs)
for sample_idx in range(bsz):
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
sequence_boundaries = sample_packing.get("sequence_boundaries")
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
images_per_subseq = (
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
)
videos_per_subseq = (
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
)
if has_dummy_image:
mm_inputs = {}
if num_sub_seqs <= 1:
sample_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
}
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
)
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
all_position_ids.append(sample_features["position_ids"])
all_rope_deltas.append(sample_features["rope_deltas"])
else:
# when we do packing, don't need rope_deltas when training.
sample_position_ids: list[torch.Tensor] = []
for subseq_idx in range(num_sub_seqs):
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_features = {
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
}
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
mm_inputs,
batch_imglens,
batch_vidlens,
sample_idx,
images_per_subseq,
videos_per_subseq,
subseq_idx
)
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
if has_dummy_image:
mm_inputs = dummy_mm_inputs
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
all_position_ids[0].size(0),
bsz,
seq_len,
)
# Check if position_ids shape matches expected shape.
# for further usage, we should padding to the right when some padding token on the right.
if has_dummy_image:
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
if features["position_ids"].shape != expected_position_ids_shape:
raise ValueError(
"Merged position_ids shape mismatch: "
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
packing_params_list: list[dict[str, Any] | None] = []
for feature in features: for feature in features:
images = feature.pop("images", None) or [] images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or [] videos = feature.pop("videos", None) or []
@@ -119,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_audlens.append(len(audios)) batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
packing_params_list.append(feature.pop("packing_params", None))
fake_input_ids = [] fake_input_ids = []
has_dummy_image = False
if ( if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case ): # avoid process hanging in zero3/fsdp case
@@ -136,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_input_ids.extend(_fake_input_ids) fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images batch_images = fake_images
batch_imglens[0] = 1 batch_imglens[0] = 1
has_dummy_image = True
if ( if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
@@ -182,46 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: dict[str, torch.Tensor] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
bsz, seq_len = features["input_ids"].shape[:2]
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
is_omni = model_type in [
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
]
if self.get_rope_func is not None: if self.get_rope_func is not None:
rope_index_kwargs = { # for mmrope situation, we should calculate position_ids and rope_deltas per sample.
"input_ids": features["input_ids"], # When neat_packing is on, each sample has packing_params; None means no packing for that sample.
"image_grid_thw": mm_inputs.get("image_grid_thw"), boundaries_list = [
"video_grid_thw": mm_inputs.get("video_grid_thw"), p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
"attention_mask": (features["attention_mask"] >= 1).float(), ]
} has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if "second_per_grid_ts" in mm_inputs: # for qwen2vl if has_dummy_image and has_packing:
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") # FIXME: too tricky, need to be refactored
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni features["has_dummy_image"] = True
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]: # When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) if not has_packing:
feature_attention_mask = mm_inputs.get("feature_attention_mask", None) self._compute_rope_position_ids(features, mm_inputs)
if feature_attention_mask is not None: # FIXME: need to get video image lengths else:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) if is_omni:
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input raise RuntimeError("Omni models are not supported for packed sequences for now.")
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) self._compute_rope_position_ids_with_packing(
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( features,
dim=-1 mm_inputs,
).unsqueeze(-1) packing_params_list,
else: # for qwen vl batch_imglens,
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) batch_vidlens,
batch_audlens,
has_dummy_image,
)
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
if features["position_ids"].dim() == 3:
features["position_ids"] = torch.cat(
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
)
if ( if (
self.model is not None self.model is not None
and getattr(self.model.config, "model_type", None) and getattr(self.model.config, "model_type", None) in MROPE_MODELS
in [
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3) and ("position_ids" not in features or features["position_ids"].dim() != 3)
): ):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.") raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
@@ -249,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
neat_packing: bool = False
def __post_init__(self):
super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
@staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None:
r"""Trim padded positions for packed FA2 batches."""
attention_mask = features.get("attention_mask")
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
return
seq_len = attention_mask.size(1)
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
if non_padding_indices.numel() == seq_len:
return
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
for key, value in list(features.items()):
if not torch.is_tensor(value):
continue
if key == "position_ids" and value.size(-1) == seq_len:
features[key] = value.index_select(-1, non_padding_indices)
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
has_dummy_image = features.pop("has_dummy_image", False)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
if not has_dummy_image:
self._unpad_packed_features(features)
features["attention_mask"] = None # let transformers handle causal packed mask.
for key, value in features.items(): # cast data dtype for paligemma for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value): if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype) features[key] = value.to(self.compute_dtype)

View File

@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
# filter out non-JSON files # filter out non-JSON files
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path] files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files) files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
if not files: if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.") raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")

View File

@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
from transformers.models.mllama.processing_mllama import ( from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from transformers.video_utils import make_batched_videos
from typing_extensions import override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -47,13 +48,6 @@ if is_pyav_available():
import av import av
if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
else:
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
if TYPE_CHECKING: if TYPE_CHECKING:
from av.stream import Stream from av.stream import Stream
from numpy.typing import NDArray from numpy.typing import NDArray
@@ -161,7 +155,9 @@ class MMPluginMixin:
video_processor: BaseImageProcessor = getattr( video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None) processor, "video_processor", getattr(processor, "image_processor", None)
) )
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
@@ -390,7 +386,9 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt")) mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0: if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True, chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False): if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs}) mm_inputs.update({"audio_phs": audio_phs})
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
) )
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = ( video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0]) torch.arange(video_grid_thw[num_video_tokens][0])
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) )
.flatten() .flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens] * mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25 * position_id_per_seconds
).long() ).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] t_ntoken_per_chunk = position_id_per_seconds * 2
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = "" placeholder_string = ""

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
"""
sequence_boundaries: list[int]
image_subseq_ids: list[int]
video_subseq_ids: list[int]
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass @dataclass
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
valid_num += 1 valid_num += 1
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
requires_packing_params = self.data_args.neat_packing
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], [] packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], [] packed_images, packed_videos, packed_audios = [], [], []
if requires_packing_params:
sequence_boundaries = [0]
image_subseq_ids: list[int] = []
video_subseq_ids: list[int] = []
audio_subseq_ids: list[int] = []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_images += batch_images[index] packed_images += batch_images[index]
packed_videos += batch_videos[index] packed_videos += batch_videos[index]
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if requires_packing_params:
n_img = len(batch_images[index])
n_vid = len(batch_videos[index])
n_aud = len(batch_audios[index])
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
image_subseq_ids.extend([i] * n_img)
video_subseq_ids.extend([i] * n_vid)
audio_subseq_ids.extend([i] * n_aud)
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
if requires_packing_params:
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
if len(packed_input_ids) != self.data_args.cutoff_len + 1: if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.") raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,
)
model_inputs["packing_params"].append(asdict(packing_params))
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids) model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)

View File

@@ -1113,7 +1113,7 @@ register_template(
register_template( register_template(
name="gpt_oss", name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
default_system="You are ChatGPT, a large language model trained by OpenAI.", default_system="You are ChatGPT, a large language model trained by OpenAI.",
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
@@ -2029,6 +2029,39 @@ register_template(
) )
register_template(
name="qwen3_5",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_5_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
@@ -2218,3 +2251,24 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.", default_system="You are Zephyr, a helpful assistant.",
) )
# copied from glm4_7 template
register_template(
name="aeva",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"You are an AI assistant named Aeva created by Zongzhi Lou. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|user|>", "<|observation|>"],
thought_words=("<think>", "</think>"),
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)

View File

@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>""" """"arguments": <args-json-object>}}\n</tool_call>"""
) )
QWEN35_TOOL_PROMPT = (
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with your current knowledge "
"and do not tell the user about function calls\n</IMPORTANT>"
)
SEED_TOOL_PROMPT = ( SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing " "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
return results return results
class Qwen35ToolUtils(ToolUtils):
r"""Qwen 3.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f"<tool_call>\n<function={name}>"
for key, value in arguments.items():
prompt += f"\n<parameter={key}>"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += f"\n{value}\n</parameter>"
prompt += "\n</function>\n</tool_call>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
results = []
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
for func_name, params_block in re.findall(regex, content):
args_dict = {}
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block.strip()):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value.strip()
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
class GLM4MOEToolUtils(QwenToolUtils): class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template.""" r"""GLM-4-MOE tool using template."""
@@ -662,6 +728,7 @@ TOOLS = {
"minimax2": MiniMaxM2ToolUtils(), "minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(), "mistral": MistralToolUtils(),
"qwen": QwenToolUtils(), "qwen": QwenToolUtils(),
"qwen3_5": Qwen35ToolUtils(),
"glm4_moe": GLM4MOEToolUtils(), "glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(), "seed_oss": SeedToolUtils(),
"ling": LingToolUtils(), "ling": LingToolUtils(),

View File

@@ -69,12 +69,28 @@ MCA_SUPPORTED_MODELS = {
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
"qwen3_5",
"qwen3_5_moe",
} }
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MROPE_MODELS = {
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
}
MULTIMODAL_SUPPORTED_MODELS = set() MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora", "oft"} PEFT_METHODS = {"lora", "oft"}
@@ -2810,6 +2826,66 @@ register_model_group(
) )
register_model_group(
models={
"Qwen3.5-0.8B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
},
"Qwen3.5-2B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
},
"Qwen3.5-4B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
},
"Qwen3.5-9B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
},
"Qwen3.5-35B-A3B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
},
"Qwen3.5-0.8B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
},
"Qwen3.5-2B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
},
"Qwen3.5-4B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
},
"Qwen3.5-9B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
},
"Qwen3.5-27B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
},
"Qwen3.5-35B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
},
"Qwen3.5-122B-A10B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
},
"Qwen3.5-397B-A17B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
},
},
template="qwen3_5",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen2-Audio-7B": { "Qwen2-Audio-7B": {
@@ -3451,3 +3527,35 @@ register_model_group(
}, },
template="zephyr", template="zephyr",
) )
register_model_group(
models={
"Aeva-Flash-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
},
"Aeva-Air-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
},
"Aeva-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
DownloadSource.OPENMIND: "louzongzhi/Aeva",
},
"Aeva-Pro-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
},
"Aeva-Max-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
},
},
template="aeva",
)

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.0.0") check_version("transformers>=4.55.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1") check_version("peft>=0.18.0,<=0.18.1")

View File

@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than from ..extras.packages import is_mcore_adapter_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -100,6 +100,52 @@ def _parse_args(
return tuple(parsed_args) return tuple(parsed_args)
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
"""Validates Trackio-specific arguments.
Args:
training_args: TrainingArguments instance (not a dictionary)
"""
report_to = training_args.report_to
if not report_to:
return
if isinstance(report_to, str):
report_to = [report_to]
if "trackio" not in report_to:
return
# --- Enforce project (required by Trackio) ---
if not training_args.project:
raise ValueError("`--project` must be specified when using Trackio.")
# --- Validate trackio_space_id format ---
space_id = training_args.trackio_space_id
if space_id:
if space_id != "trackio" and "/" not in space_id:
logger.warning(
f"trackio_space_id '{space_id}' should typically be in format "
"'org/space' for Hugging Face Spaces deployment."
)
# --- Inform about default project usage ---
if training_args.project == "huggingface":
logger.info(
"Using default project name 'huggingface'. "
"Consider setting a custom project name with --project "
"for better organization."
)
# --- Validate hub repo privacy flag ---
if training_args.hub_private_repo:
logger.info("Repository will be created as private on Hugging Face Hub.")
# --- Recommend run_name for experiment clarity ---
if not training_args.run_name:
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]: if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: if training_args.report_to and any(
raise ValueError("PPO only accepts wandb or tensorboard logger.") logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
):
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
@@ -346,12 +394,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if model_args.use_kt and is_deepspeed_zero3_enabled(): if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.") raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars() _set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
_verify_trackio_args(training_args)
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
@@ -421,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None training_args.resume_from_checkpoint is None
and training_args.do_train and training_args.do_train
and os.path.isdir(training_args.output_dir) and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)

View File

@@ -147,6 +147,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef:
return return

View File

@@ -37,7 +37,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -45,10 +44,6 @@ import torch.nn.functional as F
from ...extras import logging from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -24,7 +24,6 @@ import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -344,9 +343,7 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -355,9 +352,7 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -390,7 +385,25 @@ _register_composite_model(
"visual.deepstack_merger_list", "visual.deepstack_merger_list",
"audio_tower", "audio_tower",
], ],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_5",
projector_key="model.visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_5_moe",
projector_key="model.visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
@@ -142,7 +141,6 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs) configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable) configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":

View File

@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and getattr(args, "overwrite_output_dir", False)
): ):
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
} }
) )
if "trackio" in args.report_to:
import trackio
trackio.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab: if self.finetuning_args.use_swanlab:
import swanlab # type: ignore import swanlab # type: ignore

View File

@@ -12,4 +12,62 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO override the original trainer from typing import Any
import torch.nn.functional as F
from mcore_adapter.trainer import McaTrainer
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
class CustomMcaTrainer(McaTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@override
def _pad_batched_inputs(self, inputs: dict[str, Tensor | Any], seq_length: int):
r"""Override to avoid padding error when handling 3d posids."""
padding_inputs = {
k: v.tolist() if v is not None and isinstance(v, Tensor) else v
for k, v in inputs.items()
if k in self._language_input_names
}
position_ids_3d = None
if isinstance(inputs.get("position_ids"), Tensor) and inputs["position_ids"].dim() == 3:
position_ids_3d = inputs["position_ids"]
padding_inputs.pop("position_ids", None)
if "labels" in padding_inputs:
padding_inputs["labels"] = [
labels + [IGNORE_INDEX] * (seq_length - len(labels)) for labels in padding_inputs["labels"]
]
tokenizer = (
self.processing_class
if isinstance(self.processing_class, PreTrainedTokenizerBase)
else getattr(self.processing_class, "tokenizer", self.processing_class)
)
padding_side = getattr(tokenizer, "padding_side", "right")
padding_inputs = tokenizer.pad(
padding_inputs,
padding="max_length",
max_length=seq_length,
return_tensors="pt",
).to(self.args.device)
inputs.update(padding_inputs)
if position_ids_3d is not None:
current_seq_len = position_ids_3d.size(-1)
if current_seq_len < seq_length:
pad_len = seq_length - current_seq_len
if padding_side == "left":
position_ids_3d = F.pad(position_ids_3d, (pad_len, 0), value=0)
else:
position_ids_3d = F.pad(position_ids_3d, (0, pad_len), value=0)
inputs["position_ids"] = position_ids_3d.to(self.args.device)
return inputs

View File

@@ -13,10 +13,13 @@
# limitations under the License. # limitations under the License.
import functools import functools
import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ...data import ( from ...data import (
@@ -41,9 +44,10 @@ if not is_mcore_adapter_available():
from mcore_adapter.models import AutoConfig, AutoModel from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig from mcore_adapter.trainer.dpo_config import DPOConfig
from .trainer import CustomMcaTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
@@ -70,37 +74,53 @@ def _data_collator_wrapper(data_collator: Any):
for k in ["attention_mask", "position_ids"]: for k in ["attention_mask", "position_ids"]:
if k in feature: if k in feature:
feature[k] = feature[k][:-1] feature[k] = feature[k][:-1]
return data_collator(features)
# for qwen vl series model
tmp_features = data_collator(features)
tmp_features.pop("rope_deltas", None)
position_ids = tmp_features.get("position_ids", None)
if position_ids is not None and position_ids.dim() == 3:
if position_ids.shape[0] == 4:
position_ids = position_ids[1:]
tmp_features["position_ids"] = position_ids
return tmp_features
return wrapper return wrapper
def _check_model_support(model_args: "ModelArguments"): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
model_type = config.model_type
config = HfAutoConfig.from_pretrained( if model_type not in MCA_SUPPORTED_MODELS:
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError( raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter." f"Model {model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models." "You can try to upgrade mcore_adapter to the latest version for more supported models."
) )
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments.""" """Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]: if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return return
params_to_freeze = [] params_to_freeze = []
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]: if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"]) params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector: if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"]) params_to_freeze.extend(["vision_model.merger"])
if finetuning_args.freeze_language_model: if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"]) params_to_freeze.extend(["embedding", "decoder", "output_layer"])
@@ -110,6 +130,28 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments")
if any(name.startswith(k) for k in params_to_freeze): if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False) p.requires_grad_(False)
def _build_meta_hf_model_for_collator(model_args: "ModelArguments") -> Any | None:
r"""Build a lightweight HF model on meta device for compatibility with collator."""
from transformers import AutoConfig as HfAutoConfig
from transformers import AutoModel as HfAutoModel
from transformers import AutoModelForImageTextToText
try:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
with torch.device("meta"):
try:
# Prefer multimodal auto class for VLMs (e.g. qwen2-vl), so get_rope_index is available.
return AutoModelForImageTextToText.from_config(config)
except Exception:
return HfAutoModel.from_config(config)
except Exception as exc:
logger.warning("Failed to build meta HF model for collator, fallback to no model. Error: %s", exc)
return None
def run_pt( def run_pt(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
@@ -135,7 +177,7 @@ def run_pt(
) )
data_collator = _data_collator_wrapper(data_collator) data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer( trainer = CustomMcaTrainer(
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
@@ -185,6 +227,7 @@ def run_sft(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
collator_model = _build_meta_hf_model_for_collator(model_args)
# optional freezing for qwen_vl series # optional freezing for qwen_vl series
_freeze_model_parameters(model, finetuning_args) _freeze_model_parameters(model, finetuning_args)
@@ -192,6 +235,7 @@ def run_sft(
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
template=template, template=template,
model=collator_model,
padding="max_length" if pad_to_max else "longest", padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None, max_length=data_args.cutoff_len if pad_to_max else None,
pad_to_multiple_of=64, pad_to_multiple_of=64,
@@ -200,7 +244,7 @@ def run_sft(
) )
data_collator = _data_collator_wrapper(data_collator) data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer( trainer = CustomMcaTrainer(
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
@@ -239,6 +283,7 @@ def run_dpo(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
collator_model = _build_meta_hf_model_for_collator(model_args)
_freeze_model_parameters(model, finetuning_args) _freeze_model_parameters(model, finetuning_args)
@@ -262,6 +307,7 @@ def run_dpo(
) )
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
template=template, template=template,
model=collator_model,
pad_to_multiple_of=64, pad_to_multiple_of=64,
padding="max_length" if pad_to_max else "longest", padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None, max_length=data_args.cutoff_len if pad_to_max else None,

View File

@@ -215,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
if len(pad_len): # move pad token to last if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) input_ids_column = dataset["input_ids"]
try:
input_ids_list = input_ids_column.to_pylist()
except AttributeError:
input_ids_list = list(input_ids_column)
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)

View File

@@ -65,6 +65,7 @@ def run_sft(
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,
neat_packing=data_args.neat_packing,
attn_implementation=getattr(model.config, "_attn_implementation", None), attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
**tokenizer_module, **tokenizer_module,

View File

@@ -52,6 +52,7 @@ if is_ray_available():
import ray import ray
from ray.util.placement_group import PlacementGroup, placement_group from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.state import list_nodes
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -941,7 +942,7 @@ def get_ray_remote_config_for_worker(
def get_ray_head_node_ip() -> str: def get_ray_head_node_ip() -> str:
r"""Get the IP address of the Ray head node.""" r"""Get the IP address of the Ray head node."""
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False)) head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
return head_ip return head_ip

View File

@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ..utils.env import is_env_enabled from ..utils.env import is_env_enabled
from ..utils.helper import set_seed
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -66,7 +66,7 @@ class TrainingArguments:
metadata={"help": "Number of workers for batching."}, metadata={"help": "Number of workers for batching."},
) )
enable_activation_checkpointing: bool = field( enable_activation_checkpointing: bool = field(
default=True, default=False,
metadata={"help": "Enable activation checkpointing for training."}, metadata={"help": "Enable activation checkpointing for training."},
) )
dist_config: PluginConfig | None = field( dist_config: PluginConfig | None = field(
@@ -81,6 +81,10 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "Learning rate scheduler configuration for training."}, metadata={"help": "Learning rate scheduler configuration for training."},
) )
seed: int = field(
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -76,7 +76,7 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False}) self.model.gradient_checkpointing_enable({"use_reentrant": False})
self._accelerate_engine = None self._deepspeed_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if dist_name == "deepspeed": if dist_name == "deepspeed":
@@ -108,6 +108,7 @@ class BaseTrainer:
cutoff_len=self.args.cutoff_len, cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers, batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy, batching_strategy=self.args.batching_strategy,
seed=self.args.seed,
) )
def _shard_model(self) -> None: def _shard_model(self) -> None:

View File

@@ -26,6 +26,7 @@
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
import torch
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True, pin_memory: bool = True,
drop_last: bool = True, drop_last: bool = True,
seed: int = 42,
) -> None: ) -> None:
self.dataset = dataset self.dataset = dataset
self.renderer = renderer self.renderer = renderer
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
self.batching_strategy = batching_strategy self.batching_strategy = batching_strategy
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.drop_last = drop_last self.drop_last = drop_last
self.seed = seed
# TODO: support length and infinity # TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP) dp_size = DistributedInterface().get_world_size(Dim.DP)
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
num_replicas=DistributedInterface().get_world_size(Dim.DP), num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP), rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True, shuffle=True,
seed=0, seed=self.seed,
drop_last=self.drop_last, drop_last=self.drop_last,
) )
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch, batch_size=self.micro_batch_size * self.num_micro_batch,
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type, pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
generator=generato_seed,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider) self._length = len(self._data_provider)

View File

@@ -91,7 +91,11 @@ class Renderer:
self.processor = processor self.processor = processor
def render_messages( def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False self,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Apply template to messages and convert them to model input. """Apply template to messages and convert them to model input.
@@ -99,6 +103,7 @@ class Renderer:
messages (list[Message]): The messages to render. messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None. tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False. is_generate (bool, optional): Whether to render for generation. Defaults to False.
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
Returns: Returns:
ModelInput: The rendered model input. ModelInput: The rendered model input.
@@ -108,7 +113,9 @@ class Renderer:
else: else:
from ...plugins.model_plugins.rendering import RenderingPlugin from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) return RenderingPlugin(self.template).render_messages(
self.processor, messages, tools, is_generate, enable_thinking
)
def parse_message(self, generated_text: str) -> Message: def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format. """Parse a message in the template format.

View File

@@ -150,6 +150,9 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
@PeftPlugin("lora").register() @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
if model.device.type == "meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
adapter_name_or_path = config.get("adapter_name_or_path") adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path: if adapter_name_or_path:

View File

@@ -12,224 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import importlib
import re
from ...utils.constants import IGNORE_INDEX from ...utils import logging
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor, ToolCall from ...utils.types import Message, ModelInput, Processor
logger = logging.get_logger(__name__)
class RenderingPlugin(BasePlugin): class RenderingPlugin(BasePlugin):
_attempted_template_imports: set[str] = set()
def _ensure_template_imported(self) -> None:
if self.name is None or self.name in self._attempted_template_imports:
return
full_module_name = f"{__package__}.templates.{self.name}"
self._attempted_template_imports.add(self.name)
try:
importlib.import_module(full_module_name)
except Exception as exc:
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
def __getitem__(self, method_name: str):
self._ensure_template_imported()
return super().__getitem__(method_name)
def render_messages( def render_messages(
self, self,
processor: Processor, processor: Processor,
messages: list[Message], messages: list[Message],
tools: str | None = None, tools: str | None = None,
is_generate: bool = False, is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Render messages in the template format.""" """Render messages in the template format."""
return self["render_messages"](processor, messages, tools, is_generate) return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
def parse_messages(self, generated_text: str) -> Message: def parse_messages(self, generated_text: str) -> Message:
"""Parse messages in the template format.""" """Parse messages in the template format."""
return self["parse_messages"](generated_text) return self["parse_messages"](generated_text)
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,259 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
def _get_last_query_index(messages: list[Message]) -> int:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message["role"] != "user":
continue
user_text = ""
is_plain_text = True
for content in message["content"]:
if content["type"] != "text":
is_plain_text = False
break
user_text += content["value"]
if not is_plain_text:
continue
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
last_query_index = idx
break
return last_query_index
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content = ""
reasoning_content = ""
tool_calls: list[ToolCall] = []
for content in message["content"]:
if content["type"] == "text":
text_content += content["value"]
elif content["type"] == "reasoning":
reasoning_content += content["value"]
elif content["type"] == "tool_call":
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
tool_calls.append(tool_call)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return text_content, reasoning_content, tool_calls
@RenderingPlugin("qwen3").register("render_messages")
def render_qwen3_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
last_query_index = _get_last_query_index(messages)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
else:
temp_str += text_content
for tool_call_idx, tool_call in enumerate(tool_calls):
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
temp_str += "\n"
arguments = tool_call.get("arguments")
if isinstance(arguments, str):
arguments_str = arguments
else:
arguments_str = json.dumps(arguments, ensure_ascii=False)
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ arguments_str
+ "}\n</tool_call>"
)
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking is False:
temp_str += "<think>\n\n</think>\n\n"
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3").register("parse_message")
def parse_qwen3_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "think":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,209 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking:
raise ValueError("The qwen3_nothink template does not support thinking mode.")
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import gc import gc
import os import os
@@ -166,12 +167,11 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
) )
use_gradient_checkpointing = True # Could be configurable # BaseTrainer is the single source of truth for gradient checkpointing.
if use_gradient_checkpointing: # FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if getattr(model, "is_gradient_checkpointing", False):
if self.rank == 0: if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...") logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
@@ -213,10 +213,52 @@ class FSDP2Engine:
return model return model
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
"""Save non-persistent buffers, such as inv_freq."""
saved = {}
for mod_name, module in model.named_modules():
for buf_name in module._non_persistent_buffers_set:
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
buf = getattr(module, buf_name, None)
if buf is not None:
saved[fqn] = copy.deepcopy(buf)
if self.rank == 0 and saved:
logger.info(f"Saved {len(saved)} non-persistent buffers")
return saved
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
"""Register saved non-persistent buffers to model."""
if not saved_buffers:
return
device = get_current_accelerator()
for fqn, buf in saved_buffers.items():
buf = buf.to(device)
if "." in fqn:
parent_fqn, buf_name = fqn.rsplit(".", 1)
parent_module = model.get_submodule(parent_fqn)
else:
buf_name = fqn
parent_module = model
parent_module.register_buffer(buf_name, buf, persistent=False)
if self.rank == 0:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel: def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta": if model.device.type == "meta":
non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
model = self.prepare_model(model) model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None):
model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers)
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
return model return model

View File

@@ -15,12 +15,22 @@
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool: def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer. """Check if processor is tokenizer.

View File

@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict): class Content(TypedDict):
type: Literal["text", "reasoning", "tool_call", "image_url"] type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
"""Type of the content.""" """Type of the content."""
value: str value: str
"""Value of the content.""" """Value of the content."""

View File

@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Column(): with gr.Column():
enable_thinking = gr.Checkbox(value=True) enable_thinking = gr.Checkbox(value=True)
report_to = gr.Dropdown( report_to = gr.Dropdown(
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"], choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
value="none", value="none",
allow_custom_value=True, allow_custom_value=True,
) )
with gr.Accordion("Trackio Settings", open=False):
project = gr.Textbox(
value="huggingface",
label="Project Name",
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
)
trackio_space_id = gr.Textbox(
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
)
hub_private_repo = gr.Checkbox(
value=False, label="Private Repository", info="Make the Hugging Face repository private"
)
input_elems.update( input_elems.update(
{ {
logging_steps, logging_steps,
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro, use_llama_pro,
enable_thinking, enable_thinking,
report_to, report_to,
project,
trackio_space_id,
hub_private_repo,
} }
) )
elem_dict.update( elem_dict.update(
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
enable_thinking=enable_thinking, enable_thinking=enable_thinking,
report_to=report_to, report_to=report_to,
project=project,
trackio_space_id=trackio_space_id,
hub_private_repo=hub_private_repo,
) )
) )

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import Counter
import pytest import pytest
import torch import torch
@@ -22,6 +23,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -116,19 +118,189 @@ def test_multimodal_collator():
"labels": [ "labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q], [0, 1, 2, 3, q, q, q, q, q, q, q, q],
], ],
"position_ids": [ "position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], "rope_deltas": [[0]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image), **tokenizer_module["processor"].image_processor(fake_image),
} }
if not is_transformers_version_greater_than("5.0.0"):
# adapt position_ids and rope_deltas for transformers < 5.0.0
# https://github.com/huggingface/transformers/pull/43972
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
expected_input["rope_deltas"] = [[-8]]
assert batch_input.keys() == expected_input.keys() assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys(): for k in batch_input.keys():
if k == "position_ids" and batch_input[k].dim() == 3 and batch_input[k].shape[0] == 4:
batch_input[k] = batch_input[k][1:]
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def _make_packed_feature(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int | None = None,
vision_end_id: int | None = None,
image_pad_id: int | None = None,
) -> dict:
r"""Build one packed sample using the new PackingParams schema."""
sequence_boundaries = packing_params["sequence_boundaries"]
image_subseq_ids = packing_params["image_subseq_ids"]
video_subseq_ids = packing_params["video_subseq_ids"]
audio_subseq_ids = packing_params["audio_subseq_ids"]
unpadded_length = packing_params["unpadded_length"]
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
cutoff_plus_one = sequence_boundaries[-1]
content_len = unpadded_length
pad_len = right_padding_length
assert content_len + pad_len == cutoff_plus_one
assert sequence_boundaries[0] == 0
assert sequence_boundaries[-1] == cutoff_plus_one
content_ids = list(range(100, 100 + content_len))
if vision_start_id is not None and vision_end_id is not None and image_pad_id is not None:
image_counts_by_subseq = Counter(image_subseq_ids)
for subseq_idx, image_count in sorted(image_counts_by_subseq.items()):
if subseq_idx >= len(sequence_boundaries) - 1:
continue
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_len = subseq_end - subseq_start
if subseq_len < 3:
continue
# Build repeated image groups while preserving at least 3 tokens for each remaining image.
injected_tokens: list[int] = []
remaining = subseq_len
for image_idx in range(image_count):
remaining_images = image_count - image_idx
min_reserved_for_rest = 3 * (remaining_images - 1)
current_group_len = min(6, remaining - min_reserved_for_rest)
if current_group_len < 3:
break
group = [vision_start_id] + [image_pad_id] * max(1, current_group_len - 2) + [vision_end_id]
injected_tokens.extend(group[:current_group_len])
remaining -= current_group_len
if injected_tokens:
insert_end = subseq_start + len(injected_tokens)
content_ids[subseq_start:insert_end] = injected_tokens
input_ids = content_ids + [pad_token_id] * pad_len
attention_mask = [1] * content_len + [0] * pad_len
labels = [label_ignore_id] * cutoff_plus_one
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"images": [fake_image] * len(image_subseq_ids),
"videos": [None] * len(video_subseq_ids),
"audios": [None] * len(audio_subseq_ids),
"packing_params": packing_params,
}
def _make_packed_features(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int,
vision_end_id: int,
image_pad_id: int,
) -> list[dict]:
r"""Build packed features from caller-provided packing_params."""
return [
_make_packed_feature(
packing_params=packing_params,
pad_token_id=pad_token_id,
label_ignore_id=label_ignore_id,
fake_image=fake_image,
vision_start_id=vision_start_id,
vision_end_id=vision_end_id,
image_pad_id=image_pad_id,
)
]
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
bound_list = packing_params["sequence_boundaries"]
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
all_position_ids = []
for i, input_ids_slice in enumerate(input_ids_slices):
img_cnt = img_counts_by_subseq[i]
if sum(attention_mask_slices[i]) == 0:
continue
rope_func_kwargs = {
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
}
position_ids, _ = get_rope_func(**rope_func_kwargs)
all_position_ids.append(position_ids)
return torch.cat(all_position_ids, dim=-1)
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator_with_packing():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
tokenizer_module["tokenizer"].padding_side = "right"
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForImageTextToText.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
model=model,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
tokenizer = tokenizer_module["tokenizer"]
packing_params = {
"sequence_boundaries": [0, 2, 10, 18, 28, 32],
"image_subseq_ids": [1, 2, 3],
"video_subseq_ids": [],
"audio_subseq_ids": [],
"unpadded_length": 28,
"right_padding_length": 4,
}
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = _make_packed_features(
packing_params=packing_params,
pad_token_id=tokenizer.pad_token_id,
label_ignore_id=IGNORE_INDEX,
fake_image=fake_image,
vision_start_id=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
vision_end_id=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
image_pad_id=tokenizer.convert_tokens_to_ids("<|image_pad|>"),
)
expected_position_ids = _get_expected_position_ids(
packing_params,
data_collator.get_rope_func,
features[0]["input_ids"],
features[0]["attention_mask"],
)
batch_input = data_collator(features) # [3, bsz, seq_len]
valid_len = expected_position_ids.shape[-1]
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu"])
def test_4d_attention_mask(): def test_4d_attention_mask():
o = 0.0 o = 0.0

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.106 0.9.5.107

View File

@@ -0,0 +1,104 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
Validates that the FSDP2 meta loading path behaves correctly for tied weights
and non-persistent buffers by comparing it with the standard non-meta path.
"""
import torch
from transformers import AutoConfig
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
TINY_MODEL = "llamafactory/tiny-random-qwen3"
def collect_non_persistent_buffers(model):
"""Collect all non-persistent buffers from model."""
result = {}
for mod_name, module in model.named_modules():
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
buf = getattr(module, buf_name, None)
if buf is not None:
result[fqn] = buf.detach().cpu().clone()
return result
def test_fsdp2_meta_loading_buffers_and_tied_weights():
"""Verify non-persistent buffers and tied weights consistency after meta load."""
# 1. Initialize DistributedInterface for single process
DistributedInterface()
# 2. Build FSDP2Engine config
engine = FSDP2Engine(
{
"name": "fsdp2",
"mixed_precision": "bf16",
"reshard_after_forward": True,
"offload_params": False,
"pin_memory": False,
"dcp_path": None,
}
)
config = AutoConfig.from_pretrained(TINY_MODEL)
# --- NORMAL PATH ---
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
normal_engine = ModelEngine(model_args=normal_args)
normal_model = normal_engine.model.to(torch.bfloat16)
normal_model = engine.shard_model(normal_model)
normal_non_persistent = collect_non_persistent_buffers(normal_model)
del normal_model
# --- META PATH ---
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
meta_model_engine = ModelEngine(model_args=meta_args)
meta_model = meta_model_engine.model
assert meta_model.device.type == "meta", "Model should be on meta device"
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
meta_model = engine.shard_model(meta_model)
meta_non_persistent = collect_non_persistent_buffers(meta_model)
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
if tie_word_embeddings:
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
"Weights should be tied after loading"
)
del meta_model
# 4. Non-persistent buffers (e.g., inv_freq)
normal_buf_keys = set(normal_non_persistent.keys())
meta_buf_keys = set(meta_non_persistent.keys())
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
for key in sorted(normal_buf_keys & meta_buf_keys):
nb = normal_non_persistent[key]
mb = meta_non_persistent[key]
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"