allow non-packing pretraining

Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent c631799f5d
commit 4881f4e631
22 changed files with 64 additions and 67 deletions

View File

@@ -21,8 +21,11 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@@ -245,7 +248,7 @@ def get_preprocess_and_print_func(
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)

View File

@@ -36,8 +36,8 @@ class Template:
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
@@ -56,8 +56,8 @@ class Template:
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 1,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
@@ -207,11 +207,11 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
default_system: Optional[str] = "",
stop_words: Optional[List[str]] = [],
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False,
force_system: Optional[bool] = False,
default_system: str = "",
stop_words: List[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
) -> None:
r"""
Registers a chat template.
@@ -279,9 +279,7 @@ def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: Optional[str] = "content"
) -> str:
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):