[trainer] fix pt loss (#7748)

* fix pt loss

* robust

* fix

* test
This commit is contained in:
hoshi-hiyouga
2025-04-17 03:15:35 +08:00
committed by GitHub
parent 86ebb219d6
commit 39169986ef
10 changed files with 34 additions and 34 deletions

View File

@@ -43,11 +43,6 @@ import torch
import torch.nn.functional as F
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
if TYPE_CHECKING:
@@ -116,5 +111,7 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")