modify style
Former-commit-id: 54b713d0c4ffdfc6a7faeb14471b58bb1cd8acf5
This commit is contained in:
@@ -3,6 +3,7 @@ from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
|
||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
) -> Dict[str, List[Any]]:
|
||||
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
@@ -33,16 +31,11 @@ def convert_alpaca(
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
|
||||
if dataset_attr.response and isinstance(
|
||||
examples[dataset_attr.response][i], list
|
||||
):
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content}
|
||||
for content in examples[dataset_attr.response][i]
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(
|
||||
examples[dataset_attr.response][i], str
|
||||
):
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
response = [
|
||||
{
|
||||
"role": Role.ASSISTANT.value,
|
||||
@@ -54,17 +47,13 @@ def convert_alpaca(
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(
|
||||
examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["images"].append([])
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
) -> Dict[str, List[Any]]:
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
@@ -77,10 +66,7 @@ def convert_sharegpt(
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if (
|
||||
dataset_attr.system_tag
|
||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
||||
):
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
@@ -105,17 +91,13 @@ def convert_sharegpt(
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(
|
||||
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
|
||||
)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append([])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_llava(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
) -> Dict[str, List[Any]]:
|
||||
def convert_llava(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
@@ -128,10 +110,7 @@ def convert_llava(
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
accept_tags = (odd_tags, even_tags)
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
if (
|
||||
dataset_attr.system_tag
|
||||
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
|
||||
):
|
||||
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
|
||||
system = messages[0][dataset_attr.content_tag]
|
||||
messages = messages[1:]
|
||||
else:
|
||||
@@ -156,13 +135,9 @@ def convert_llava(
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(
|
||||
examples[dataset_attr.tools][i] if dataset_attr.tools else ""
|
||||
)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
print(examples[dataset_attr.images][i])
|
||||
outputs["images"].append(
|
||||
examples[dataset_attr.images][i] if dataset_attr.images else []
|
||||
)
|
||||
outputs["images"].append(examples[dataset_attr.images][i] if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
@@ -13,9 +13,10 @@ from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum, merge_dataset
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments, AutoProcessor
|
||||
from transformers import AutoProcessor, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
@@ -78,20 +79,14 @@ def load_single_dataset(
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(
|
||||
data_args.streaming and (dataset_attr.load_from != "file")
|
||||
),
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install modelscope via `pip install modelscope -U`"
|
||||
)
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
if (
|
||||
"trust_remote_code" in inspect.signature(load_dataset).parameters
|
||||
): # for datasets==2.16.0
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
@@ -108,9 +103,7 @@ def load_single_dataset(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if data_args.streaming and (
|
||||
dataset_attr.load_from == "file"
|
||||
): # faster than specifying streaming=True
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
@@ -135,13 +128,9 @@ def get_dataset(
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning(
|
||||
"Loading dataset from disk will ignore other data arguments."
|
||||
)
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
logger.info(
|
||||
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
|
||||
)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
@@ -152,16 +141,10 @@ def get_dataset(
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (
|
||||
stage != "rm" and dataset_attr.ranking is True
|
||||
):
|
||||
raise ValueError(
|
||||
"The dataset is not applicable in the current training stage."
|
||||
)
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(
|
||||
load_single_dataset(dataset_attr, model_args, data_args)
|
||||
)
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
@@ -177,21 +160,13 @@ def get_dataset(
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func, batched=True, remove_columns=column_names, **kwargs
|
||||
)
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info(
|
||||
"Tokenized dataset saved at {}.".format(data_args.tokenized_path)
|
||||
)
|
||||
logger.info(
|
||||
"Please restart the training with `--tokenized_path {}`.".format(
|
||||
data_args.tokenized_path
|
||||
)
|
||||
)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||
|
||||
exit(0)
|
||||
|
||||
@@ -199,8 +174,6 @@ def get_dataset(
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError(
|
||||
"Cannot find valid samples, check `data/README.md` for the data format."
|
||||
)
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -50,9 +50,7 @@ class DatasetAttr:
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(
|
||||
self, key: str, obj: Dict[str, Any], default: Optional[Any] = None
|
||||
) -> None:
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
@@ -71,16 +69,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(
|
||||
os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)
|
||||
)
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [
|
||||
float(prob.strip()) for prob in data_args.interleave_probs.split(",")
|
||||
]
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
@@ -98,21 +92,13 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr(
|
||||
"ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]
|
||||
)
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]
|
||||
)
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr(
|
||||
"script", dataset_name=dataset_info[name]["script_url"]
|
||||
)
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file", dataset_name=dataset_info[name]["file_name"]
|
||||
)
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
@@ -9,7 +9,7 @@ from .utils import Role
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer, AutoProcessor
|
||||
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
@@ -24,22 +24,16 @@ def preprocess_pretrain_dataset(
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
text_examples = [
|
||||
messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]
|
||||
]
|
||||
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(
|
||||
text_examples, add_special_tokens=False, max_length=data_args.cutoff_len
|
||||
)
|
||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
||||
else:
|
||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||
concatenated_examples = {
|
||||
k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()
|
||||
}
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.cutoff_len
|
||||
total_length = (total_length // block_size) * block_size
|
||||
@@ -87,9 +81,7 @@ def preprocess_supervised_dataset(
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
||||
len(source_ids) - 1
|
||||
)
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
@@ -128,9 +120,7 @@ def preprocess_packed_supervised_dataset(
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
||||
len(source_ids) - 1
|
||||
)
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
@@ -190,9 +180,7 @@ def preprocess_multimodal_supervised_dataset(
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
||||
len(source_ids) - 1
|
||||
)
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
@@ -206,9 +194,7 @@ def preprocess_multimodal_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
pixel_values = processor.image_processor(
|
||||
examples["images"][0], return_tensors="pt"
|
||||
)["pixel_values"][0]
|
||||
pixel_values = processor.image_processor(examples["images"][0], return_tensors="pt")["pixel_values"][0]
|
||||
model_inputs["pixel_values"].append(pixel_values)
|
||||
return model_inputs
|
||||
|
||||
@@ -229,9 +215,7 @@ def preprocess_unsupervised_dataset(
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [
|
||||
{"role": Role.ASSISTANT.value, "content": ""}
|
||||
]
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
@@ -294,15 +278,9 @@ def preprocess_pairwise_dataset(
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(
|
||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
||||
) -> None:
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print(
|
||||
"inputs:\n{}".format(
|
||||
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print(
|
||||
"labels:\n{}".format(
|
||||
@@ -314,38 +292,18 @@ def print_supervised_dataset_example(
|
||||
)
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(
|
||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
||||
) -> None:
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print(
|
||||
"prompt:\n{}".format(
|
||||
tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print(
|
||||
"chosen:\n{}".format(
|
||||
tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print(
|
||||
"rejected:\n{}".format(
|
||||
tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def print_unsupervised_dataset_example(
|
||||
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
||||
) -> None:
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print(
|
||||
"inputs:\n{}".format(
|
||||
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
@@ -357,12 +315,8 @@ def get_preprocess_and_print_func(
|
||||
processor: Optional["AutoProcessor"] = None,
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(
|
||||
preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args
|
||||
)
|
||||
print_function = partial(
|
||||
print_unsupervised_dataset_example, tokenizer=tokenizer
|
||||
)
|
||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
if data_args.packing:
|
||||
preprocess_func = partial(
|
||||
@@ -402,8 +356,6 @@ def get_preprocess_and_print_func(
|
||||
template=template,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(
|
||||
print_unsupervised_dataset_example, tokenizer=tokenizer
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
|
||||
@@ -42,9 +42,7 @@ class Template:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(
|
||||
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
|
||||
)
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids += query_ids + resp_ids
|
||||
@@ -64,9 +62,7 @@ class Template:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
return self._encode(
|
||||
tokenizer, messages, system, tools, cutoff_len, reserved_label_len
|
||||
)
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
@@ -93,9 +89,7 @@ class Template:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(
|
||||
content=message["content"], idx=str(i // 2)
|
||||
)
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
@@ -130,11 +124,7 @@ class Template:
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input must be string, set[str] or dict[str, str], got {}".format(
|
||||
type(elem)
|
||||
)
|
||||
)
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
@@ -192,9 +182,7 @@ class Llama2Template(Template):
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(
|
||||
content=system_text + message["content"]
|
||||
)
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
@@ -257,9 +245,7 @@ def _register_template(
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(
|
||||
slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots
|
||||
)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
templates[name] = template_class(
|
||||
@@ -295,9 +281,7 @@ def _jinja_escape(content: str) -> str:
|
||||
return content.replace("\n", r"\n").replace("'", r"\'")
|
||||
|
||||
|
||||
def _convert_slots_to_jinja(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@@ -311,9 +295,7 @@ def _convert_slots_to_jinja(
|
||||
elif isinstance(slot, set):
|
||||
if "bos_token" in slot:
|
||||
slot_items.append("'" + tokenizer.bos_token + "'")
|
||||
elif (
|
||||
"eos_token" in slot
|
||||
): # do not use {{ eos_token }} since it may be replaced
|
||||
elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced
|
||||
slot_items.append("'" + tokenizer.eos_token + "'")
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
@@ -325,37 +307,25 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
jinja_template = ""
|
||||
|
||||
if template.default_system:
|
||||
jinja_template += (
|
||||
"{% set system_message = '"
|
||||
+ _jinja_escape(template.default_system)
|
||||
+ "' %}"
|
||||
)
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if messages[0]['role'] == 'system' %}" "{% set system_message = messages[0]['content'] %}" "{% endif %}"
|
||||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(
|
||||
template.format_system.apply(), tokenizer, placeholder="system_message"
|
||||
)
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
if isinstance(template, Llama2Template):
|
||||
pass
|
||||
elif template.force_system:
|
||||
jinja_template += "{{ " + system_message + " }}"
|
||||
else:
|
||||
jinja_template += (
|
||||
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
)
|
||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
|
||||
jinja_template += "{% for message in messages %}"
|
||||
jinja_template += "{% set content = message['content'] %}"
|
||||
if isinstance(template, Llama2Template):
|
||||
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
||||
jinja_template += (
|
||||
"{% set content = " + system_message + " + message['content'] %}"
|
||||
)
|
||||
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
||||
jinja_template += "{% endif %}"
|
||||
jinja_template += "{% if message['role'] == 'user' %}"
|
||||
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
||||
@@ -403,9 +373,7 @@ def get_template_and_fix_tokenizer(
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
if num_added_tokens > 0:
|
||||
logger.warning(
|
||||
"New tokens have been added, make sure `resize_vocab` is True."
|
||||
)
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
@@ -417,9 +385,7 @@ def get_template_and_fix_tokenizer(
|
||||
|
||||
_register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(
|
||||
slots=["### Instruction:\n{{content}}\n\n### Response:\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
@@ -458,9 +424,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]
|
||||
),
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
@@ -483,9 +447,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]
|
||||
),
|
||||
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
@@ -504,9 +466,7 @@ _register_template(
|
||||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
@@ -515,13 +475,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
@@ -539,9 +495,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[
|
||||
@@ -572,15 +526,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
@@ -589,15 +537,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
@@ -607,9 +549,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
@@ -639,15 +579,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="dbrx",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
||||
@@ -725,9 +659,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(
|
||||
slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
@@ -740,9 +672,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
@@ -751,12 +681,8 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
@@ -859,9 +785,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(
|
||||
slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]
|
||||
),
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
@@ -869,15 +793,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|end|>"],
|
||||
@@ -887,15 +805,9 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||
),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
@@ -951,12 +863,8 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]
|
||||
),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
@@ -975,9 +883,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
@@ -995,9 +901,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]
|
||||
),
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
|
||||
Reference in New Issue
Block a user