Merge branch 'hiyouga:main' into main

Former-commit-id: 0395d0aafb69e86645e6b0a36b8f8dadb82219e0
This commit is contained in:
Johann-Peter Hartmann
2024-02-04 12:51:25 +00:00
committed by GitHub
13 changed files with 83 additions and 43 deletions

View File

@@ -155,9 +155,6 @@ def get_dataset(
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split

View File

@@ -22,12 +22,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
@@ -59,7 +55,12 @@ def preprocess_supervised_dataset(
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
@@ -147,7 +148,12 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
@@ -176,10 +182,20 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:

View File

@@ -37,7 +37,7 @@ class Template:
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
@@ -57,7 +57,7 @@ class Template:
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
@@ -144,10 +144,10 @@ class Template:
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@@ -218,7 +218,7 @@ def register_template(
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(slots="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
@@ -356,6 +356,14 @@ register_template(
)
register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@@ -464,7 +472,7 @@ register_template(
register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: </s>"]),
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)