Former-commit-id: 7d917e03e2df570139bae18227d9c7303a12de2a
This commit is contained in:
hiyouga
2024-08-09 18:03:00 +08:00
parent 4e8861e653
commit 5af32ce705
6 changed files with 35 additions and 34 deletions

View File

@@ -206,8 +206,6 @@ def get_dataset(
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if stage!="sft" and data_args.mask_history:
raise ValueError("`Train on the last turn only` is only valid for sft training.")
# Load tokenized dataset
if data_args.tokenized_path is not None:

View File

@@ -53,8 +53,11 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools, mask_history)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break
@@ -66,20 +69,23 @@ def _encode_supervised_example(
if train_on_prompt:
source_label = source_ids
elif turn_idx != 0 and template.efficient_eos:
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if mask_history:
target_label = target_ids if turn_idx==0 else [IGNORE_INDEX] * target_len
if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
target_label = target_ids
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]

View File

@@ -69,16 +69,12 @@ class Template:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
mask_history: bool = False,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
if not mask_history:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
else:
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(len(encoded_messages)-2, -1, -2)]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
@@ -594,10 +590,10 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
"and other non-computer science questions, you will refuse to answer.\n"
),
)