[misc] fix packing and eval plot (#7623)

This commit is contained in:
hoshi-hiyouga
2025-04-07 18:20:57 +08:00
committed by GitHub
parent 5115dc8c7f
commit c3c0efbaa0
70 changed files with 288 additions and 194 deletions

View File

@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available
@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
zero_tensor = torch.tensor(0, dtype=dtype)
# Create a non-padding mask.
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d

View File

@@ -493,8 +493,8 @@ class Llama4Plugin(BasePlugin):
messages = deepcopy(messages)
for message in messages:
content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
placeholder_count = content.count(IMAGE_PLACEHOLDER)
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
@@ -507,6 +507,8 @@ class Llama4Plugin(BasePlugin):
new_content.append(tokens_for_this_image)
content = "".join(new_content)
else:
content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
message["content"] = content

View File

@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_position_ids += [0] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs

View File

@@ -1370,7 +1370,7 @@ register_template(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"],
)