mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-27 14:03:09 +00:00
Compare commits
30 Commits
d3bf882e87
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5afabe3d2 | ||
|
|
df2e6edb7e | ||
|
|
d02fcd3588 | ||
|
|
c340aa2a33 | ||
|
|
1e536733c6 | ||
|
|
97d479fa92 | ||
|
|
ffbff33af3 | ||
|
|
833f6027b1 | ||
|
|
d91d8af89e | ||
|
|
e67ab9e2f2 | ||
|
|
2c4f121817 | ||
|
|
487f8b8191 | ||
|
|
78cad1e332 | ||
|
|
70653026f5 | ||
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 | ||
|
|
9501c3308a | ||
|
|
0ee1c42c2b | ||
|
|
3061f48d55 | ||
|
|
2d9bd2aa14 | ||
|
|
c0245c43fc | ||
|
|
eb976d75a2 | ||
|
|
b5cb7cb0e6 | ||
|
|
0779846513 | ||
|
|
45d335c709 | ||
|
|
816480012f |
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -35,15 +35,12 @@ jobs:
|
||||
transformers:
|
||||
- ""
|
||||
include: # test backward compatibility
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.55.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.57.1"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
6
.github/workflows/tests_npu.yml
vendored
6
.github/workflows/tests_npu.yml
vendored
@@ -49,6 +49,12 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set nginx-cache for Ascend CI
|
||||
run: |
|
||||
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
|
||||
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
|
||||
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
|
||||
@@ -319,7 +319,7 @@ Read technical notes:
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -473,7 +473,7 @@ huggingface-cli login
|
||||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -321,7 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -475,7 +475,7 @@ huggingface-cli login
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -236,6 +236,13 @@
|
||||
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"sgsc_b2b_entities": {
|
||||
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
}
|
||||
},
|
||||
"ultrachat_200k": {
|
||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# https://hub.docker.com/r/ascendai/cann/tags
|
||||
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RUN pip uninstall -y torch torchvision torchaudio
|
||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
|
||||
@@ -33,7 +33,7 @@ services:
|
||||
dockerfile: ./docker/docker-npu/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
|
||||
@@ -28,12 +28,7 @@ save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### ray
|
||||
ray_run_name: qwen3_4b_sft_lora
|
||||
ray_storage_path: ./saves
|
||||
ray_num_workers: 4 # Number of GPUs to use.
|
||||
placement_strategy: PACK
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
# ray_init_kwargs:
|
||||
# runtime_env:
|
||||
# env_vars:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -14,16 +13,12 @@ dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
init_config:
|
||||
name: init_on_meta
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
|
||||
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
cp_mode: ulysses
|
||||
cp_size: 2
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_ulysses_cp
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
|
||||
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
init_config:
|
||||
name: init_on_rank0
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -1,5 +1,4 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||
"transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.18.0,<=0.18.1",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torch-npu==2.7.1.post2
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
|
||||
@@ -71,6 +71,7 @@ def convert(
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
moe_grouped_gemm: bool | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
@@ -84,6 +85,10 @@ def convert(
|
||||
pipeline_model_parallel_size: Pipeline model parallel size
|
||||
expert_model_parallel_size: Expert model parallel size
|
||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||
Must match the format used when saving the checkpoint.
|
||||
"""
|
||||
if bf16 and fp16:
|
||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||
@@ -97,8 +102,9 @@ def convert(
|
||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||
expert_model_parallel_size=expert_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||
moe_grouped_gemm=moe_grouped_gemm,
|
||||
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||
)
|
||||
|
||||
convert_checkpoint_to_mca(
|
||||
checkpoint_path,
|
||||
output_path,
|
||||
|
||||
@@ -154,25 +154,24 @@ def vllm_infer(
|
||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||
|
||||
for j in range(len(batch["input_ids"])):
|
||||
multi_modal_data = {}
|
||||
video_metadata_kwargs = None
|
||||
|
||||
if batch["images"][j] is not None:
|
||||
image = batch["images"][j]
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video_metadata, video_metadata_kwargs = None, None
|
||||
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
|
||||
if batch["videos"][j] is not None:
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
if need_video_kwargs:
|
||||
container = av.open(video[0], "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
||||
video_backend="opencv",
|
||||
)
|
||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||
elif batch["audios"][j] is not None:
|
||||
|
||||
if batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||
if video_metadata_kwargs is not None:
|
||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||
|
||||
vllm_inputs.append(vllm_input_data)
|
||||
|
||||
@@ -88,7 +88,10 @@ def _process_request(
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
content = request.messages.pop(0).content
|
||||
system = content[0].text if isinstance(content, list) else content
|
||||
if isinstance(content, list):
|
||||
system = content[0].text if content else ""
|
||||
else:
|
||||
system = content
|
||||
else:
|
||||
system = None
|
||||
|
||||
|
||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
||||
else self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
multi_modal_data = {}
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
}
|
||||
elif videos is not None:
|
||||
multi_modal_data = {
|
||||
"video": self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif audios is not None:
|
||||
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
|
||||
if videos is not None:
|
||||
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
|
||||
if audios is not None:
|
||||
audio_data = self.template.mm_plugin._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=self.model_args.audio_sampling_rate,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
result_generator = self.model.generate(
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
@@ -25,7 +26,7 @@ import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -39,6 +40,56 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
def _slice_mm_inputs_for_sample(
|
||||
mm_inputs: dict[str, Any],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_idx: int,
|
||||
images_per_subseq: Optional[list[int]] = None,
|
||||
videos_per_subseq: Optional[list[int]] = None,
|
||||
subseq_idx: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
|
||||
|
||||
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
|
||||
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
|
||||
is given, further restrict to that sub-seq's counts via packed_*_counts.
|
||||
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
|
||||
"""
|
||||
image_start_idx = sum(batch_imglens[:batch_idx])
|
||||
image_end_idx = sum(batch_imglens[: batch_idx + 1])
|
||||
video_start_idx = sum(batch_vidlens[:batch_idx])
|
||||
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
|
||||
|
||||
if subseq_idx is not None and images_per_subseq is not None:
|
||||
image_start_idx += sum(images_per_subseq[:subseq_idx])
|
||||
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
|
||||
|
||||
if subseq_idx is not None and videos_per_subseq is not None:
|
||||
video_start_idx += sum(videos_per_subseq[:subseq_idx])
|
||||
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
|
||||
|
||||
sliced_mm_inputs: dict[str, Any] = {}
|
||||
key_to_slice_meta = {
|
||||
"image_grid_thw": (image_start_idx, image_end_idx, True),
|
||||
"video_grid_thw": (video_start_idx, video_end_idx, True),
|
||||
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
|
||||
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
|
||||
}
|
||||
|
||||
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
|
||||
if key not in mm_inputs:
|
||||
continue
|
||||
|
||||
mm_value = mm_inputs[key]
|
||||
if mm_value is not None and end_idx > start_idx:
|
||||
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
|
||||
elif assign_none_when_empty:
|
||||
sliced_mm_inputs[key] = None
|
||||
|
||||
return sliced_mm_inputs
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
@@ -106,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
else:
|
||||
self.get_rope_func = None
|
||||
|
||||
def _compute_rope_position_ids(
|
||||
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
|
||||
) -> None:
|
||||
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if features["attention_mask"].sum() == 0:
|
||||
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
|
||||
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
|
||||
return
|
||||
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
def _compute_rope_position_ids_with_packing(
|
||||
self,
|
||||
features: dict[str, "torch.Tensor"],
|
||||
mm_inputs: dict[str, Any],
|
||||
packing_params_list: list[dict[str, Any] | None],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_audlens: list[int],
|
||||
has_dummy_image: bool,
|
||||
) -> None:
|
||||
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
|
||||
bsz = features["input_ids"].size(0)
|
||||
seq_len = features["input_ids"].size(1)
|
||||
all_position_ids: list[torch.Tensor] = []
|
||||
all_rope_deltas: list[torch.Tensor] = []
|
||||
|
||||
if has_dummy_image:
|
||||
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
|
||||
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
|
||||
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
|
||||
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
|
||||
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
|
||||
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
|
||||
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
|
||||
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
|
||||
dummy_mm_inputs = copy.deepcopy(mm_inputs)
|
||||
|
||||
for sample_idx in range(bsz):
|
||||
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
||||
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
||||
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
||||
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
||||
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
||||
images_per_subseq = (
|
||||
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
|
||||
)
|
||||
videos_per_subseq = (
|
||||
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
|
||||
)
|
||||
if has_dummy_image:
|
||||
mm_inputs = {}
|
||||
|
||||
if num_sub_seqs <= 1:
|
||||
sample_features = {
|
||||
"input_ids": features["input_ids"],
|
||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
|
||||
}
|
||||
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
|
||||
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
|
||||
)
|
||||
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
|
||||
all_position_ids.append(sample_features["position_ids"])
|
||||
all_rope_deltas.append(sample_features["rope_deltas"])
|
||||
else:
|
||||
# when we do packing, don't need rope_deltas when training.
|
||||
sample_position_ids: list[torch.Tensor] = []
|
||||
for subseq_idx in range(num_sub_seqs):
|
||||
subseq_start = sequence_boundaries[subseq_idx]
|
||||
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||
subseq_features = {
|
||||
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||
}
|
||||
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
||||
mm_inputs,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
sample_idx,
|
||||
images_per_subseq,
|
||||
videos_per_subseq,
|
||||
subseq_idx
|
||||
)
|
||||
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||
sample_position_ids.append(subseq_features["position_ids"])
|
||||
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
|
||||
|
||||
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
|
||||
|
||||
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
|
||||
if has_dummy_image:
|
||||
mm_inputs = dummy_mm_inputs
|
||||
|
||||
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
|
||||
all_position_ids[0].size(0),
|
||||
bsz,
|
||||
seq_len,
|
||||
)
|
||||
# Check if position_ids shape matches expected shape.
|
||||
# for further usage, we should padding to the right when some padding token on the right.
|
||||
if has_dummy_image:
|
||||
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
||||
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
|
||||
|
||||
if features["position_ids"].shape != expected_position_ids_shape:
|
||||
raise ValueError(
|
||||
"Merged position_ids shape mismatch: "
|
||||
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
|
||||
)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
packing_params_list: list[dict[str, Any] | None] = []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
@@ -120,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_audlens.append(len(audios))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
packing_params_list.append(feature.pop("packing_params", None))
|
||||
|
||||
fake_input_ids = []
|
||||
has_dummy_image = False
|
||||
if (
|
||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
@@ -137,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_input_ids.extend(_fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
has_dummy_image = True
|
||||
|
||||
if (
|
||||
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
||||
@@ -183,57 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
bsz, seq_len = features["input_ids"].shape[:2]
|
||||
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
|
||||
is_omni = model_type in [
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
]
|
||||
|
||||
if self.get_rope_func is not None:
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
||||
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
||||
boundaries_list = [
|
||||
p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
|
||||
]
|
||||
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||
if has_dummy_image and has_packing:
|
||||
# FIXME: too tricky, need to be refactored
|
||||
features["has_dummy_image"] = True
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
|
||||
if not has_packing:
|
||||
self._compute_rope_position_ids(features, mm_inputs)
|
||||
else:
|
||||
if is_omni:
|
||||
raise RuntimeError("Omni models are not supported for packed sequences for now.")
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
self._compute_rope_position_ids_with_packing(
|
||||
features,
|
||||
mm_inputs,
|
||||
packing_params_list,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
batch_audlens,
|
||||
has_dummy_image,
|
||||
)
|
||||
|
||||
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
|
||||
if features["position_ids"].dim() == 3:
|
||||
features["position_ids"] = torch.cat(
|
||||
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
|
||||
)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in [
|
||||
"glm4v",
|
||||
"glm_ocr",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_5",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
and getattr(self.model.config, "model_type", None) in MROPE_MODELS
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||
@@ -261,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
neat_packing: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2":
|
||||
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
|
||||
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
|
||||
|
||||
@staticmethod
|
||||
def _unpad_packed_features(features: dict[str, Any]) -> None:
|
||||
r"""Trim padded positions for packed FA2 batches."""
|
||||
attention_mask = features.get("attention_mask")
|
||||
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
|
||||
return
|
||||
|
||||
seq_len = attention_mask.size(1)
|
||||
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
|
||||
if non_padding_indices.numel() == seq_len:
|
||||
return
|
||||
|
||||
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
|
||||
for key, value in list(features.items()):
|
||||
if not torch.is_tensor(value):
|
||||
continue
|
||||
|
||||
if key == "position_ids" and value.size(-1) == seq_len:
|
||||
features[key] = value.index_select(-1, non_padding_indices)
|
||||
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
has_dummy_image = features.pop("has_dummy_image", False)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
|
||||
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
|
||||
if not has_dummy_image:
|
||||
self._unpad_packed_features(features)
|
||||
|
||||
features["attention_mask"] = None # let transformers handle causal packed mask.
|
||||
|
||||
for key, value in features.items(): # cast data dtype for paligemma
|
||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||
features[key] = value.to(self.compute_dtype)
|
||||
|
||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
||||
|
||||
# filter out non-JSON files
|
||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
||||
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||
if not files:
|
||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||
|
||||
|
||||
@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from transformers.video_utils import make_batched_videos
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -47,13 +48,6 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.52.0"):
|
||||
from transformers.image_utils import make_flat_list_of_images
|
||||
from transformers.video_utils import make_batched_videos
|
||||
else:
|
||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from numpy.typing import NDArray
|
||||
@@ -161,7 +155,9 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@@ -390,7 +386,9 @@ class MMPluginMixin:
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
chunk_input=True,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)
|
||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
||||
audio_feature_lens = [
|
||||
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||
]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
||||
)
|
||||
|
||||
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
* position_id_per_seconds
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
t_ntoken_per_chunk = position_id_per_seconds * 2
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||
|
||||
@dataclass
|
||||
class PackingParams:
|
||||
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||
|
||||
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
|
||||
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
|
||||
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
|
||||
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
|
||||
"""
|
||||
|
||||
sequence_boundaries: list[int]
|
||||
image_subseq_ids: list[int]
|
||||
video_subseq_ids: list[int]
|
||||
audio_subseq_ids: list[int]
|
||||
right_padding_length: int
|
||||
|
||||
@dataclass
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
requires_packing_params = self.data_args.neat_packing
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
if requires_packing_params:
|
||||
sequence_boundaries = [0]
|
||||
image_subseq_ids: list[int] = []
|
||||
video_subseq_ids: list[int] = []
|
||||
audio_subseq_ids: list[int] = []
|
||||
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if requires_packing_params:
|
||||
n_img = len(batch_images[index])
|
||||
n_vid = len(batch_videos[index])
|
||||
n_aud = len(batch_audios[index])
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
|
||||
image_subseq_ids.extend([i] * n_img)
|
||||
video_subseq_ids.extend([i] * n_vid)
|
||||
audio_subseq_ids.extend([i] * n_aud)
|
||||
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if requires_packing_params:
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
|
||||
|
||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
if requires_packing_params:
|
||||
packing_params = PackingParams(
|
||||
sequence_boundaries=sequence_boundaries,
|
||||
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
|
||||
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
right_padding_length=pad_length,
|
||||
)
|
||||
model_inputs["packing_params"].append(asdict(packing_params))
|
||||
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["position_ids"].append(packed_position_ids)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
@@ -1113,7 +1113,7 @@ register_template(
|
||||
register_template(
|
||||
name="gpt_oss",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||
|
||||
@@ -361,6 +361,8 @@ class MiniMaxM2ToolUtils(ToolUtils):
|
||||
prompt += "\n</invoke>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
|
||||
@@ -69,12 +69,28 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
MROPE_MODELS = {
|
||||
"glm4v",
|
||||
"glm_ocr",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
MULTIMODAL_SUPPORTED_MODELS = set()
|
||||
|
||||
PEFT_METHODS = {"lora", "oft"}
|
||||
@@ -2812,24 +2828,61 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3.5-27B": {
|
||||
"Qwen3.5-0.8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
|
||||
},
|
||||
"Qwen3.5-2B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
|
||||
},
|
||||
"Qwen3.5-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
|
||||
},
|
||||
"Qwen3.5-9B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
},
|
||||
"Qwen3.5-0.8B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
|
||||
},
|
||||
"Qwen3.5-2B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
|
||||
},
|
||||
"Qwen3.5-4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
|
||||
},
|
||||
"Qwen3.5-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
|
||||
},
|
||||
"Qwen3.5-27B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
|
||||
},
|
||||
"Qwen3.5-35B-A3B": {
|
||||
"Qwen3.5-35B-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
|
||||
},
|
||||
"Qwen3.5-122B-A10B": {
|
||||
"Qwen3.5-122B-A10B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
|
||||
},
|
||||
"Qwen3.5-397B-A17B": {
|
||||
"Qwen3.5-397B-A17B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
|
||||
},
|
||||
},
|
||||
template="qwen3_5",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.51.0,<=5.2.0")
|
||||
check_version("transformers>=4.55.0,<=5.2.0")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.18.0,<=0.18.1")
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||
from ..extras.packages import is_mcore_adapter_available
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -100,6 +100,52 @@ def _parse_args(
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||
"""Validates Trackio-specific arguments.
|
||||
|
||||
Args:
|
||||
training_args: TrainingArguments instance (not a dictionary)
|
||||
"""
|
||||
report_to = training_args.report_to
|
||||
if not report_to:
|
||||
return
|
||||
|
||||
if isinstance(report_to, str):
|
||||
report_to = [report_to]
|
||||
|
||||
if "trackio" not in report_to:
|
||||
return
|
||||
|
||||
# --- Enforce project (required by Trackio) ---
|
||||
if not training_args.project:
|
||||
raise ValueError("`--project` must be specified when using Trackio.")
|
||||
|
||||
# --- Validate trackio_space_id format ---
|
||||
space_id = training_args.trackio_space_id
|
||||
if space_id:
|
||||
if space_id != "trackio" and "/" not in space_id:
|
||||
logger.warning(
|
||||
f"trackio_space_id '{space_id}' should typically be in format "
|
||||
"'org/space' for Hugging Face Spaces deployment."
|
||||
)
|
||||
|
||||
# --- Inform about default project usage ---
|
||||
if training_args.project == "huggingface":
|
||||
logger.info(
|
||||
"Using default project name 'huggingface'. "
|
||||
"Consider setting a custom project name with --project "
|
||||
"for better organization."
|
||||
)
|
||||
|
||||
# --- Validate hub repo privacy flag ---
|
||||
if training_args.hub_private_repo:
|
||||
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||
|
||||
# --- Recommend run_name for experiment clarity ---
|
||||
if not training_args.run_name:
|
||||
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
if training_args.report_to and any(
|
||||
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||
|
||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
@@ -346,12 +394,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
_verify_trackio_args(training_args)
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
@@ -421,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
|
||||
@@ -79,6 +79,8 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
||||
elif model_type == "qwen3_next":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||
elif model_type == "qwen3_5":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_5 as apply_liger_kernel
|
||||
elif model_type == "gpt_oss":
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||
|
||||
@@ -37,7 +37,6 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -45,10 +44,6 @@ import torch.nn.functional as F
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
@@ -24,7 +24,6 @@ import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -344,9 +343,7 @@ _register_composite_model(
|
||||
model_type="qwen2_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
@@ -355,9 +352,7 @@ _register_composite_model(
|
||||
model_type="qwen2_5_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
@@ -390,7 +385,25 @@ _register_composite_model(
|
||||
"visual.deepstack_merger_list",
|
||||
"audio_tower",
|
||||
],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5",
|
||||
projector_key="model.visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5_moe",
|
||||
projector_key="model.visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.kv_cache import configure_kv_cache
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
@@ -142,7 +141,6 @@ def patch_config(
|
||||
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(model_args, is_trainable)
|
||||
configure_kv_cache(config, model_args, is_trainable)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
|
||||
@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
and getattr(args, "overwrite_output_dir", False)
|
||||
):
|
||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
|
||||
}
|
||||
)
|
||||
|
||||
if "trackio" in args.report_to:
|
||||
import trackio
|
||||
|
||||
trackio.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab # type: ignore
|
||||
|
||||
|
||||
@@ -12,4 +12,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO override the original trainer
|
||||
from typing import Any
|
||||
|
||||
import torch.nn.functional as F
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
|
||||
|
||||
class CustomMcaTrainer(McaTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def _pad_batched_inputs(self, inputs: dict[str, Tensor | Any], seq_length: int):
|
||||
r"""Override to avoid padding error when handling 3d posids."""
|
||||
padding_inputs = {
|
||||
k: v.tolist() if v is not None and isinstance(v, Tensor) else v
|
||||
for k, v in inputs.items()
|
||||
if k in self._language_input_names
|
||||
}
|
||||
|
||||
position_ids_3d = None
|
||||
if isinstance(inputs.get("position_ids"), Tensor) and inputs["position_ids"].dim() == 3:
|
||||
position_ids_3d = inputs["position_ids"]
|
||||
padding_inputs.pop("position_ids", None)
|
||||
|
||||
if "labels" in padding_inputs:
|
||||
padding_inputs["labels"] = [
|
||||
labels + [IGNORE_INDEX] * (seq_length - len(labels)) for labels in padding_inputs["labels"]
|
||||
]
|
||||
tokenizer = (
|
||||
self.processing_class
|
||||
if isinstance(self.processing_class, PreTrainedTokenizerBase)
|
||||
else getattr(self.processing_class, "tokenizer", self.processing_class)
|
||||
)
|
||||
padding_side = getattr(tokenizer, "padding_side", "right")
|
||||
padding_inputs = tokenizer.pad(
|
||||
padding_inputs,
|
||||
padding="max_length",
|
||||
max_length=seq_length,
|
||||
return_tensors="pt",
|
||||
).to(self.args.device)
|
||||
inputs.update(padding_inputs)
|
||||
|
||||
if position_ids_3d is not None:
|
||||
current_seq_len = position_ids_3d.size(-1)
|
||||
if current_seq_len < seq_length:
|
||||
pad_len = seq_length - current_seq_len
|
||||
if padding_side == "left":
|
||||
position_ids_3d = F.pad(position_ids_3d, (pad_len, 0), value=0)
|
||||
else:
|
||||
position_ids_3d = F.pad(position_ids_3d, (0, pad_len), value=0)
|
||||
|
||||
inputs["position_ids"] = position_ids_3d.to(self.args.device)
|
||||
|
||||
return inputs
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import (
|
||||
@@ -41,9 +44,10 @@ if not is_mcore_adapter_available():
|
||||
|
||||
from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
|
||||
from .trainer import CustomMcaTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
@@ -70,37 +74,53 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
for k in ["attention_mask", "position_ids"]:
|
||||
if k in feature:
|
||||
feature[k] = feature[k][:-1]
|
||||
return data_collator(features)
|
||||
|
||||
# for qwen vl series model
|
||||
tmp_features = data_collator(features)
|
||||
tmp_features.pop("rope_deltas", None)
|
||||
position_ids = tmp_features.get("position_ids", None)
|
||||
|
||||
if position_ids is not None and position_ids.dim() == 3:
|
||||
if position_ids.shape[0] == 4:
|
||||
position_ids = position_ids[1:]
|
||||
tmp_features["position_ids"] = position_ids
|
||||
|
||||
return tmp_features
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||
model_type = mca_config.get("hf_model_type", None)
|
||||
else:
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
model_type = config.model_type
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
if model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {config.model_type} is not supported by mcore_adapter."
|
||||
f"Model {model_type} is not supported by mcore_adapter."
|
||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||
)
|
||||
|
||||
|
||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
|
||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||
return
|
||||
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||
params_to_freeze.extend(["vision_model.pos_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
params_to_freeze.extend(["vision_model.merger"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
@@ -111,6 +131,27 @@ def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments")
|
||||
p.requires_grad_(False)
|
||||
|
||||
|
||||
def _build_meta_hf_model_for_collator(model_args: "ModelArguments") -> Any | None:
|
||||
r"""Build a lightweight HF model on meta device for compatibility with collator."""
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
from transformers import AutoModel as HfAutoModel
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
try:
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
with torch.device("meta"):
|
||||
try:
|
||||
# Prefer multimodal auto class for VLMs (e.g. qwen2-vl), so get_rope_index is available.
|
||||
return AutoModelForImageTextToText.from_config(config)
|
||||
except Exception:
|
||||
return HfAutoModel.from_config(config)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to build meta HF model for collator, fallback to no model. Error: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -136,7 +177,7 @@ def run_pt(
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
trainer = CustomMcaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
@@ -186,6 +227,7 @@ def run_sft(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
collator_model = _build_meta_hf_model_for_collator(model_args)
|
||||
|
||||
# optional freezing for qwen_vl series
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
@@ -193,6 +235,7 @@ def run_sft(
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
template=template,
|
||||
model=collator_model,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
pad_to_multiple_of=64,
|
||||
@@ -201,7 +244,7 @@ def run_sft(
|
||||
)
|
||||
data_collator = _data_collator_wrapper(data_collator)
|
||||
|
||||
trainer = McaTrainer(
|
||||
trainer = CustomMcaTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
@@ -240,6 +283,7 @@ def run_dpo(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
collator_model = _build_meta_hf_model_for_collator(model_args)
|
||||
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
@@ -263,6 +307,7 @@ def run_dpo(
|
||||
)
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
template=template,
|
||||
model=collator_model,
|
||||
pad_to_multiple_of=64,
|
||||
padding="max_length" if pad_to_max else "longest",
|
||||
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||
|
||||
@@ -215,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||
input_ids_column = dataset["input_ids"]
|
||||
try:
|
||||
input_ids_list = input_ids_column.to_pylist()
|
||||
except AttributeError:
|
||||
input_ids_list = list(input_ids_column)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ def run_sft(
|
||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
block_diag_attn=model_args.block_diag_attn,
|
||||
neat_packing=data_args.neat_packing,
|
||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
**tokenizer_module,
|
||||
|
||||
@@ -52,6 +52,7 @@ if is_ray_available():
|
||||
import ray
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -941,7 +942,7 @@ def get_ray_remote_config_for_worker(
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..utils.env import is_env_enabled
|
||||
from ..utils.helper import set_seed
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
default=False,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
@@ -81,6 +81,14 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
logging_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Log metrics every N optimizer steps."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
LoggingCallback,
|
||||
TrainerCallback,
|
||||
TrainerState,
|
||||
)
|
||||
from ..utils.helper import compute_valid_tokens
|
||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||
from .utils.batching import BatchGenerator
|
||||
@@ -52,6 +58,7 @@ class BaseTrainer:
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
train_dataset: TorchDataset,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
@@ -64,6 +71,7 @@ class BaseTrainer:
|
||||
# cached variables
|
||||
self.device = DistributedInterface().current_device
|
||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
|
||||
self.model_input_names = self.renderer.processor.model_input_names
|
||||
|
||||
self._create_batch_generator()
|
||||
@@ -76,7 +84,7 @@ class BaseTrainer:
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
self._accelerate_engine = None
|
||||
self._deepspeed_engine = None
|
||||
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||
|
||||
if dist_name == "deepspeed":
|
||||
@@ -99,6 +107,29 @@ class BaseTrainer:
|
||||
self._init_optimizer()
|
||||
self._init_lr_scheduler()
|
||||
|
||||
# Callbacks
|
||||
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
|
||||
for cb in callbacks or []:
|
||||
self.callback_handler.add_callback(cb)
|
||||
|
||||
# Callbacks: TrainerState tracks progress across the full run.
|
||||
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
if model.config.model_type == "qwen3_5":
|
||||
raise RuntimeError(
|
||||
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
|
||||
)
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
@@ -108,6 +139,7 @@ class BaseTrainer:
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
@@ -156,7 +188,7 @@ class BaseTrainer:
|
||||
"""
|
||||
batch_size, _ = batch["labels"].shape
|
||||
model_inputs = {
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||
outputs: ModelOutput = model(**model_inputs)
|
||||
@@ -173,16 +205,31 @@ class BaseTrainer:
|
||||
def fit(self) -> None:
|
||||
"""Train the model."""
|
||||
self.model.train()
|
||||
self.callback_handler.on_train_begin(self.args, self.state)
|
||||
for epoch in range(self.args.num_train_epochs):
|
||||
self.state.epoch = epoch
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
self.state.global_step = self.global_step
|
||||
self.callback_handler.on_step_begin(self.args, self.state)
|
||||
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
num_micro = len(micro_batches)
|
||||
for i, micro_batch in enumerate(micro_batches):
|
||||
loss = self.compute_loss(micro_batch)
|
||||
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import (
|
||||
SequenceParallelLossPlugin,
|
||||
)
|
||||
|
||||
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
|
||||
else:
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
@@ -199,7 +246,24 @@ class BaseTrainer:
|
||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||
|
||||
parameters = self.model.parameters()
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
parameters = list(parameters)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
grad_norm = _get_total_norm(grads)
|
||||
grad_norm = grad_norm.to(self.device)
|
||||
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
|
||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||
grad_norm = grad_norm.full_tensor().item()
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.args.max_grad_norm
|
||||
).item()
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
@@ -212,14 +276,41 @@ class BaseTrainer:
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
|
||||
# Update state with step metrics
|
||||
current_lr = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if hasattr(self.lr_scheduler, "get_last_lr")
|
||||
else self.args.learning_rate
|
||||
)
|
||||
self.state.loss = step_loss
|
||||
self.state.grad_norm = grad_norm
|
||||
self.state.learning_rate = current_lr
|
||||
|
||||
self.callback_handler.on_step_end(self.args, self.state)
|
||||
|
||||
# Logging: trainer decides when to log
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"epoch": epoch,
|
||||
"step": self.global_step,
|
||||
"loss": step_loss,
|
||||
"grad_norm": grad_norm,
|
||||
"learning_rate": current_lr,
|
||||
}
|
||||
self.callback_handler.on_log(self.args, self.state, logs)
|
||||
|
||||
# Check if max_steps is reached
|
||||
if self.global_step >= self.num_training_steps:
|
||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||
self.callback_handler.on_train_end(self.args, self.state)
|
||||
return
|
||||
|
||||
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||
|
||||
self.callback_handler.on_train_end(self.args, self.state)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||
@@ -233,3 +324,5 @@ class BaseTrainer:
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
self.callback_handler.on_save(self.args, self.state)
|
||||
|
||||
@@ -140,6 +140,9 @@ class ModelEngine:
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
|
||||
model._init_mode = init_mode
|
||||
|
||||
if self.args.peft_config is None:
|
||||
if self.is_train:
|
||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||
@@ -147,6 +150,9 @@ class ModelEngine:
|
||||
else:
|
||||
logger.info_rank0("Inference the original model")
|
||||
else:
|
||||
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
from ..plugins.model_plugins.peft import PeftPlugin
|
||||
|
||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
# TODO: support length and infinity
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
seed=self.seed,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
@@ -91,7 +91,11 @@ class Renderer:
|
||||
self.processor = processor
|
||||
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
self,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
enable_thinking: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Apply template to messages and convert them to model input.
|
||||
|
||||
@@ -99,6 +103,7 @@ class Renderer:
|
||||
messages (list[Message]): The messages to render.
|
||||
tools (str | None, optional): The tools to use. Defaults to None.
|
||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ModelInput: The rendered model input.
|
||||
@@ -108,7 +113,9 @@ class Renderer:
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
return RenderingPlugin(self.template).render_messages(
|
||||
self.processor, messages, tools, is_generate, enable_thinking
|
||||
)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
"""Parse a message in the template format.
|
||||
@@ -139,6 +146,8 @@ class Renderer:
|
||||
for sample in samples:
|
||||
if "messages" in sample:
|
||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||
if "position_ids" not in model_input:
|
||||
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's verl library.
|
||||
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def all_to_all_tensor(
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
||||
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||
dist.all_to_all(output_list, input_list, group=group)
|
||||
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class SeqAllToAll4D(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
group: dist.ProcessGroup,
|
||||
local_input: Tensor,
|
||||
scatter_dim: int,
|
||||
gather_dim: int,
|
||||
) -> Tensor:
|
||||
ctx.group = group
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
|
||||
return (
|
||||
None,
|
||||
all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
@@ -0,0 +1,199 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
|
||||
from ....accelerator.interface import Dim, DistributedInterface
|
||||
from ....utils import logging
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import ModelOutput
|
||||
from .ulysses import (
|
||||
UlyssesAttention,
|
||||
get_ulysses_sequence_parallel_group,
|
||||
get_ulysses_sequence_parallel_rank,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
set_ulysses_sequence_parallel_group,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SequenceParallelModelPlugin(BasePlugin):
|
||||
def __call__(self, model, model_args):
|
||||
return super().__call__(model, model_args)
|
||||
|
||||
|
||||
class SequenceParallelLossPlugin(BasePlugin):
|
||||
def __call__(self, model, inputs, *args, **kwargs):
|
||||
return super().__call__(model, inputs, *args, **kwargs)
|
||||
|
||||
|
||||
def new_flash_attn_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
sequence_parallel_size=1,
|
||||
dropout=0,
|
||||
deterministic=False,
|
||||
is_causal=True,
|
||||
group=None,
|
||||
mode="ulysses",
|
||||
attn_fn=None,
|
||||
target_dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
if mode == "ulysses":
|
||||
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
|
||||
attn_output = dist_attn(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length=query_states.shape[1] * sequence_parallel_size,
|
||||
deterministic=deterministic,
|
||||
dropout_p=dropout,
|
||||
causal=is_causal,
|
||||
position_ids=kwargs.get("position_ids", None),
|
||||
target_dtype=target_dtype,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
@SequenceParallelModelPlugin("ulysses").register()
|
||||
def apply_sequence_parallel(model, model_args):
|
||||
# Replace _flash_attention_forward with new_flash_attn_forward
|
||||
module = sys.modules[model.__module__]
|
||||
cp_size = model_args.get("cp_size", 1)
|
||||
|
||||
set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP))
|
||||
|
||||
try:
|
||||
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads
|
||||
except AttributeError:
|
||||
num_attention_heads, num_key_value_heads = (
|
||||
model.config.text_config.num_attention_heads,
|
||||
model.config.text_config.num_key_value_heads,
|
||||
)
|
||||
|
||||
assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size"
|
||||
assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, (
|
||||
"num_key_value_heads must be divisible by cp_size"
|
||||
)
|
||||
|
||||
origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
new_flash_attention_forward = partial(
|
||||
new_flash_attn_forward,
|
||||
group=get_ulysses_sequence_parallel_group(),
|
||||
mode="ulysses",
|
||||
attn_fn=origin_attn,
|
||||
sequence_parallel_size=cp_size,
|
||||
)
|
||||
|
||||
for module_name, module in list(sys.modules.items()):
|
||||
try:
|
||||
if (
|
||||
hasattr(module, "__file__")
|
||||
and "transformers" in module.__file__
|
||||
and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward"
|
||||
):
|
||||
module._flash_attention_forward = new_flash_attention_forward
|
||||
logger.info_rank0(
|
||||
f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel."
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
continue
|
||||
|
||||
|
||||
def padding_and_split_data(data, device_mesh=None):
|
||||
if device_mesh is not None:
|
||||
cp_size = device_mesh["cp"].size()
|
||||
cp_rank = device_mesh["cp"].get_local_rank()
|
||||
cp_group = device_mesh["cp"].get_group()
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 1:
|
||||
data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64)
|
||||
global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)]
|
||||
dist.all_gather(global_data_len, data_len, group=cp_group)
|
||||
max_data_len = max(global_data_len)
|
||||
pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size
|
||||
if k == "labels":
|
||||
pad_value = -100
|
||||
elif k == "loss_weights":
|
||||
pad_value = 0.0
|
||||
else:
|
||||
pad_value = 0
|
||||
pad_data = F.pad(v, (0, pad_size), value=pad_value)
|
||||
data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous()
|
||||
return data
|
||||
|
||||
|
||||
@SequenceParallelLossPlugin("sequence_parallel_loss").register()
|
||||
def sequence_parallel_loss(model, model_inputs):
|
||||
device_mesh = DistributedInterface().get_device_mesh(Dim.CP)
|
||||
|
||||
model_inputs = {
|
||||
k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
model_inputs = padding_and_split_data(model_inputs, device_mesh)
|
||||
|
||||
batch_size, _ = model_inputs["labels"].shape
|
||||
|
||||
outputs: ModelOutput = model(**model_inputs)
|
||||
|
||||
logits = outputs.logits.float()
|
||||
|
||||
labels = model_inputs["labels"]
|
||||
|
||||
cp_group = get_ulysses_sequence_parallel_group()
|
||||
cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group)
|
||||
cp_rank = get_ulysses_sequence_parallel_rank(cp_group)
|
||||
|
||||
# use all_gather to collect labels from all sequence parallel processes
|
||||
global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)]
|
||||
dist.all_gather(global_labels, labels, group=cp_group)
|
||||
labels = torch.cat(global_labels, dim=1).contiguous()
|
||||
shift_labels = labels[..., 1:].view(-1).contiguous()
|
||||
shift_labels = F.pad(shift_labels, (0, 1), value=-100)
|
||||
shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous()
|
||||
|
||||
# use all_gather to collect loss_weights from all sequence parallel processes
|
||||
loss_weights = model_inputs["loss_weights"]
|
||||
global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)]
|
||||
dist.all_gather(global_loss_weights, loss_weights, group=cp_group)
|
||||
shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous()
|
||||
shift_loss_weights = shift_loss_weights[..., 1:].contiguous()
|
||||
|
||||
shift_logits = logits.view(shift_labels.size(0), -1).contiguous()
|
||||
|
||||
# use all_gather to collect log_probs from all sequence parallel processes
|
||||
log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||
global_log_probs = dist.nn.all_gather(log_probs, group=cp_group)
|
||||
global_log_probs = torch.cat(global_log_probs, dim=1).contiguous()
|
||||
log_probs = global_log_probs[..., :-1].contiguous()
|
||||
|
||||
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,163 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's verl library.
|
||||
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .seq_comm import SeqAllToAll4D
|
||||
|
||||
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
||||
"""Set ulysses sequence parallel process group."""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
||||
"""Get ulysses sequence parallel process group."""
|
||||
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
||||
"""Get ulysses sequence parallel world size."""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_world_size(group) if group else 1
|
||||
|
||||
|
||||
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
||||
"""Get ulysses sequence parallel rank."""
|
||||
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||
return dist.get_rank(group) if group else 0
|
||||
|
||||
|
||||
class UlyssesAttention(torch.nn.Module):
|
||||
"""Initialization.
|
||||
|
||||
Arguments:
|
||||
local_attention (Module): local attention with q,k,v
|
||||
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||
scatter_idx (int): scatter_idx for all2all comm
|
||||
gather_idx (int): gather_idx for all2all comm
|
||||
attn_type (AttnType): attention type enum
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_process_group: dist.ProcessGroup = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
attn_fn: Optional[callable] = None,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.spg = sequence_process_group
|
||||
self.scatter_idx = scatter_idx
|
||||
self.gather_idx = gather_idx
|
||||
self.attn_fn = attn_fn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
causal=True,
|
||||
deterministic=False,
|
||||
target_dtype=None,
|
||||
*args: Any,
|
||||
) -> Tensor:
|
||||
"""Forward.
|
||||
|
||||
Arguments:
|
||||
query (Tensor): query input to the layer
|
||||
key (Tensor): key input to the layer
|
||||
value (Tensor): value input to the layer
|
||||
attention_mask (Tensor): attention mask for the layer
|
||||
query_length (int): the length of the query sequence
|
||||
dropout_p (float, optional): dropout probability. Defaults to 0.0.
|
||||
softmax_scale (float, optional): scale factor for softmax. Defaults to None,
|
||||
position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None.
|
||||
causal (bool, optional): whether to apply causal mask. Defaults to True.
|
||||
deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False.
|
||||
target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None.
|
||||
args: other args
|
||||
|
||||
Returns:
|
||||
* output (Tensor): context output
|
||||
"""
|
||||
# TODO Merge three alltoall calls into one
|
||||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||
# in shape : e.g., [s/p:h:]
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
|
||||
# scatter 2, gather 1
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** -0.5
|
||||
|
||||
if attention_mask is None:
|
||||
if position_ids is not None:
|
||||
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||
else:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
|
||||
context_layer = self.attn_fn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask,
|
||||
query_length=query_length,
|
||||
is_causal=causal,
|
||||
dropout=dropout_p,
|
||||
position_ids=position_ids,
|
||||
softmax_scale=softmax_scale,
|
||||
deterministic=deterministic,
|
||||
target_dtype=target_dtype,
|
||||
)
|
||||
|
||||
if isinstance(context_layer, tuple):
|
||||
context_layer = context_layer[0]
|
||||
|
||||
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
|
||||
# scatter 1, gather 2
|
||||
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
||||
|
||||
# out e.g., [s/p::h]
|
||||
return output
|
||||
@@ -12,224 +12,45 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
import importlib
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
_attempted_template_imports: set[str] = set()
|
||||
|
||||
def _ensure_template_imported(self) -> None:
|
||||
if self.name is None or self.name in self._attempted_template_imports:
|
||||
return
|
||||
|
||||
full_module_name = f"{__package__}.templates.{self.name}"
|
||||
self._attempted_template_imports.add(self.name)
|
||||
try:
|
||||
importlib.import_module(full_module_name)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
|
||||
|
||||
def __getitem__(self, method_name: str):
|
||||
self._ensure_template_imported()
|
||||
return super().__getitem__(method_name)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
enable_thinking: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Render messages in the template format."""
|
||||
return self["render_messages"](processor, messages, tools, is_generate)
|
||||
return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
|
||||
|
||||
def parse_messages(self, generated_text: str) -> Message:
|
||||
"""Parse messages in the template format."""
|
||||
return self["parse_messages"](generated_text)
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen3_nothink_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Render messages in the Qwen3 nothink template format.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ....utils.constants import IGNORE_INDEX
|
||||
from ....utils.helper import get_tokenizer
|
||||
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||
from ..rendering import RenderingPlugin
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _concat_text_content(message: Message) -> str:
|
||||
"""Concatenate text fields in a message."""
|
||||
message_text = ""
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
message_text += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
return message_text
|
||||
|
||||
|
||||
def _get_last_query_index(messages: list[Message]) -> int:
|
||||
"""Find the last user query index, excluding wrapped tool responses."""
|
||||
last_query_index = len(messages) - 1
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
message = messages[idx]
|
||||
if message["role"] != "user":
|
||||
continue
|
||||
|
||||
user_text = ""
|
||||
is_plain_text = True
|
||||
for content in message["content"]:
|
||||
if content["type"] != "text":
|
||||
is_plain_text = False
|
||||
break
|
||||
user_text += content["value"]
|
||||
|
||||
if not is_plain_text:
|
||||
continue
|
||||
|
||||
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
|
||||
last_query_index = idx
|
||||
break
|
||||
|
||||
return last_query_index
|
||||
|
||||
|
||||
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
|
||||
"""Split assistant message into text, reasoning and tool calls."""
|
||||
text_content = ""
|
||||
reasoning_content = ""
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
text_content += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
reasoning_content += content["value"]
|
||||
elif content["type"] == "tool_call":
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
tool_calls.append(tool_call)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
return text_content, reasoning_content, tool_calls
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3").register("render_messages")
|
||||
def render_qwen3_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
enable_thinking: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Render messages in the Qwen3 template format.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
last_query_index = _get_last_query_index(messages)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
|
||||
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
|
||||
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
|
||||
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
|
||||
else:
|
||||
temp_str += text_content
|
||||
|
||||
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
|
||||
temp_str += "\n"
|
||||
|
||||
arguments = tool_call.get("arguments")
|
||||
if isinstance(arguments, str):
|
||||
arguments_str = arguments
|
||||
else:
|
||||
arguments_str = json.dumps(arguments, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ arguments_str
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
if enable_thinking is False:
|
||||
temp_str += "<think>\n\n</think>\n\n"
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3").register("parse_message")
|
||||
def parse_qwen3_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the Qwen3 template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "think":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ....utils.constants import IGNORE_INDEX
|
||||
from ....utils.helper import get_tokenizer
|
||||
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||
from ..rendering import RenderingPlugin
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _concat_text_content(message: Message) -> str:
|
||||
"""Concatenate text fields in a message."""
|
||||
message_text = ""
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
message_text += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
return message_text
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen3_nothink_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
enable_thinking: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Render messages in the Qwen3 nothink template format.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
if enable_thinking:
|
||||
raise ValueError("The qwen3_nothink template does not support thinking mode.")
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
@@ -12,10 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
@@ -83,10 +85,7 @@ class FSDP2Engine:
|
||||
)
|
||||
|
||||
if self.device_mesh is not None:
|
||||
try:
|
||||
self.fsdp_mesh = self.device_mesh["dp"]
|
||||
except Exception:
|
||||
self.fsdp_mesh = self.device_mesh
|
||||
self.fsdp_mesh = self.device_mesh
|
||||
|
||||
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
||||
else:
|
||||
@@ -166,12 +165,11 @@ class FSDP2Engine:
|
||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||
)
|
||||
|
||||
use_gradient_checkpointing = True # Could be configurable
|
||||
if use_gradient_checkpointing:
|
||||
# BaseTrainer is the single source of truth for gradient checkpointing.
|
||||
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
|
||||
if getattr(model, "is_gradient_checkpointing", False):
|
||||
if self.rank == 0:
|
||||
logger.info("Enabling gradient checkpointing (transformers native)...")
|
||||
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
|
||||
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
@@ -213,12 +211,88 @@ class FSDP2Engine:
|
||||
|
||||
return model
|
||||
|
||||
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
|
||||
"""Save non-persistent buffers, such as inv_freq."""
|
||||
saved = {}
|
||||
for mod_name, module in model.named_modules():
|
||||
for buf_name in module._non_persistent_buffers_set:
|
||||
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||
buf = getattr(module, buf_name, None)
|
||||
if buf is not None:
|
||||
saved[fqn] = copy.deepcopy(buf)
|
||||
if self.rank == 0 and saved:
|
||||
logger.info(f"Saved {len(saved)} non-persistent buffers")
|
||||
return saved
|
||||
|
||||
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
|
||||
"""Register saved non-persistent buffers to model."""
|
||||
if not saved_buffers:
|
||||
return
|
||||
device = get_current_accelerator()
|
||||
for fqn, buf in saved_buffers.items():
|
||||
buf = buf.to(device)
|
||||
if "." in fqn:
|
||||
parent_fqn, buf_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
buf_name = fqn
|
||||
parent_module = model
|
||||
parent_module.register_buffer(buf_name, buf, persistent=False)
|
||||
if self.rank == 0:
|
||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
init_mode = getattr(model, "_init_mode", "init_on_default")
|
||||
|
||||
if init_mode == "init_on_rank0":
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
|
||||
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
|
||||
else:
|
||||
full_sd = {}
|
||||
|
||||
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
||||
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
||||
|
||||
model = self.prepare_model(model)
|
||||
|
||||
device = get_current_accelerator()
|
||||
model.to_empty(device=device)
|
||||
|
||||
# Scatter params from Rank 0 into all DTensor shards
|
||||
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||
set_model_state_dict(model, full_sd, options=options)
|
||||
|
||||
# Broadcast and restore non-persistent buffers
|
||||
buffers_to_sync = [saved_buffers]
|
||||
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
||||
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 sync complete.")
|
||||
|
||||
elif init_mode == "init_on_meta":
|
||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
|
||||
# fix tied broken for no-fsdp-wrap case
|
||||
if getattr(model.config, "tie_word_embeddings", False):
|
||||
model.tie_weights()
|
||||
|
||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
|
||||
return model
|
||||
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
|
||||
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .logging_callback import LoggingCallback
|
||||
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CallbackHandler",
|
||||
"LoggingCallback",
|
||||
"TrainerCallback",
|
||||
"TrainerState",
|
||||
]
|
||||
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .. import logging
|
||||
from .trainer_callback import TrainerCallback, TrainerState
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...config import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LoggingCallback(TrainerCallback):
|
||||
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
|
||||
|
||||
On each logging step the entry is also persisted as a JSON line in
|
||||
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
|
||||
"""
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
logs: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Persist in history regardless of rank
|
||||
state.log_history.append(dict(logs))
|
||||
|
||||
# Everything below is rank-0 only
|
||||
from ...accelerator.interface import DistributedInterface # lazy import
|
||||
|
||||
if DistributedInterface().get_rank() != 0:
|
||||
return
|
||||
|
||||
# Human-readable output to stdout
|
||||
display_logs = {**logs, "total_steps": state.num_training_steps}
|
||||
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||
logger.info_rank0(parts)
|
||||
|
||||
# Append to JSONL log file in output_dir
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")
|
||||
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...config import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerState:
|
||||
"""A read-only snapshot of training progress passed to every callback hook.
|
||||
|
||||
Attributes:
|
||||
epoch: Current epoch (0-indexed).
|
||||
global_step: Number of optimizer steps completed so far.
|
||||
num_training_steps: Total number of optimizer steps planned.
|
||||
loss: Scalar loss value of the most recent step.
|
||||
grad_norm: Gradient-norm value of the most recent step.
|
||||
learning_rate: Current learning rate seen by the optimizer.
|
||||
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
|
||||
"""
|
||||
|
||||
epoch: int = 0
|
||||
global_step: int = 0
|
||||
num_training_steps: int = 0
|
||||
loss: float = 0.0
|
||||
grad_norm: float = 0.0
|
||||
learning_rate: float = 0.0
|
||||
log_history: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class TrainerCallback:
|
||||
"""Abstract base class for training callbacks.
|
||||
|
||||
Subclass and override whichever hooks you need. All hooks receive:
|
||||
|
||||
- ``args`` – the :class:`~llamafactory.v1.config.TrainingArguments`.
|
||||
- ``state`` – a :class:`TrainerState` snapshot (read-only).
|
||||
- ``**kwargs`` – extra keyword arguments (model, optimizer, …).
|
||||
|
||||
Callbacks are *observers*: they should NOT mutate training flow.
|
||||
|
||||
Hook call order::
|
||||
|
||||
on_train_begin
|
||||
for each epoch:
|
||||
on_epoch_begin
|
||||
for each step:
|
||||
on_step_begin
|
||||
(forward / backward / optimizer.step)
|
||||
on_step_end
|
||||
[on_log] ← if this step is a logging step
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
"""
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called once before the first training step."""
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called once after the last training step."""
|
||||
|
||||
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called at the beginning of each epoch."""
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called at the end of each epoch."""
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called before the forward/backward pass of each optimizer step."""
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called after the optimizer step."""
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when the trainer emits a log entry."""
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
"""Called after the model checkpoint has been written to disk."""
|
||||
|
||||
|
||||
class CallbackHandler:
|
||||
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
|
||||
handler.on_train_begin(args, state)
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
|
||||
self.callbacks: list[TrainerCallback] = list(callbacks or [])
|
||||
self.trainer = trainer
|
||||
|
||||
def add_callback(self, callback: TrainerCallback) -> None:
|
||||
"""Append a callback to the handler."""
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||
if self.trainer is not None:
|
||||
kwargs.setdefault("model", getattr(self.trainer, "model", None))
|
||||
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
|
||||
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
|
||||
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
|
||||
|
||||
for cb in self.callbacks:
|
||||
getattr(cb, event)(args, state, **kwargs)
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_train_begin", args, state)
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_train_end", args, state)
|
||||
|
||||
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_epoch_begin", args, state)
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_epoch_end", args, state)
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_step_begin", args, state)
|
||||
|
||||
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_step_end", args, state)
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
|
||||
self._call("on_log", args, state, logs=logs)
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||
self._call("on_save", args, state)
|
||||
@@ -15,12 +15,22 @@
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
"""Check if processor is tokenizer.
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
||||
type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
|
||||
"""Type of the content."""
|
||||
value: str
|
||||
"""Value of the content."""
|
||||
|
||||
@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
with gr.Column():
|
||||
enable_thinking = gr.Checkbox(value=True)
|
||||
report_to = gr.Dropdown(
|
||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
|
||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
|
||||
value="none",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Accordion("Trackio Settings", open=False):
|
||||
project = gr.Textbox(
|
||||
value="huggingface",
|
||||
label="Project Name",
|
||||
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
|
||||
)
|
||||
|
||||
trackio_space_id = gr.Textbox(
|
||||
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
|
||||
)
|
||||
|
||||
hub_private_repo = gr.Checkbox(
|
||||
value=False, label="Private Repository", info="Make the Hugging Face repository private"
|
||||
)
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
logging_steps,
|
||||
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
use_llama_pro,
|
||||
enable_thinking,
|
||||
report_to,
|
||||
project,
|
||||
trackio_space_id,
|
||||
hub_private_repo,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
use_llama_pro=use_llama_pro,
|
||||
enable_thinking=enable_thinking,
|
||||
report_to=report_to,
|
||||
project=project,
|
||||
trackio_space_id=trackio_space_id,
|
||||
hub_private_repo=hub_private_repo,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -129,9 +130,177 @@ def test_multimodal_collator():
|
||||
|
||||
assert batch_input.keys() == expected_input.keys()
|
||||
for k in batch_input.keys():
|
||||
if k == "position_ids" and batch_input[k].dim() == 3 and batch_input[k].shape[0] == 4:
|
||||
batch_input[k] = batch_input[k][1:]
|
||||
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def _make_packed_feature(
|
||||
*,
|
||||
packing_params: dict,
|
||||
pad_token_id: int,
|
||||
label_ignore_id: int,
|
||||
fake_image: Image.Image,
|
||||
vision_start_id: int | None = None,
|
||||
vision_end_id: int | None = None,
|
||||
image_pad_id: int | None = None,
|
||||
) -> dict:
|
||||
r"""Build one packed sample using the new PackingParams schema."""
|
||||
sequence_boundaries = packing_params["sequence_boundaries"]
|
||||
image_subseq_ids = packing_params["image_subseq_ids"]
|
||||
video_subseq_ids = packing_params["video_subseq_ids"]
|
||||
audio_subseq_ids = packing_params["audio_subseq_ids"]
|
||||
unpadded_length = packing_params["unpadded_length"]
|
||||
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
|
||||
cutoff_plus_one = sequence_boundaries[-1]
|
||||
content_len = unpadded_length
|
||||
pad_len = right_padding_length
|
||||
assert content_len + pad_len == cutoff_plus_one
|
||||
assert sequence_boundaries[0] == 0
|
||||
assert sequence_boundaries[-1] == cutoff_plus_one
|
||||
|
||||
content_ids = list(range(100, 100 + content_len))
|
||||
if vision_start_id is not None and vision_end_id is not None and image_pad_id is not None:
|
||||
image_counts_by_subseq = Counter(image_subseq_ids)
|
||||
for subseq_idx, image_count in sorted(image_counts_by_subseq.items()):
|
||||
if subseq_idx >= len(sequence_boundaries) - 1:
|
||||
continue
|
||||
|
||||
subseq_start = sequence_boundaries[subseq_idx]
|
||||
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||
subseq_len = subseq_end - subseq_start
|
||||
if subseq_len < 3:
|
||||
continue
|
||||
|
||||
# Build repeated image groups while preserving at least 3 tokens for each remaining image.
|
||||
injected_tokens: list[int] = []
|
||||
remaining = subseq_len
|
||||
for image_idx in range(image_count):
|
||||
remaining_images = image_count - image_idx
|
||||
min_reserved_for_rest = 3 * (remaining_images - 1)
|
||||
current_group_len = min(6, remaining - min_reserved_for_rest)
|
||||
if current_group_len < 3:
|
||||
break
|
||||
|
||||
group = [vision_start_id] + [image_pad_id] * max(1, current_group_len - 2) + [vision_end_id]
|
||||
injected_tokens.extend(group[:current_group_len])
|
||||
remaining -= current_group_len
|
||||
|
||||
if injected_tokens:
|
||||
insert_end = subseq_start + len(injected_tokens)
|
||||
content_ids[subseq_start:insert_end] = injected_tokens
|
||||
|
||||
input_ids = content_ids + [pad_token_id] * pad_len
|
||||
attention_mask = [1] * content_len + [0] * pad_len
|
||||
labels = [label_ignore_id] * cutoff_plus_one
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"images": [fake_image] * len(image_subseq_ids),
|
||||
"videos": [None] * len(video_subseq_ids),
|
||||
"audios": [None] * len(audio_subseq_ids),
|
||||
"packing_params": packing_params,
|
||||
}
|
||||
|
||||
|
||||
def _make_packed_features(
|
||||
*,
|
||||
packing_params: dict,
|
||||
pad_token_id: int,
|
||||
label_ignore_id: int,
|
||||
fake_image: Image.Image,
|
||||
vision_start_id: int,
|
||||
vision_end_id: int,
|
||||
image_pad_id: int,
|
||||
) -> list[dict]:
|
||||
r"""Build packed features from caller-provided packing_params."""
|
||||
return [
|
||||
_make_packed_feature(
|
||||
packing_params=packing_params,
|
||||
pad_token_id=pad_token_id,
|
||||
label_ignore_id=label_ignore_id,
|
||||
fake_image=fake_image,
|
||||
vision_start_id=vision_start_id,
|
||||
vision_end_id=vision_end_id,
|
||||
image_pad_id=image_pad_id,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
||||
bound_list = packing_params["sequence_boundaries"]
|
||||
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
|
||||
all_position_ids = []
|
||||
for i, input_ids_slice in enumerate(input_ids_slices):
|
||||
img_cnt = img_counts_by_subseq[i]
|
||||
if sum(attention_mask_slices[i]) == 0:
|
||||
continue
|
||||
|
||||
rope_func_kwargs = {
|
||||
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
|
||||
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
|
||||
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
|
||||
}
|
||||
position_ids, _ = get_rope_func(**rope_func_kwargs)
|
||||
all_position_ids.append(position_ids)
|
||||
|
||||
return torch.cat(all_position_ids, dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_multimodal_collator_with_packing():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
tokenizer_module["tokenizer"].padding_side = "right"
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForImageTextToText.from_config(config)
|
||||
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
model=model,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
packing_params = {
|
||||
"sequence_boundaries": [0, 2, 10, 18, 28, 32],
|
||||
"image_subseq_ids": [1, 2, 3],
|
||||
"video_subseq_ids": [],
|
||||
"audio_subseq_ids": [],
|
||||
"unpadded_length": 28,
|
||||
"right_padding_length": 4,
|
||||
}
|
||||
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
features = _make_packed_features(
|
||||
packing_params=packing_params,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
label_ignore_id=IGNORE_INDEX,
|
||||
fake_image=fake_image,
|
||||
vision_start_id=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
|
||||
vision_end_id=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
|
||||
image_pad_id=tokenizer.convert_tokens_to_ids("<|image_pad|>"),
|
||||
)
|
||||
expected_position_ids = _get_expected_position_ids(
|
||||
packing_params,
|
||||
data_collator.get_rope_func,
|
||||
features[0]["input_ids"],
|
||||
features[0]["attention_mask"],
|
||||
)
|
||||
batch_input = data_collator(features) # [3, bsz, seq_len]
|
||||
valid_len = expected_position_ids.shape[-1]
|
||||
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_4d_attention_mask():
|
||||
o = 0.0
|
||||
|
||||
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import (
|
||||
SequenceParallelModelPlugin,
|
||||
sequence_parallel_loss,
|
||||
)
|
||||
from llamafactory.v1.utils.env import find_available_port
|
||||
from llamafactory.v1.utils.pytest import dist_env
|
||||
|
||||
|
||||
def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int):
|
||||
with dist_env(local_rank, world_size, master_port):
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||
|
||||
# Initialize distributed interface with config
|
||||
dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size}
|
||||
DistributedInterface(dist_config)
|
||||
|
||||
# Now create model engine
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
|
||||
# Apply sequence parallel plugin
|
||||
SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||
"labels": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
|
||||
"position_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||
"loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]),
|
||||
}
|
||||
|
||||
loss = sequence_parallel_loss(model_engine.model, model_inputs)
|
||||
assert loss is not None
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
@pytest.mark.require_distributed(2)
|
||||
@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)])
|
||||
def test_sequence_parallel_loss(cp_size, dp_size):
|
||||
master_port = find_available_port()
|
||||
world_size = cp_size * dp_size
|
||||
mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)
|
||||
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
|
||||
|
||||
Validates that the FSDP2 meta loading path behaves correctly for tied weights
|
||||
and non-persistent buffers by comparing it with the standard non-meta path.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||
|
||||
|
||||
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||
|
||||
|
||||
def collect_non_persistent_buffers(model):
|
||||
"""Collect all non-persistent buffers from model."""
|
||||
result = {}
|
||||
for mod_name, module in model.named_modules():
|
||||
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
|
||||
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||
buf = getattr(module, buf_name, None)
|
||||
if buf is not None:
|
||||
result[fqn] = buf.detach().cpu().clone()
|
||||
return result
|
||||
|
||||
|
||||
def test_fsdp2_meta_loading_buffers_and_tied_weights():
|
||||
"""Verify non-persistent buffers and tied weights consistency after meta load."""
|
||||
# 1. Initialize DistributedInterface for single process
|
||||
DistributedInterface()
|
||||
|
||||
# 2. Build FSDP2Engine config
|
||||
engine = FSDP2Engine(
|
||||
{
|
||||
"name": "fsdp2",
|
||||
"mixed_precision": "bf16",
|
||||
"reshard_after_forward": True,
|
||||
"offload_params": False,
|
||||
"pin_memory": False,
|
||||
"dcp_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MODEL)
|
||||
|
||||
# --- NORMAL PATH ---
|
||||
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
|
||||
normal_engine = ModelEngine(model_args=normal_args)
|
||||
normal_model = normal_engine.model.to(torch.bfloat16)
|
||||
|
||||
normal_model = engine.shard_model(normal_model)
|
||||
normal_non_persistent = collect_non_persistent_buffers(normal_model)
|
||||
|
||||
del normal_model
|
||||
|
||||
# --- META PATH ---
|
||||
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
|
||||
meta_model_engine = ModelEngine(model_args=meta_args)
|
||||
meta_model = meta_model_engine.model
|
||||
|
||||
assert meta_model.device.type == "meta", "Model should be on meta device"
|
||||
|
||||
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
|
||||
meta_model = engine.shard_model(meta_model)
|
||||
meta_non_persistent = collect_non_persistent_buffers(meta_model)
|
||||
|
||||
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
|
||||
|
||||
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
|
||||
"Weights should be tied after loading"
|
||||
)
|
||||
|
||||
del meta_model
|
||||
|
||||
# 4. Non-persistent buffers (e.g., inv_freq)
|
||||
normal_buf_keys = set(normal_non_persistent.keys())
|
||||
meta_buf_keys = set(meta_non_persistent.keys())
|
||||
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
|
||||
|
||||
for key in sorted(normal_buf_keys & meta_buf_keys):
|
||||
nb = normal_non_persistent[key]
|
||||
mb = meta_non_persistent[key]
|
||||
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
|
||||
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"
|
||||
Reference in New Issue
Block a user