allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -21,8 +21,11 @@ logger = get_logger(__name__)
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
if not data_args.packing:
|
||||
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
@@ -245,7 +248,7 @@ def get_preprocess_and_print_func(
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.sft_packing:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
|
||||
@@ -36,8 +36,8 @@ class Template:
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
@@ -56,8 +56,8 @@ class Template:
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
reserved_label_len: Optional[int] = 1,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
@@ -207,11 +207,11 @@ def _register_template(
|
||||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
default_system: Optional[str] = "",
|
||||
stop_words: Optional[List[str]] = [],
|
||||
efficient_eos: Optional[bool] = False,
|
||||
replace_eos: Optional[bool] = False,
|
||||
force_system: Optional[bool] = False,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
force_system: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@@ -279,9 +279,7 @@ def _jinja_escape(content: str) -> str:
|
||||
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||
|
||||
|
||||
def _convert_slots_to_jinja(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: Optional[str] = "content"
|
||||
) -> str:
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
|
||||
Reference in New Issue
Block a user