imporve log

Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
hiyouga
2025-01-08 09:56:10 +00:00
parent 3b843ac9d4
commit 647c51a772
16 changed files with 78 additions and 67 deletions

View File

@@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")