Merge branch 'hiyouga:main' into minicpmv

Former-commit-id: 873b2d5888038e2328a12a6eb7c84099ba7ca1f3
This commit is contained in:
Zhangchi Feng
2025-01-04 11:20:33 +08:00
committed by GitHub
16 changed files with 155 additions and 54 deletions

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
@@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask

View File

@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

View File

@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))