[misc] fix packing and eval plot (#7623)
This commit is contained in:
@@ -24,7 +24,6 @@ import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.misc import get_current_device
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
_, seq_len = attention_mask_with_indices.size()
|
||||
|
||||
# Move to compute device if the source is CPU.
|
||||
source_device = attention_mask_with_indices.device
|
||||
compute_device = get_current_device() if source_device.type == "cpu" else source_device
|
||||
if compute_device != source_device:
|
||||
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
|
||||
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
|
||||
zero_tensor = torch.tensor(0, dtype=dtype)
|
||||
|
||||
# Create a non-padding mask.
|
||||
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
||||
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
|
||||
# Create indices for comparison.
|
||||
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
|
||||
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
|
||||
# Create a lower triangular mask.
|
||||
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
|
||||
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
|
||||
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
|
||||
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
|
||||
# Invert the attention mask.
|
||||
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
|
||||
|
||||
# Move back to original device if needed.
|
||||
if compute_device != source_device:
|
||||
attention_mask_4d = attention_mask_4d.to(source_device)
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
|
||||
@@ -493,8 +493,8 @@ class Llama4Plugin(BasePlugin):
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||
if self.expand_mm_tokens:
|
||||
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
||||
new_content = []
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
@@ -507,6 +507,8 @@ class Llama4Plugin(BasePlugin):
|
||||
new_content.append(tokens_for_this_image)
|
||||
|
||||
content = "".join(new_content)
|
||||
else:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
|
||||
@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
|
||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
|
||||
packed_labels += batch_labels[index]
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
packed_position_ids += list(range(len(batch_input_ids[index])))
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
|
||||
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
|
||||
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
|
||||
packed_position_ids += [0] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
packed_position_ids += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["position_ids"].append(packed_position_ids)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
model_inputs["images"].append(packed_images or None)
|
||||
model_inputs["videos"].append(packed_videos or None)
|
||||
model_inputs["audios"].append(packed_audios or None)
|
||||
model_inputs["position_ids"].append(packed_position_ids or None)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -1370,7 +1370,7 @@ register_template(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen"),
|
||||
default_system="You are a helpful assistant.",
|
||||
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from enum import Enum, unique
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
@@ -115,6 +115,19 @@ class DownloadSource(str, Enum):
|
||||
OPENMIND = "om"
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
|
||||
|
||||
BNB = "bnb"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
class RopeScaling(str, Enum):
|
||||
LINEAR = "linear"
|
||||
DYNAMIC = "dynamic"
|
||||
|
||||
@@ -160,5 +160,11 @@ class DataArguments:
|
||||
if self.mask_history and self.train_on_prompt:
|
||||
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
|
||||
|
||||
if self.neat_packing:
|
||||
self.packing = True
|
||||
|
||||
if self.packing:
|
||||
self.cutoff_len -= 1 # avoid pad_to_multiple_of, needs improve
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -184,8 +184,8 @@ class BaseModelArguments:
|
||||
class QuantizationArguments:
|
||||
r"""Arguments pertaining to the quantization method."""
|
||||
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
quantization_method: QuantizationMethod = field(
|
||||
default=QuantizationMethod.BNB,
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
|
||||
@@ -135,7 +135,7 @@ def _check_extra_dependencies(
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == EngineName.VLLM:
|
||||
check_version("vllm>=0.4.3,<=0.8.2")
|
||||
check_version("vllm>=0.4.3,<=0.8.3")
|
||||
check_version("vllm", mandatory=True)
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.4")
|
||||
@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and not data_args.packing:
|
||||
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
|
||||
@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Processor was not found: {e}.")
|
||||
logger.debug(f"Failed to load processor: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||
if processor is not None and "Processor" not in processor.__class__.__name__:
|
||||
logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.constants import FILEEXT2TYPE, QuantizationMethod
|
||||
from ...extras.misc import check_version, get_current_device
|
||||
|
||||
|
||||
@@ -41,19 +40,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
|
||||
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
@@ -145,7 +131,7 @@ def configure_quantization(
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_method == QuantizationMethod.BNB:
|
||||
if model_args.quantization_bit == 8:
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@@ -173,7 +159,7 @@ def configure_quantization(
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
@@ -185,7 +171,7 @@ def configure_quantization(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
|
||||
@@ -91,7 +91,13 @@ def run_dpo(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
|
||||
keys = ["loss", "rewards/accuracies"]
|
||||
if isinstance(dataset_module["eval_dataset"], dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
||||
@@ -82,7 +82,13 @@ def run_kto(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
|
||||
keys = ["loss", "rewards/chosen"]
|
||||
if isinstance(dataset_module["eval_dataset"], dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
||||
@@ -66,7 +66,13 @@ def run_pt(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module["eval_dataset"], dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
||||
@@ -74,7 +74,15 @@ def run_rm(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module["eval_dataset"], dict):
|
||||
keys += sum(
|
||||
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
|
||||
)
|
||||
else:
|
||||
keys += ["eval_loss", "eval_accuracy"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
|
||||
@@ -110,7 +110,15 @@ def run_sft(
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module["eval_dataset"], dict):
|
||||
keys += sum(
|
||||
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
|
||||
)
|
||||
else:
|
||||
keys += ["eval_loss", "eval_accuracy"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
@@ -42,7 +42,7 @@ def create_top() -> dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
|
||||
quantization_method = gr.Dropdown(choices=["bnb", "hqq", "eetq"], value="bnb")
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
|
||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
|
||||
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")
|
||||
|
||||
Reference in New Issue
Block a user