11 Commits

Author SHA1 Message Date
Kingsley
833f6027b1 [fix] fit neat_packing & mrope model packing (#10283)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-20 16:50:11 +08:00
robertglools
d91d8af89e [data] add SGSC zero-hallucination B2B dataset (NOO-Protocol) (#10284)
Co-authored-by: GloolsGuan <GloolsGuan@gmail.com>
2026-03-20 15:49:03 +08:00
xxddccaa
e67ab9e2f2 fix:MiniCPMVPlugin IndexError in process_messages when training with video (#10276)
Co-authored-by: xxddccaa <xxddccaa@users.noreply.github.com>
2026-03-18 19:18:06 +08:00
LincolnBurrows2017
2c4f121817 [fix] handle empty content list in system message (#10291)
Co-authored-by: AI Assistant <assistant@example.com>
2026-03-18 12:05:49 +08:00
xvxuopop
487f8b8191 [v1] add qwen3 templates and fix rendering plugin. (#10212)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-03-18 11:30:50 +08:00
SnowCharm
78cad1e332 [fix] unused keys in ray example (#10290) 2026-03-18 00:23:53 +08:00
LincolnBurrows2017
70653026f5 [fix] make position_id_per_seconds configurable for Qwen2OmniPlugin (#10281)
Co-authored-by: LincolnBurrows2017 <lincoln@example.com>
2026-03-16 19:42:38 +08:00
Ruijie Hou
246192abd2 [data] correct gpt_oss template format_assistant (#10269) 2026-03-10 21:36:38 +08:00
浮梦
0258dc14d0 [docker] update npu docker (#10268)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-10 19:37:27 +08:00
xxddccaa
3045adf0ba [fix] fallback to audio_processor when feature_extractor is missing (#10267)
Co-authored-by: kevin <742971636@qq.com>
2026-03-10 19:36:41 +08:00
Kingsley
a3d44e3152 [mca] support qwen3.5 (#10265)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-10 10:55:16 +08:00
31 changed files with 1096 additions and 329 deletions

View File

@@ -35,15 +35,12 @@ jobs:
transformers: transformers:
- "" - ""
include: # test backward compatibility include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.51.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.53.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.55.0" transformers: "4.55.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.57.1"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@@ -236,6 +236,13 @@
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4", "ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
"formatting": "sharegpt" "formatting": "sharegpt"
}, },
"sgsc_b2b_entities": {
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
}
},
"ultrachat_200k": { "ultrachat_200k": {
"hf_hub_url": "HuggingFaceH4/ultrachat_200k", "hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k", "ms_hub_url": "AI-ModelScope/ultrachat_200k",

View File

@@ -1,6 +1,6 @@
# https://hub.docker.com/r/ascendai/cann/tags # https://hub.docker.com/r/ascendai/cann/tags
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11 ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
COPY . /app COPY . /app
# Install torch-npu # Install torch-npu
RUN pip uninstall -y torch torchvision torchaudio && \ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \ RUN pip uninstall -y torch torchvision torchaudio
pip install --no-cache-dir -e . --no-build-isolation && \ RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
RUN pip install --no-cache-dir -e . --no-build-isolation && \
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
# Set up volumes # Set up volumes

View File

@@ -33,7 +33,7 @@ services:
dockerfile: ./docker/docker-npu/Dockerfile dockerfile: ./docker/docker-npu/Dockerfile
context: ../.. context: ../..
args: args:
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11 BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory-a3 container_name: llamafactory-a3
image: llamafactory:npu-a3 image: llamafactory:npu-a3

View File

@@ -28,12 +28,7 @@ save_only_model: false
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray ### ray
ray_run_name: qwen3_4b_sft_lora
ray_storage_path: ./saves
ray_num_workers: 4 # Number of GPUs to use. ray_num_workers: 4 # Number of GPUs to use.
placement_strategy: PACK
resources_per_worker:
GPU: 1
# ray_init_kwargs: # ray_init_kwargs:
# runtime_env: # runtime_env:
# env_vars: # env_vars:

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0", "transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1", "peft>=0.18.0,<=0.18.1",

View File

@@ -1,4 +1,4 @@
torch==2.7.1 torch==2.7.1
torch-npu==2.7.1 torch-npu==2.7.1.post2
torchvision==0.22.1 torchvision==0.22.1
torchaudio==2.7.1 torchaudio==2.7.1

View File

@@ -71,6 +71,7 @@ def convert(
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1, expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None, virtual_pipeline_model_parallel_size: int | None = None,
moe_grouped_gemm: bool | None = None,
): ):
"""Convert checkpoint between MCA and HuggingFace formats. """Convert checkpoint between MCA and HuggingFace formats.
@@ -84,6 +85,10 @@ def convert(
pipeline_model_parallel_size: Pipeline model parallel size pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
Must match the format used when saving the checkpoint.
""" """
if bf16 and fp16: if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.") raise ValueError("bf16 and fp16 cannot be both True.")
@@ -97,8 +102,9 @@ def convert(
pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
moe_grouped_gemm=moe_grouped_gemm,
transformer_impl="transformer_engine", # hard code here since we default using te for training
) )
convert_checkpoint_to_mca( convert_checkpoint_to_mca(
checkpoint_path, checkpoint_path,
output_path, output_path,

View File

@@ -88,7 +88,10 @@ def _process_request(
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content if isinstance(content, list):
system = content[0].text if content else ""
else:
system = content
else: else:
system = None system = None

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -25,7 +26,7 @@ import torch.nn.functional as F
from peft import PeftModel from peft import PeftModel
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
@@ -39,6 +40,56 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
def _slice_mm_inputs_for_sample(
mm_inputs: dict[str, Any],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_idx: int,
images_per_subseq: Optional[list[int]] = None,
videos_per_subseq: Optional[list[int]] = None,
subseq_idx: Optional[int] = None,
) -> dict[str, Any]:
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
is given, further restrict to that sub-seq's counts via packed_*_counts.
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
"""
image_start_idx = sum(batch_imglens[:batch_idx])
image_end_idx = sum(batch_imglens[: batch_idx + 1])
video_start_idx = sum(batch_vidlens[:batch_idx])
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
if subseq_idx is not None and images_per_subseq is not None:
image_start_idx += sum(images_per_subseq[:subseq_idx])
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
if subseq_idx is not None and videos_per_subseq is not None:
video_start_idx += sum(videos_per_subseq[:subseq_idx])
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
sliced_mm_inputs: dict[str, Any] = {}
key_to_slice_meta = {
"image_grid_thw": (image_start_idx, image_end_idx, True),
"video_grid_thw": (video_start_idx, video_end_idx, True),
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
}
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
if key not in mm_inputs:
continue
mm_value = mm_inputs[key]
if mm_value is not None and end_idx > start_idx:
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
elif assign_none_when_empty:
sliced_mm_inputs[key] = None
return sliced_mm_inputs
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""Expand 2d attention mask to 4d attention mask. r"""Expand 2d attention mask to 4d attention mask.
@@ -106,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
else: else:
self.get_rope_func = None self.get_rope_func = None
def _compute_rope_position_ids(
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
) -> None:
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if features["attention_mask"].sum() == 0:
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen vl
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
def _compute_rope_position_ids_with_packing(
self,
features: dict[str, "torch.Tensor"],
mm_inputs: dict[str, Any],
packing_params_list: list[dict[str, Any] | None],
batch_imglens: list[int],
batch_vidlens: list[int],
batch_audlens: list[int],
has_dummy_image: bool,
) -> None:
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
bsz = features["input_ids"].size(0)
seq_len = features["input_ids"].size(1)
all_position_ids: list[torch.Tensor] = []
all_rope_deltas: list[torch.Tensor] = []
if has_dummy_image:
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs)
for sample_idx in range(bsz):
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
sequence_boundaries = sample_packing.get("sequence_boundaries")
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
images_per_subseq = (
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
)
videos_per_subseq = (
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
)
if has_dummy_image:
mm_inputs = {}
if num_sub_seqs <= 1:
sample_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
}
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
)
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
all_position_ids.append(sample_features["position_ids"])
all_rope_deltas.append(sample_features["rope_deltas"])
else:
# when we do packing, don't need rope_deltas when training.
sample_position_ids: list[torch.Tensor] = []
for subseq_idx in range(num_sub_seqs):
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_features = {
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
}
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
mm_inputs,
batch_imglens,
batch_vidlens,
sample_idx,
images_per_subseq,
videos_per_subseq,
subseq_idx
)
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
if has_dummy_image:
mm_inputs = dummy_mm_inputs
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
all_position_ids[0].size(0),
bsz,
seq_len,
)
# Check if position_ids shape matches expected shape.
# for further usage, we should padding to the right when some padding token on the right.
if has_dummy_image:
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
if features["position_ids"].shape != expected_position_ids_shape:
raise ValueError(
"Merged position_ids shape mismatch: "
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
packing_params_list: list[dict[str, Any] | None] = []
for feature in features: for feature in features:
images = feature.pop("images", None) or [] images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or [] videos = feature.pop("videos", None) or []
@@ -120,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos)) batch_vidlens.append(len(videos))
batch_audlens.append(len(audios)) batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
packing_params_list.append(feature.pop("packing_params", None))
fake_input_ids = [] fake_input_ids = []
has_dummy_image = False
if ( if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case ): # avoid process hanging in zero3/fsdp case
@@ -137,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_input_ids.extend(_fake_input_ids) fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images batch_images = fake_images
batch_imglens[0] = 1 batch_imglens[0] = 1
has_dummy_image = True
if ( if (
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0 self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
@@ -183,57 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: dict[str, torch.Tensor] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
bsz, seq_len = features["input_ids"].shape[:2]
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
is_omni = model_type in [
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
]
if self.get_rope_func is not None: if self.get_rope_func is not None:
rope_index_kwargs = { # for mmrope situation, we should calculate position_ids and rope_deltas per sample.
"input_ids": features["input_ids"], # When neat_packing is on, each sample has packing_params; None means no packing for that sample.
"image_grid_thw": mm_inputs.get("image_grid_thw"), boundaries_list = [
"video_grid_thw": mm_inputs.get("video_grid_thw"), p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
"attention_mask": (features["attention_mask"] >= 1).float(), ]
} has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters: if has_dummy_image and has_packing:
image_token_id = getattr(self.model.config, "image_token_id", None) # FIXME: too tricky, need to be refactored
video_token_id = getattr(self.model.config, "video_token_id", None) features["has_dummy_image"] = True
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]: # When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) if not has_packing:
feature_attention_mask = mm_inputs.get("feature_attention_mask", None) self._compute_rope_position_ids(features, mm_inputs)
if feature_attention_mask is not None: # FIXME: need to get video image lengths else:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) if is_omni:
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input raise RuntimeError("Omni models are not supported for packed sequences for now.")
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) self._compute_rope_position_ids_with_packing(
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( features,
dim=-1 mm_inputs,
).unsqueeze(-1) packing_params_list,
else: # for qwen vl batch_imglens,
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) batch_vidlens,
batch_audlens,
has_dummy_image,
)
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
if features["position_ids"].dim() == 3:
features["position_ids"] = torch.cat(
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
)
if ( if (
self.model is not None self.model is not None
and getattr(self.model.config, "model_type", None) and getattr(self.model.config, "model_type", None) in MROPE_MODELS
in [
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_5",
"qwen3_vl",
"qwen3_vl_moe",
]
and ("position_ids" not in features or features["position_ids"].dim() != 3) and ("position_ids" not in features or features["position_ids"].dim() != 3)
): ):
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.") raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
@@ -261,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
block_diag_attn: bool = False block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
neat_packing: bool = False
def __post_init__(self):
super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
@staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None:
r"""Trim padded positions for packed FA2 batches."""
attention_mask = features.get("attention_mask")
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
return
seq_len = attention_mask.size(1)
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
if non_padding_indices.numel() == seq_len:
return
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
for key, value in list(features.items()):
if not torch.is_tensor(value):
continue
if key == "position_ids" and value.size(-1) == seq_len:
features[key] = value.index_select(-1, non_padding_indices)
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
features[key] = value.index_select(1, non_padding_indices)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
has_dummy_image = features.pop("has_dummy_image", False)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
if not has_dummy_image:
self._unpad_packed_features(features)
features["attention_mask"] = None # let transformers handle causal packed mask.
for key, value in features.items(): # cast data dtype for paligemma for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value): if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype) features[key] = value.to(self.compute_dtype)

View File

@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
from transformers.models.mllama.processing_mllama import ( from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from transformers.video_utils import make_batched_videos
from typing_extensions import override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -47,13 +48,6 @@ if is_pyav_available():
import av import av
if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
else:
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
if TYPE_CHECKING: if TYPE_CHECKING:
from av.stream import Stream from av.stream import Stream
from numpy.typing import NDArray from numpy.typing import NDArray
@@ -161,7 +155,9 @@ class MMPluginMixin:
video_processor: BaseImageProcessor = getattr( video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None) processor, "video_processor", getattr(processor, "image_processor", None)
) )
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
@@ -390,7 +386,9 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt")) mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0: if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True, chunk_input=True,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000), sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
) )
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] audio_feature_lens = [
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False): if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs}) mm_inputs.update({"audio_phs": audio_phs})
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
processor, "audio_processor", None
)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
) )
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = ( video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0]) torch.arange(video_grid_thw[num_video_tokens][0])
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) )
.flatten() .flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens] * mm_inputs["video_second_per_grid"][num_video_tokens]
* 25 # FIXME hardcode of position_id_per_seconds=25 * position_id_per_seconds
).long() ).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] t_ntoken_per_chunk = position_id_per_seconds * 2
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = "" placeholder_string = ""

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
"""
sequence_boundaries: list[int]
image_subseq_ids: list[int]
video_subseq_ids: list[int]
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass @dataclass
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
valid_num += 1 valid_num += 1
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
requires_packing_params = self.data_args.neat_packing
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], [] packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], [] packed_images, packed_videos, packed_audios = [], [], []
if requires_packing_params:
sequence_boundaries = [0]
image_subseq_ids: list[int] = []
video_subseq_ids: list[int] = []
audio_subseq_ids: list[int] = []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_images += batch_images[index] packed_images += batch_images[index]
packed_videos += batch_videos[index] packed_videos += batch_videos[index]
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if requires_packing_params:
n_img = len(batch_images[index])
n_vid = len(batch_videos[index])
n_aud = len(batch_audios[index])
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
image_subseq_ids.extend([i] * n_img)
video_subseq_ids.extend([i] * n_vid)
audio_subseq_ids.extend([i] * n_aud)
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else: else:
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
if requires_packing_params:
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
if len(packed_input_ids) != self.data_args.cutoff_len + 1: if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.") raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,
)
model_inputs["packing_params"].append(asdict(packing_params))
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids) model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)

View File

@@ -1113,7 +1113,7 @@ register_template(
register_template( register_template(
name="gpt_oss", name="gpt_oss",
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
default_system="You are ChatGPT, a large language model trained by OpenAI.", default_system="You are ChatGPT, a large language model trained by OpenAI.",
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),

View File

@@ -69,12 +69,28 @@ MCA_SUPPORTED_MODELS = {
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
"qwen3_5",
"qwen3_5_moe",
} }
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MROPE_MODELS = {
"glm4v",
"glm_ocr",
"Keye",
"qwen2_vl",
"qwen2_5_vl",
"qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
}
MULTIMODAL_SUPPORTED_MODELS = set() MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora", "oft"} PEFT_METHODS = {"lora", "oft"}

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.2.0") check_version("transformers>=4.55.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1") check_version("peft>=0.18.0,<=0.18.1")

View File

@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than from ..extras.packages import is_mcore_adapter_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -394,9 +394,6 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if model_args.use_kt and is_deepspeed_zero3_enabled(): if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.") raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars() _set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
@@ -470,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None training_args.resume_from_checkpoint is None
and training_args.do_train and training_args.do_train
and os.path.isdir(training_args.output_dir) and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)

View File

@@ -37,7 +37,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -45,10 +44,6 @@ import torch.nn.functional as F
from ...extras import logging from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")

View File

@@ -24,7 +24,6 @@ import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -344,9 +343,7 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
@@ -355,9 +352,7 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"] language_model_keys=["language_model", "lm_head"],
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )

View File

@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
@@ -142,7 +141,6 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs) configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable) configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":

View File

@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and getattr(args, "overwrite_output_dir", False)
): ):
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))

View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -77,20 +79,25 @@ def _data_collator_wrapper(data_collator: Any):
def _check_model_support(model_args: "ModelArguments"): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
model_type = config.model_type
config = HfAutoConfig.from_pretrained( if model_type not in MCA_SUPPORTED_MODELS:
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError( raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter." f"Model {model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models." "You can try to upgrade mcore_adapter to the latest version for more supported models."
) )
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments.""" """Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]: if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return return
params_to_freeze = [] params_to_freeze = []

View File

@@ -65,6 +65,7 @@ def run_sft(
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,
neat_packing=data_args.neat_packing,
attn_implementation=getattr(model.config, "_attn_implementation", None), attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
**tokenizer_module, **tokenizer_module,

View File

@@ -50,9 +50,9 @@ if is_apollo_available():
if is_ray_available(): if is_ray_available():
import ray import ray
from ray.util.state import list_nodes
from ray.util.placement_group import PlacementGroup, placement_group from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.state import list_nodes
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -91,7 +91,11 @@ class Renderer:
self.processor = processor self.processor = processor
def render_messages( def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False self,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Apply template to messages and convert them to model input. """Apply template to messages and convert them to model input.
@@ -99,6 +103,7 @@ class Renderer:
messages (list[Message]): The messages to render. messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None. tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False. is_generate (bool, optional): Whether to render for generation. Defaults to False.
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
Returns: Returns:
ModelInput: The rendered model input. ModelInput: The rendered model input.
@@ -108,7 +113,9 @@ class Renderer:
else: else:
from ...plugins.model_plugins.rendering import RenderingPlugin from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) return RenderingPlugin(self.template).render_messages(
self.processor, messages, tools, is_generate, enable_thinking
)
def parse_message(self, generated_text: str) -> Message: def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format. """Parse a message in the template format.

View File

@@ -12,224 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import importlib
import re
from ...utils.constants import IGNORE_INDEX from ...utils import logging
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor, ToolCall from ...utils.types import Message, ModelInput, Processor
logger = logging.get_logger(__name__)
class RenderingPlugin(BasePlugin): class RenderingPlugin(BasePlugin):
_attempted_template_imports: set[str] = set()
def _ensure_template_imported(self) -> None:
if self.name is None or self.name in self._attempted_template_imports:
return
full_module_name = f"{__package__}.templates.{self.name}"
self._attempted_template_imports.add(self.name)
try:
importlib.import_module(full_module_name)
except Exception as exc:
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
def __getitem__(self, method_name: str):
self._ensure_template_imported()
return super().__getitem__(method_name)
def render_messages( def render_messages(
self, self,
processor: Processor, processor: Processor,
messages: list[Message], messages: list[Message],
tools: str | None = None, tools: str | None = None,
is_generate: bool = False, is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput: ) -> ModelInput:
"""Render messages in the template format.""" """Render messages in the template format."""
return self["render_messages"](processor, messages, tools, is_generate) return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
def parse_messages(self, generated_text: str) -> Message: def parse_messages(self, generated_text: str) -> Message:
"""Parse messages in the template format.""" """Parse messages in the template format."""
return self["parse_messages"](generated_text) return self["parse_messages"](generated_text)
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,13 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,259 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
def _get_last_query_index(messages: list[Message]) -> int:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message["role"] != "user":
continue
user_text = ""
is_plain_text = True
for content in message["content"]:
if content["type"] != "text":
is_plain_text = False
break
user_text += content["value"]
if not is_plain_text:
continue
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
last_query_index = idx
break
return last_query_index
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content = ""
reasoning_content = ""
tool_calls: list[ToolCall] = []
for content in message["content"]:
if content["type"] == "text":
text_content += content["value"]
elif content["type"] == "reasoning":
reasoning_content += content["value"]
elif content["type"] == "tool_call":
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
tool_calls.append(tool_call)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return text_content, reasoning_content, tool_calls
@RenderingPlugin("qwen3").register("render_messages")
def render_qwen3_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
last_query_index = _get_last_query_index(messages)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
else:
temp_str += text_content
for tool_call_idx, tool_call in enumerate(tool_calls):
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
temp_str += "\n"
arguments = tool_call.get("arguments")
if isinstance(arguments, str):
arguments_str = arguments
else:
arguments_str = json.dumps(arguments, ensure_ascii=False)
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ arguments_str
+ "}\n</tool_call>"
)
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking is False:
temp_str += "<think>\n\n</think>\n\n"
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3").register("parse_message")
def parse_qwen3_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "think":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -0,0 +1,209 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ....utils.constants import IGNORE_INDEX
from ....utils.helper import get_tokenizer
from ....utils.types import Message, ModelInput, Processor, ToolCall
from ..rendering import RenderingPlugin
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def _concat_text_content(message: Message) -> str:
"""Concatenate text fields in a message."""
message_text = ""
for content in message["content"]:
if content["type"] == "text":
message_text += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
return message_text
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen3_nothink_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
enable_thinking: bool = False,
) -> ModelInput:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
temp_str += _concat_text_content(messages[0]) + "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
if enable_thinking:
raise ValueError("The qwen3_nothink template does not support thinking mode.")
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen3_nothink_message(generated_text: str) -> Message:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict): class Content(TypedDict):
type: Literal["text", "reasoning", "tool_call", "image_url"] type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
"""Type of the content.""" """Type of the content."""
value: str value: str
"""Value of the content.""" """Value of the content."""

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import Counter
import pytest import pytest
import torch import torch
@@ -129,9 +130,177 @@ def test_multimodal_collator():
assert batch_input.keys() == expected_input.keys() assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys(): for k in batch_input.keys():
if k == "position_ids" and batch_input[k].dim() == 3 and batch_input[k].shape[0] == 4:
batch_input[k] = batch_input[k][1:]
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def _make_packed_feature(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int | None = None,
vision_end_id: int | None = None,
image_pad_id: int | None = None,
) -> dict:
r"""Build one packed sample using the new PackingParams schema."""
sequence_boundaries = packing_params["sequence_boundaries"]
image_subseq_ids = packing_params["image_subseq_ids"]
video_subseq_ids = packing_params["video_subseq_ids"]
audio_subseq_ids = packing_params["audio_subseq_ids"]
unpadded_length = packing_params["unpadded_length"]
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
cutoff_plus_one = sequence_boundaries[-1]
content_len = unpadded_length
pad_len = right_padding_length
assert content_len + pad_len == cutoff_plus_one
assert sequence_boundaries[0] == 0
assert sequence_boundaries[-1] == cutoff_plus_one
content_ids = list(range(100, 100 + content_len))
if vision_start_id is not None and vision_end_id is not None and image_pad_id is not None:
image_counts_by_subseq = Counter(image_subseq_ids)
for subseq_idx, image_count in sorted(image_counts_by_subseq.items()):
if subseq_idx >= len(sequence_boundaries) - 1:
continue
subseq_start = sequence_boundaries[subseq_idx]
subseq_end = sequence_boundaries[subseq_idx + 1]
subseq_len = subseq_end - subseq_start
if subseq_len < 3:
continue
# Build repeated image groups while preserving at least 3 tokens for each remaining image.
injected_tokens: list[int] = []
remaining = subseq_len
for image_idx in range(image_count):
remaining_images = image_count - image_idx
min_reserved_for_rest = 3 * (remaining_images - 1)
current_group_len = min(6, remaining - min_reserved_for_rest)
if current_group_len < 3:
break
group = [vision_start_id] + [image_pad_id] * max(1, current_group_len - 2) + [vision_end_id]
injected_tokens.extend(group[:current_group_len])
remaining -= current_group_len
if injected_tokens:
insert_end = subseq_start + len(injected_tokens)
content_ids[subseq_start:insert_end] = injected_tokens
input_ids = content_ids + [pad_token_id] * pad_len
attention_mask = [1] * content_len + [0] * pad_len
labels = [label_ignore_id] * cutoff_plus_one
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"images": [fake_image] * len(image_subseq_ids),
"videos": [None] * len(video_subseq_ids),
"audios": [None] * len(audio_subseq_ids),
"packing_params": packing_params,
}
def _make_packed_features(
*,
packing_params: dict,
pad_token_id: int,
label_ignore_id: int,
fake_image: Image.Image,
vision_start_id: int,
vision_end_id: int,
image_pad_id: int,
) -> list[dict]:
r"""Build packed features from caller-provided packing_params."""
return [
_make_packed_feature(
packing_params=packing_params,
pad_token_id=pad_token_id,
label_ignore_id=label_ignore_id,
fake_image=fake_image,
vision_start_id=vision_start_id,
vision_end_id=vision_end_id,
image_pad_id=image_pad_id,
)
]
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
bound_list = packing_params["sequence_boundaries"]
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
all_position_ids = []
for i, input_ids_slice in enumerate(input_ids_slices):
img_cnt = img_counts_by_subseq[i]
if sum(attention_mask_slices[i]) == 0:
continue
rope_func_kwargs = {
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
}
position_ids, _ = get_rope_func(**rope_func_kwargs)
all_position_ids.append(position_ids)
return torch.cat(all_position_ids, dim=-1)
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator_with_packing():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
tokenizer_module["tokenizer"].padding_side = "right"
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForImageTextToText.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
model=model,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
tokenizer = tokenizer_module["tokenizer"]
packing_params = {
"sequence_boundaries": [0, 2, 10, 18, 28, 32],
"image_subseq_ids": [1, 2, 3],
"video_subseq_ids": [],
"audio_subseq_ids": [],
"unpadded_length": 28,
"right_padding_length": 4,
}
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = _make_packed_features(
packing_params=packing_params,
pad_token_id=tokenizer.pad_token_id,
label_ignore_id=IGNORE_INDEX,
fake_image=fake_image,
vision_start_id=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
vision_end_id=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
image_pad_id=tokenizer.convert_tokens_to_ids("<|image_pad|>"),
)
expected_position_ids = _get_expected_position_ids(
packing_params,
data_collator.get_rope_func,
features[0]["input_ids"],
features[0]["attention_mask"],
)
batch_input = data_collator(features) # [3, bsz, seq_len]
valid_len = expected_position_ids.shape[-1]
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
@pytest.mark.runs_on(["cpu"]) @pytest.mark.runs_on(["cpu"])
def test_4d_attention_mask(): def test_4d_attention_mask():
o = 0.0 o = 0.0