move configure_packing to llamafactory.model.patcher and fix constants
Former-commit-id: 9c5e972c9c81957f2e9e30bf284ef1c076de9fd0
This commit is contained in:
@@ -19,7 +19,7 @@ from transformers.modeling_attn_mask_utils import (
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_MULTIPACK
|
||||
from ...extras.constants import SUPPORTED_CLASS_EFFECIENT_PACKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
@@ -303,7 +303,7 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments")
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_MULTIPACK:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
|
||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user