move efficient_packing from data_args to model_args
Former-commit-id: 7b61659c707480bcf8c802c73e10d12ad5b9b965
This commit is contained in:
@@ -177,7 +177,7 @@ def get_dataset(
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
data_args, training_args, stage, template, tokenizer, processor
|
||||
data_args, model_args, training_args, stage, template, tokenizer, processor
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
|
||||
@@ -29,12 +29,13 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .template import Template
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
model_args: "ModelArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
@@ -49,7 +50,7 @@ def get_preprocess_and_print_func(
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing or data_args.efficient_packing:
|
||||
if data_args.packing or model_args.efficient_packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
|
||||
@@ -23,7 +23,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ...hparams import DataArguments, ModelArguments
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ def preprocess_packed_supervised_dataset(
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
model_args: "ModelArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
@@ -176,7 +177,7 @@ def preprocess_packed_supervised_dataset(
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
if data_args.efficient_packing:
|
||||
if model_args.efficient_packing:
|
||||
model_inputs["attention_mask"].append(packed_attention_mask)
|
||||
else:
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
|
||||
Reference in New Issue
Block a user