mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -24,7 +24,6 @@ import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.misc import get_current_device
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
_, seq_len = attention_mask_with_indices.size()
|
||||
|
||||
# Move to compute device if the source is CPU.
|
||||
source_device = attention_mask_with_indices.device
|
||||
compute_device = get_current_device() if source_device.type == "cpu" else source_device
|
||||
if compute_device != source_device:
|
||||
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
|
||||
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
|
||||
zero_tensor = torch.tensor(0, dtype=dtype)
|
||||
|
||||
# Create a non-padding mask.
|
||||
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
||||
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
||||
# Create indices for comparison.
|
||||
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
|
||||
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
|
||||
# Create a lower triangular mask.
|
||||
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
|
||||
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
|
||||
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
|
||||
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
|
||||
# Invert the attention mask.
|
||||
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
|
||||
|
||||
# Move back to original device if needed.
|
||||
if compute_device != source_device:
|
||||
attention_mask_4d = attention_mask_4d.to(source_device)
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user