update hparams

Former-commit-id: 1c4feac44192b1f540208837f5a530b0d3f5fb37
This commit is contained in:
hiyouga
2024-07-03 23:18:58 +08:00
parent 8ac4f87c91
commit 5acaa476d6
8 changed files with 72 additions and 28 deletions

View File

@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +22,44 @@ import torch
from transformers import DataCollatorForSeq2Seq
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```
[1, 1, 2, 2, 2, 0]
```
->
```
[[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
return attention_mask_4d
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""

View File

@@ -177,7 +177,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
data_args, model_args, training_args, stage, template, tokenizer, processor
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}

View File

@@ -29,13 +29,12 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
@@ -50,7 +49,7 @@ def get_preprocess_and_print_func(
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing or model_args.efficient_packing:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,

View File

@@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments, ModelArguments
from ...hparams import DataArguments
from ..template import Template
@@ -125,7 +125,6 @@ def preprocess_packed_supervised_dataset(
template: "Template",
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
model_args: "ModelArguments"
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
@@ -161,26 +160,30 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_mask, packed_labels = [], [], []
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_attention_mask += [i+1]*len(batch_input_ids[index])
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
if model_args.efficient_packing:
model_inputs["attention_mask"].append(packed_attention_mask)
else:
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
return model_inputs