fix packing for eager/sdpa attn
Former-commit-id: 735a033ceb7f2da6da71d138ea091d8a665411a9
This commit is contained in:
@@ -115,7 +115,9 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
|
||||
def patch_for_block_diag_attn(model_type: str) -> None:
|
||||
if model_type == "falcon":
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
||||
|
||||
Reference in New Issue
Block a user