mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-29 14:43:08 +00:00
[v1] support ulysses cp for fsdp2 (#10262)
This commit is contained in:
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
cp_mode: ulysses
|
||||||
|
cp_size: 2
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/test_ulysses_cp
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
@@ -71,6 +71,7 @@ class BaseTrainer:
|
|||||||
# cached variables
|
# cached variables
|
||||||
self.device = DistributedInterface().current_device
|
self.device = DistributedInterface().current_device
|
||||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||||
|
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
|
||||||
self.model_input_names = self.renderer.processor.model_input_names
|
self.model_input_names = self.renderer.processor.model_input_names
|
||||||
|
|
||||||
self._create_batch_generator()
|
self._create_batch_generator()
|
||||||
@@ -114,6 +115,21 @@ class BaseTrainer:
|
|||||||
# Callbacks: TrainerState tracks progress across the full run.
|
# Callbacks: TrainerState tracks progress across the full run.
|
||||||
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||||
|
|
||||||
|
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||||
|
if model.config.model_type == "qwen3_5":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
|
||||||
|
)
|
||||||
|
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||||
|
|
||||||
|
if model.config._attn_implementation != "flash_attention_2":
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||||
|
)
|
||||||
|
model.config._attn_implementation = "flash_attention_2"
|
||||||
|
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||||
|
|
||||||
def _create_batch_generator(self) -> None:
|
def _create_batch_generator(self) -> None:
|
||||||
self.train_batch_generator = BatchGenerator(
|
self.train_batch_generator = BatchGenerator(
|
||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
@@ -172,7 +188,7 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
batch_size, _ = batch["labels"].shape
|
batch_size, _ = batch["labels"].shape
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
||||||
}
|
}
|
||||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||||
outputs: ModelOutput = model(**model_inputs)
|
outputs: ModelOutput = model(**model_inputs)
|
||||||
@@ -206,7 +222,14 @@ class BaseTrainer:
|
|||||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||||
num_micro = len(micro_batches)
|
num_micro = len(micro_batches)
|
||||||
for i, micro_batch in enumerate(micro_batches):
|
for i, micro_batch in enumerate(micro_batches):
|
||||||
loss = self.compute_loss(micro_batch)
|
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
from ..plugins.model_plugins.parallelization.sequence_parallel import (
|
||||||
|
SequenceParallelLossPlugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(micro_batch)
|
||||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||||
@@ -223,7 +246,24 @@ class BaseTrainer:
|
|||||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||||
|
|
||||||
|
parameters = self.model.parameters()
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
else:
|
||||||
|
parameters = list(parameters)
|
||||||
|
grads = [p.grad for p in parameters if p.grad is not None]
|
||||||
|
grad_norm = _get_total_norm(grads)
|
||||||
|
grad_norm = grad_norm.to(self.device)
|
||||||
|
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
|
||||||
|
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||||
|
grad_norm = grad_norm.full_tensor().item()
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), self.args.max_grad_norm
|
||||||
|
).item()
|
||||||
|
|
||||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||||
|
|||||||
@@ -146,6 +146,8 @@ class Renderer:
|
|||||||
for sample in samples:
|
for sample in samples:
|
||||||
if "messages" in sample:
|
if "messages" in sample:
|
||||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||||
|
if "position_ids" not in model_input:
|
||||||
|
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
|
||||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's verl library.
|
||||||
|
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_tensor(
|
||||||
|
local_input: Tensor,
|
||||||
|
scatter_dim: int,
|
||||||
|
gather_dim: int,
|
||||||
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
|
):
|
||||||
|
seq_world_size = dist.get_world_size(group)
|
||||||
|
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
||||||
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||||
|
dist.all_to_all(output_list, input_list, group=group)
|
||||||
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class SeqAllToAll4D(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx: Any,
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
local_input: Tensor,
|
||||||
|
scatter_dim: int,
|
||||||
|
gather_dim: int,
|
||||||
|
) -> Tensor:
|
||||||
|
ctx.group = group
|
||||||
|
ctx.scatter_dim = scatter_dim
|
||||||
|
ctx.gather_dim = gather_dim
|
||||||
|
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ....accelerator.interface import Dim, DistributedInterface
|
||||||
|
from ....utils import logging
|
||||||
|
from ....utils.plugin import BasePlugin
|
||||||
|
from ....utils.types import ModelOutput
|
||||||
|
from .ulysses import (
|
||||||
|
UlyssesAttention,
|
||||||
|
get_ulysses_sequence_parallel_group,
|
||||||
|
get_ulysses_sequence_parallel_rank,
|
||||||
|
get_ulysses_sequence_parallel_world_size,
|
||||||
|
set_ulysses_sequence_parallel_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelModelPlugin(BasePlugin):
|
||||||
|
def __call__(self, model, model_args):
|
||||||
|
return super().__call__(model, model_args)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelLossPlugin(BasePlugin):
|
||||||
|
def __call__(self, model, inputs, *args, **kwargs):
|
||||||
|
return super().__call__(model, inputs, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def new_flash_attn_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
sequence_parallel_size=1,
|
||||||
|
dropout=0,
|
||||||
|
deterministic=False,
|
||||||
|
is_causal=True,
|
||||||
|
group=None,
|
||||||
|
mode="ulysses",
|
||||||
|
attn_fn=None,
|
||||||
|
target_dtype=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if mode == "ulysses":
|
||||||
|
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
|
||||||
|
attn_output = dist_attn(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length=query_states.shape[1] * sequence_parallel_size,
|
||||||
|
deterministic=deterministic,
|
||||||
|
dropout_p=dropout,
|
||||||
|
causal=is_causal,
|
||||||
|
position_ids=kwargs.get("position_ids", None),
|
||||||
|
target_dtype=target_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
@SequenceParallelModelPlugin("ulysses").register()
|
||||||
|
def apply_sequence_parallel(model, model_args):
|
||||||
|
# Replace _flash_attention_forward with new_flash_attn_forward
|
||||||
|
module = sys.modules[model.__module__]
|
||||||
|
cp_size = model_args.get("cp_size", 1)
|
||||||
|
|
||||||
|
set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP))
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads
|
||||||
|
except AttributeError:
|
||||||
|
num_attention_heads, num_key_value_heads = (
|
||||||
|
model.config.text_config.num_attention_heads,
|
||||||
|
model.config.text_config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size"
|
||||||
|
assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, (
|
||||||
|
"num_key_value_heads must be divisible by cp_size"
|
||||||
|
)
|
||||||
|
|
||||||
|
origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||||
|
new_flash_attention_forward = partial(
|
||||||
|
new_flash_attn_forward,
|
||||||
|
group=get_ulysses_sequence_parallel_group(),
|
||||||
|
mode="ulysses",
|
||||||
|
attn_fn=origin_attn,
|
||||||
|
sequence_parallel_size=cp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for module_name, module in list(sys.modules.items()):
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
hasattr(module, "__file__")
|
||||||
|
and "transformers" in module.__file__
|
||||||
|
and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward"
|
||||||
|
):
|
||||||
|
module._flash_attention_forward = new_flash_attention_forward
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel."
|
||||||
|
)
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def padding_and_split_data(data, device_mesh=None):
|
||||||
|
if device_mesh is not None:
|
||||||
|
cp_size = device_mesh["cp"].size()
|
||||||
|
cp_rank = device_mesh["cp"].get_local_rank()
|
||||||
|
cp_group = device_mesh["cp"].get_group()
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim > 1:
|
||||||
|
data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64)
|
||||||
|
global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)]
|
||||||
|
dist.all_gather(global_data_len, data_len, group=cp_group)
|
||||||
|
max_data_len = max(global_data_len)
|
||||||
|
pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size
|
||||||
|
if k == "labels":
|
||||||
|
pad_value = -100
|
||||||
|
elif k == "loss_weights":
|
||||||
|
pad_value = 0.0
|
||||||
|
else:
|
||||||
|
pad_value = 0
|
||||||
|
pad_data = F.pad(v, (0, pad_size), value=pad_value)
|
||||||
|
data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@SequenceParallelLossPlugin("sequence_parallel_loss").register()
|
||||||
|
def sequence_parallel_loss(model, model_inputs):
|
||||||
|
device_mesh = DistributedInterface().get_device_mesh(Dim.CP)
|
||||||
|
|
||||||
|
model_inputs = {
|
||||||
|
k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
model_inputs = padding_and_split_data(model_inputs, device_mesh)
|
||||||
|
|
||||||
|
batch_size, _ = model_inputs["labels"].shape
|
||||||
|
|
||||||
|
outputs: ModelOutput = model(**model_inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits.float()
|
||||||
|
|
||||||
|
labels = model_inputs["labels"]
|
||||||
|
|
||||||
|
cp_group = get_ulysses_sequence_parallel_group()
|
||||||
|
cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group)
|
||||||
|
cp_rank = get_ulysses_sequence_parallel_rank(cp_group)
|
||||||
|
|
||||||
|
# use all_gather to collect labels from all sequence parallel processes
|
||||||
|
global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)]
|
||||||
|
dist.all_gather(global_labels, labels, group=cp_group)
|
||||||
|
labels = torch.cat(global_labels, dim=1).contiguous()
|
||||||
|
shift_labels = labels[..., 1:].view(-1).contiguous()
|
||||||
|
shift_labels = F.pad(shift_labels, (0, 1), value=-100)
|
||||||
|
shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous()
|
||||||
|
|
||||||
|
# use all_gather to collect loss_weights from all sequence parallel processes
|
||||||
|
loss_weights = model_inputs["loss_weights"]
|
||||||
|
global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)]
|
||||||
|
dist.all_gather(global_loss_weights, loss_weights, group=cp_group)
|
||||||
|
shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous()
|
||||||
|
shift_loss_weights = shift_loss_weights[..., 1:].contiguous()
|
||||||
|
|
||||||
|
shift_logits = logits.view(shift_labels.size(0), -1).contiguous()
|
||||||
|
|
||||||
|
# use all_gather to collect log_probs from all sequence parallel processes
|
||||||
|
log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||||
|
global_log_probs = dist.nn.all_gather(log_probs, group=cp_group)
|
||||||
|
global_log_probs = torch.cat(global_log_probs, dim=1).contiguous()
|
||||||
|
log_probs = global_log_probs[..., :-1].contiguous()
|
||||||
|
|
||||||
|
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's verl library.
|
||||||
|
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from .seq_comm import SeqAllToAll4D
|
||||||
|
|
||||||
|
|
||||||
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
||||||
|
"""Set ulysses sequence parallel process group."""
|
||||||
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
||||||
|
"""Get ulysses sequence parallel process group."""
|
||||||
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
||||||
|
"""Get ulysses sequence parallel world size."""
|
||||||
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||||
|
return dist.get_world_size(group) if group else 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
||||||
|
"""Get ulysses sequence parallel rank."""
|
||||||
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||||
|
return dist.get_rank(group) if group else 0
|
||||||
|
|
||||||
|
|
||||||
|
class UlyssesAttention(torch.nn.Module):
|
||||||
|
"""Initialization.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
local_attention (Module): local attention with q,k,v
|
||||||
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||||
|
scatter_idx (int): scatter_idx for all2all comm
|
||||||
|
gather_idx (int): gather_idx for all2all comm
|
||||||
|
attn_type (AttnType): attention type enum
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_process_group: dist.ProcessGroup = None,
|
||||||
|
scatter_idx: int = 2,
|
||||||
|
gather_idx: int = 1,
|
||||||
|
attn_fn: Optional[callable] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.spg = sequence_process_group
|
||||||
|
self.scatter_idx = scatter_idx
|
||||||
|
self.gather_idx = gather_idx
|
||||||
|
self.attn_fn = attn_fn
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
causal=True,
|
||||||
|
deterministic=False,
|
||||||
|
target_dtype=None,
|
||||||
|
*args: Any,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Forward.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
query (Tensor): query input to the layer
|
||||||
|
key (Tensor): key input to the layer
|
||||||
|
value (Tensor): value input to the layer
|
||||||
|
attention_mask (Tensor): attention mask for the layer
|
||||||
|
query_length (int): the length of the query sequence
|
||||||
|
dropout_p (float, optional): dropout probability. Defaults to 0.0.
|
||||||
|
softmax_scale (float, optional): scale factor for softmax. Defaults to None,
|
||||||
|
position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None.
|
||||||
|
causal (bool, optional): whether to apply causal mask. Defaults to True.
|
||||||
|
deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False.
|
||||||
|
target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None.
|
||||||
|
args: other args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* output (Tensor): context output
|
||||||
|
"""
|
||||||
|
# TODO Merge three alltoall calls into one
|
||||||
|
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||||
|
# in shape : e.g., [s/p:h:]
|
||||||
|
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||||
|
|
||||||
|
# scatter 2, gather 1
|
||||||
|
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||||
|
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||||
|
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
if position_ids is not None:
|
||||||
|
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(torch.int64)
|
||||||
|
|
||||||
|
global_attention_mask = [
|
||||||
|
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||||
|
]
|
||||||
|
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||||
|
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||||
|
|
||||||
|
context_layer = self.attn_fn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attention_mask,
|
||||||
|
query_length=query_length,
|
||||||
|
is_causal=causal,
|
||||||
|
dropout=dropout_p,
|
||||||
|
position_ids=position_ids,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
deterministic=deterministic,
|
||||||
|
target_dtype=target_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(context_layer, tuple):
|
||||||
|
context_layer = context_layer[0]
|
||||||
|
|
||||||
|
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
|
||||||
|
# scatter 1, gather 2
|
||||||
|
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
||||||
|
|
||||||
|
# out e.g., [s/p::h]
|
||||||
|
return output
|
||||||
@@ -85,10 +85,7 @@ class FSDP2Engine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.device_mesh is not None:
|
if self.device_mesh is not None:
|
||||||
try:
|
self.fsdp_mesh = self.device_mesh
|
||||||
self.fsdp_mesh = self.device_mesh["dp"]
|
|
||||||
except Exception:
|
|
||||||
self.fsdp_mesh = self.device_mesh
|
|
||||||
|
|
||||||
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.model_args import ModelArguments
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import (
|
||||||
|
SequenceParallelModelPlugin,
|
||||||
|
sequence_parallel_loss,
|
||||||
|
)
|
||||||
|
from llamafactory.v1.utils.env import find_available_port
|
||||||
|
from llamafactory.v1.utils.pytest import dist_env
|
||||||
|
|
||||||
|
|
||||||
|
def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int):
|
||||||
|
with dist_env(local_rank, world_size, master_port):
|
||||||
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||||
|
|
||||||
|
# Initialize distributed interface with config
|
||||||
|
dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size}
|
||||||
|
DistributedInterface(dist_config)
|
||||||
|
|
||||||
|
# Now create model engine
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
|
||||||
|
# Apply sequence parallel plugin
|
||||||
|
SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config)
|
||||||
|
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"labels": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
|
||||||
|
"position_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]),
|
||||||
|
}
|
||||||
|
|
||||||
|
loss = sequence_parallel_loss(model_engine.model, model_inputs)
|
||||||
|
assert loss is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
|
@pytest.mark.require_distributed(2)
|
||||||
|
@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)])
|
||||||
|
def test_sequence_parallel_loss(cp_size, dp_size):
|
||||||
|
master_port = find_available_port()
|
||||||
|
world_size = cp_size * dp_size
|
||||||
|
mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)
|
||||||
Reference in New Issue
Block a user