From b5afabe3d22590ba7093a38aa453edcccbccbbf1 Mon Sep 17 00:00:00 2001 From: sunyi0505 <1659275352@qq.com> Date: Fri, 27 Mar 2026 16:22:48 +0800 Subject: [PATCH] [v1] support ulysses cp for fsdp2 (#10262) --- .../v1/train_full/train_full_ulysses_cp.yaml | 23 ++ src/llamafactory/v1/core/base_trainer.py | 46 +++- src/llamafactory/v1/core/utils/rendering.py | 2 + .../model_plugins/parallelization/seq_comm.py | 59 ++++++ .../parallelization/sequence_parallel.py | 199 ++++++++++++++++++ .../model_plugins/parallelization/ulysses.py | 163 ++++++++++++++ .../trainer_plugins/distributed/fsdp2.py | 5 +- .../plugins/model_plugins/test_ulysses_cp.py | 62 ++++++ 8 files changed, 552 insertions(+), 7 deletions(-) create mode 100644 examples/v1/train_full/train_full_ulysses_cp.yaml create mode 100644 src/llamafactory/v1/plugins/model_plugins/parallelization/seq_comm.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py create mode 100644 tests_v1/plugins/model_plugins/test_ulysses_cp.py diff --git a/examples/v1/train_full/train_full_ulysses_cp.yaml b/examples/v1/train_full/train_full_ulysses_cp.yaml new file mode 100644 index 000000000..2b7ba713a --- /dev/null +++ b/examples/v1/train_full/train_full_ulysses_cp.yaml @@ -0,0 +1,23 @@ +model: Qwen/Qwen3-0.6B +trust_remote_code: true +model_class: llm + +template: qwen3_nothink + +# FSDP Config +dist_config: + name: fsdp2 + dcp_path: null + cp_mode: ulysses + cp_size: 2 + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_ulysses_cp +micro_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: false +max_steps: 10 diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index cee4abb81..a8afffec1 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -71,6 +71,7 @@ class BaseTrainer: # cached variables self.device = DistributedInterface().current_device self.dp_size = DistributedInterface().get_world_size(Dim.DP) + self.cp_size = DistributedInterface().get_world_size(Dim.CP) self.model_input_names = self.renderer.processor.model_input_names self._create_batch_generator() @@ -114,6 +115,21 @@ class BaseTrainer: # Callbacks: TrainerState tracks progress across the full run. self.state = TrainerState(num_training_steps=self.num_training_steps) + if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1: + # qwen3.5 is not supported because of the different attention implementation, which will be supported in the future. + if model.config.model_type == "qwen3_5": + raise RuntimeError( + "Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future." + ) + from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin + + if model.config._attn_implementation != "flash_attention_2": + logger.warning_rank0( + "Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2." + ) + model.config._attn_implementation = "flash_attention_2" + SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config) + def _create_batch_generator(self) -> None: self.train_batch_generator = BatchGenerator( dataset=self.train_dataset, @@ -172,7 +188,7 @@ class BaseTrainer: """ batch_size, _ = batch["labels"].shape model_inputs = { - k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names + k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor) } labels = batch["labels"].to(self.device, non_blocking=True) outputs: ModelOutput = model(**model_inputs) @@ -206,7 +222,14 @@ class BaseTrainer: step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) num_micro = len(micro_batches) for i, micro_batch in enumerate(micro_batches): - loss = self.compute_loss(micro_batch) + if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1: + from ..plugins.model_plugins.parallelization.sequence_parallel import ( + SequenceParallelLossPlugin, + ) + + loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch) + else: + loss = self.compute_loss(micro_batch) mini_step_valid_tokens = compute_valid_tokens([micro_batch]) # fsdp uses mean reduction so we need to scale the loss by dp_size loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) @@ -223,7 +246,24 @@ class BaseTrainer: # deepspeed: engine.step() already ran inside backward at the sync boundary grad_norm = self._deepspeed_engine.get_grad_norm() else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() + if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1: + from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm + + parameters = self.model.parameters() + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + parameters = list(parameters) + grads = [p.grad for p in parameters if p.grad is not None] + grad_norm = _get_total_norm(grads) + grad_norm = grad_norm.to(self.device) + _clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm) + if isinstance(grad_norm, torch.distributed._tensor.DTensor): + grad_norm = grad_norm.full_tensor().item() + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args.max_grad_norm + ).item() # isfinite(): argument 'input' (position 1) must be Tensor, not float if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py index 2cb22bfd6..b49d1bf99 100644 --- a/src/llamafactory/v1/core/utils/rendering.py +++ b/src/llamafactory/v1/core/utils/rendering.py @@ -146,6 +146,8 @@ class Renderer: for sample in samples: if "messages" in sample: model_input = self.render_messages(sample["messages"], sample.get("tools")) + if "position_ids" not in model_input: + model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1)) elif "chosen_messages" in sample and "rejected_messages" in sample: chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) diff --git a/src/llamafactory/v1/plugins/model_plugins/parallelization/seq_comm.py b/src/llamafactory/v1/plugins/model_plugins/parallelization/seq_comm.py new file mode 100644 index 000000000..3460c1394 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/parallelization/seq_comm.py @@ -0,0 +1,59 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's verl library. +# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch import Tensor + + +def all_to_all_tensor( + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + group: Optional[dist.ProcessGroup] = None, +): + seq_world_size = dist.get_world_size(group) + input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class SeqAllToAll4D(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + ) -> Tensor: + ctx.group = group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + return all_to_all_tensor(local_input, scatter_dim, gather_dim, group) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: + return ( + None, + all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group), + None, + None, + ) diff --git a/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py b/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py new file mode 100644 index 000000000..35d7b0323 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py @@ -0,0 +1,199 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from functools import partial + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import transformers + +from ....accelerator.interface import Dim, DistributedInterface +from ....utils import logging +from ....utils.plugin import BasePlugin +from ....utils.types import ModelOutput +from .ulysses import ( + UlyssesAttention, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_rank, + get_ulysses_sequence_parallel_world_size, + set_ulysses_sequence_parallel_group, +) + + +logger = logging.get_logger(__name__) + + +class SequenceParallelModelPlugin(BasePlugin): + def __call__(self, model, model_args): + return super().__call__(model, model_args) + + +class SequenceParallelLossPlugin(BasePlugin): + def __call__(self, model, inputs, *args, **kwargs): + return super().__call__(model, inputs, *args, **kwargs) + + +def new_flash_attn_forward( + query_states, + key_states, + value_states, + attention_mask, + sequence_parallel_size=1, + dropout=0, + deterministic=False, + is_causal=True, + group=None, + mode="ulysses", + attn_fn=None, + target_dtype=None, + **kwargs, +): + if mode == "ulysses": + dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn) + attn_output = dist_attn( + query_states, + key_states, + value_states, + attention_mask, + query_length=query_states.shape[1] * sequence_parallel_size, + deterministic=deterministic, + dropout_p=dropout, + causal=is_causal, + position_ids=kwargs.get("position_ids", None), + target_dtype=target_dtype, + ) + else: + raise NotImplementedError("Other sequence parallel modes are to be implemented.") + + return attn_output + + +@SequenceParallelModelPlugin("ulysses").register() +def apply_sequence_parallel(model, model_args): + # Replace _flash_attention_forward with new_flash_attn_forward + module = sys.modules[model.__module__] + cp_size = model_args.get("cp_size", 1) + + set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP)) + + try: + num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads + except AttributeError: + num_attention_heads, num_key_value_heads = ( + model.config.text_config.num_attention_heads, + model.config.text_config.num_key_value_heads, + ) + + assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size" + assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, ( + "num_key_value_heads must be divisible by cp_size" + ) + + origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward + new_flash_attention_forward = partial( + new_flash_attn_forward, + group=get_ulysses_sequence_parallel_group(), + mode="ulysses", + attn_fn=origin_attn, + sequence_parallel_size=cp_size, + ) + + for module_name, module in list(sys.modules.items()): + try: + if ( + hasattr(module, "__file__") + and "transformers" in module.__file__ + and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward" + ): + module._flash_attention_forward = new_flash_attention_forward + logger.info_rank0( + f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel." + ) + except (AttributeError, TypeError): + continue + + +def padding_and_split_data(data, device_mesh=None): + if device_mesh is not None: + cp_size = device_mesh["cp"].size() + cp_rank = device_mesh["cp"].get_local_rank() + cp_group = device_mesh["cp"].get_group() + for k, v in data.items(): + if isinstance(v, torch.Tensor) and v.ndim > 1: + data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64) + global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)] + dist.all_gather(global_data_len, data_len, group=cp_group) + max_data_len = max(global_data_len) + pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size + if k == "labels": + pad_value = -100 + elif k == "loss_weights": + pad_value = 0.0 + else: + pad_value = 0 + pad_data = F.pad(v, (0, pad_size), value=pad_value) + data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous() + return data + + +@SequenceParallelLossPlugin("sequence_parallel_loss").register() +def sequence_parallel_loss(model, model_inputs): + device_mesh = DistributedInterface().get_device_mesh(Dim.CP) + + model_inputs = { + k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor) + } + + model_inputs = padding_and_split_data(model_inputs, device_mesh) + + batch_size, _ = model_inputs["labels"].shape + + outputs: ModelOutput = model(**model_inputs) + + logits = outputs.logits.float() + + labels = model_inputs["labels"] + + cp_group = get_ulysses_sequence_parallel_group() + cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group) + cp_rank = get_ulysses_sequence_parallel_rank(cp_group) + + # use all_gather to collect labels from all sequence parallel processes + global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)] + dist.all_gather(global_labels, labels, group=cp_group) + labels = torch.cat(global_labels, dim=1).contiguous() + shift_labels = labels[..., 1:].view(-1).contiguous() + shift_labels = F.pad(shift_labels, (0, 1), value=-100) + shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous() + + # use all_gather to collect loss_weights from all sequence parallel processes + loss_weights = model_inputs["loss_weights"] + global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)] + dist.all_gather(global_loss_weights, loss_weights, group=cp_group) + shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous() + shift_loss_weights = shift_loss_weights[..., 1:].contiguous() + + shift_logits = logits.view(shift_labels.size(0), -1).contiguous() + + # use all_gather to collect log_probs from all sequence parallel processes + log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1) + global_log_probs = dist.nn.all_gather(log_probs, group=cp_group) + global_log_probs = torch.cat(global_log_probs, dim=1).contiguous() + log_probs = global_log_probs[..., :-1].contiguous() + + loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6) + + return loss diff --git a/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py b/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py new file mode 100644 index 000000000..1dcd9be05 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py @@ -0,0 +1,163 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team. +# +# This code is inspired by the Bytedance's verl library. +# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +from .seq_comm import SeqAllToAll4D + + +_ULYSSES_SEQUENCE_PARALLEL_GROUP = None + + +def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup): + """Set ulysses sequence parallel process group.""" + global _ULYSSES_SEQUENCE_PARALLEL_GROUP + _ULYSSES_SEQUENCE_PARALLEL_GROUP = group + + +def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]: + """Get ulysses sequence parallel process group.""" + global _ULYSSES_SEQUENCE_PARALLEL_GROUP + return _ULYSSES_SEQUENCE_PARALLEL_GROUP + + +def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int: + """Get ulysses sequence parallel world size.""" + group = get_ulysses_sequence_parallel_group() if group is None else group + return dist.get_world_size(group) if group else 1 + + +def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int: + """Get ulysses sequence parallel rank.""" + group = get_ulysses_sequence_parallel_group() if group is None else group + return dist.get_rank(group) if group else 0 + + +class UlyssesAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + attn_type (AttnType): attention type enum + """ + + def __init__( + self, + sequence_process_group: dist.ProcessGroup = None, + scatter_idx: int = 2, + gather_idx: int = 1, + attn_fn: Optional[callable] = None, + ) -> None: + + super().__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.attn_fn = attn_fn + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attention_mask: torch.Tensor, + query_length: int, + dropout_p=0.0, + softmax_scale=None, + position_ids: Optional[torch.Tensor] = None, + causal=True, + deterministic=False, + target_dtype=None, + *args: Any, + ) -> Tensor: + """Forward. + + Arguments: + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + attention_mask (Tensor): attention mask for the layer + query_length (int): the length of the query sequence + dropout_p (float, optional): dropout probability. Defaults to 0.0. + softmax_scale (float, optional): scale factor for softmax. Defaults to None, + position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None. + causal (bool, optional): whether to apply causal mask. Defaults to True. + deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False. + target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None. + args: other args + + Returns: + * output (Tensor): context output + """ + # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! + # in shape : e.g., [s/p:h:] + # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) + + # scatter 2, gather 1 + q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx) + k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx) + v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** -0.5 + + if attention_mask is None: + if position_ids is not None: + attention_mask = torch.ones_like(position_ids).to(torch.int64) + else: + attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device) + else: + attention_mask = attention_mask.to(torch.int64) + + global_attention_mask = [ + torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg)) + ] + dist.all_gather(global_attention_mask, attention_mask, group=self.spg) + attention_mask = torch.cat(global_attention_mask, dim=1) + + context_layer = self.attn_fn( + q, + k, + v, + attention_mask, + query_length=query_length, + is_causal=causal, + dropout=dropout_p, + position_ids=position_ids, + softmax_scale=softmax_scale, + deterministic=deterministic, + target_dtype=target_dtype, + ) + + if isinstance(context_layer, tuple): + context_layer = context_layer[0] + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + + # out e.g., [s/p::h] + return output diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index bf6b09b87..7ba7130dc 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -85,10 +85,7 @@ class FSDP2Engine: ) if self.device_mesh is not None: - try: - self.fsdp_mesh = self.device_mesh["dp"] - except Exception: - self.fsdp_mesh = self.device_mesh + self.fsdp_mesh = self.device_mesh logger.info(f"Using Device Mesh: {self.fsdp_mesh}") else: diff --git a/tests_v1/plugins/model_plugins/test_ulysses_cp.py b/tests_v1/plugins/model_plugins/test_ulysses_cp.py new file mode 100644 index 000000000..8de66c027 --- /dev/null +++ b/tests_v1/plugins/model_plugins/test_ulysses_cp.py @@ -0,0 +1,62 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.multiprocessing as mp + +from llamafactory.v1.accelerator.interface import DistributedInterface +from llamafactory.v1.config.model_args import ModelArguments +from llamafactory.v1.core.model_engine import ModelEngine +from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import ( + SequenceParallelModelPlugin, + sequence_parallel_loss, +) +from llamafactory.v1.utils.env import find_available_port +from llamafactory.v1.utils.pytest import dist_env + + +def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int): + with dist_env(local_rank, world_size, master_port): + model_args = ModelArguments(model="llamafactory/tiny-random-qwen3") + + # Initialize distributed interface with config + dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size} + DistributedInterface(dist_config) + + # Now create model engine + model_engine = ModelEngine(model_args=model_args) + + # Apply sequence parallel plugin + SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config) + + model_inputs = { + "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "labels": torch.tensor([[1, 2, 3, 4, 5]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + "position_ids": torch.tensor([[1, 2, 3, 4, 5]]), + "loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]), + } + + loss = sequence_parallel_loss(model_engine.model, model_inputs) + assert loss is not None + + +@pytest.mark.runs_on(["cuda", "npu"]) +@pytest.mark.require_distributed(2) +@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)]) +def test_sequence_parallel_loss(cp_size, dp_size): + master_port = find_available_port() + world_size = cp_size * dp_size + mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)