[trainer] fix pt loss (#7748)

* fix pt loss

* robust

* fix

* test
This commit is contained in:
hoshi-hiyouga
2025-04-17 03:15:35 +08:00
committed by GitHub
parent 86ebb219d6
commit 39169986ef
10 changed files with 34 additions and 34 deletions

View File

@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
if TYPE_CHECKING:
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
check_version("transformers>=4.43.0,<4.48.0", mandatory=True)
check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -43,11 +43,6 @@ import torch
import torch.nn.functional as F
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
if TYPE_CHECKING:
@@ -116,5 +111,7 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")