support function calling

Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent f7329b1a0e
commit a423274fd9
67 changed files with 1239 additions and 1079 deletions

View File

@@ -1,4 +1,6 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .utils import split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset"]

View File

@@ -0,0 +1,106 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from .utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history:
for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT, "content": old_response})
instruction = examples[dataset_attr.prompt][i]
if dataset_attr.query and examples[dataset_attr.query][i]:
instruction += "\n" + examples[dataset_attr.query][i]
prompt.append({"role": Role.USER, "content": instruction})
if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append("")
return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tool": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION
}
for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
prompt = []
response = []
for turn_idx, message in enumerate(messages):
if turn_idx % 2 == 0:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag.")
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
last_message = prompt.pop(-1)
response.append(last_message)
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tool"].append(examples[dataset_attr.tool][i] if dataset_attr.tool else "")
return outputs
def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}]
response: [{"role": "assistant", "content": "..."}]
system: "..."
tool: "..."
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)

View File

@@ -0,0 +1,99 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Union
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format to answer the question:\n"
"```\n"
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
"Action Input: the input to the action{format_prompt}.\n"
"```"
)
@dataclass
class StringFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, **kwargs) -> List[Union[str, Dict[str, str]]]:
elements = []
for elem in self.container:
if isinstance(elem, str):
for name, value in kwargs.items():
elem = elem.replace("{{" + name + "}}", value)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class FunctionFormatter:
container: List[Union[str, Dict[str, str]]]
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
function = json.loads(content)
name = json.dumps(function["name"], ensure_ascii=False)
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except json.JSONDecodeError:
name, arguments = "", ""
elements = []
for elem in self.container:
if isinstance(elem, str):
elem = elem.replace("{{name}}", name)
elem = elem.replace("{{arguments}}", arguments)
elements.append(elem)
elif isinstance(elem, (dict, set)):
elements.append(elem)
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return elements
@dataclass
class ToolFormatter:
type: Literal["default"]
def _default(self, tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text,
tool_names=", ".join(tool_names),
format_prompt=JSON_FORMAT_PROMPT
)
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
try:
tools = json.loads(content)
if self.type == "default":
return [self._default(tools)]
except json.JSONDecodeError:
return [""]

View File

@@ -1,160 +1,114 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
from llmtuner.data.utils import checksum
from llmtuner.extras.constants import FILEEXT2TYPE
from llmtuner.extras.logging import get_logger
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from .utils import checksum
from .parser import get_dataset_list
from .aligner import align_dataset
from .template import get_template_and_fix_tokenizer
from .preprocess import get_preprocess_and_print_func
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from .parser import DatasetAttr
from ..hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
def get_dataset(
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
data_args: "DataArguments",
):
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
elif data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise NotImplementedError
raise ValueError("File not found.")
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
dataset = load_dataset(
path=data_path,
name=data_name,
checksum(data_files, dataset_attr.dataset_sha1)
else:
raise NotImplementedError
if dataset_attr.load_from == "ms_hub":
try:
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
).to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))
if data_args.max_samples is not None: # truncate dataset
num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples))
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# convert dataset from sharegpt format to alpaca format
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
for i, msg_list in enumerate(examples[dataset_attr.messages]):
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
if len(msg_list) == 0:
continue
return align_dataset(dataset, dataset_attr, data_args)
msg_pairs = []
user_role, assistant_role = None, None
for idx in range(0, len(msg_list), 2):
if user_role is None and assistant_role is None:
user_role = msg_list[idx][dataset_attr.role]
assistant_role = msg_list[idx + 1][dataset_attr.role]
else:
if (
msg_list[idx][dataset_attr.role] != user_role
or msg_list[idx+1][dataset_attr.role] != assistant_role
):
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
if len(msg_pairs) != 0:
outputs["prompt"].append(msg_pairs[-1][0])
outputs["query"].append("")
outputs["response"].append(msg_pairs[-1][1])
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
return outputs
if dataset_attr.formatting == "sharegpt": # convert format
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
)
dataset = dataset.map(
convert_format,
batched=True,
remove_columns=column_names,
**kwargs
)
else:
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments"
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
@@ -166,8 +120,72 @@ def get_dataset(
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=data_args.seed,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
)
else:
raise ValueError("Unknown mixing strategy.")
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
tokenizer: "PreTrainedTokenizer",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
# Load from cache
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
return dataset

101
src/llmtuner/data/parser.py Normal file
View File

@@ -0,0 +1,101 @@
import os
import json
from typing import TYPE_CHECKING, List, Literal, Optional
from dataclasses import dataclass
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass
class DatasetAttr:
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
system: Optional[str] = None
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
messages: Optional[str] = "conversations"
tool: Optional[str] = None
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
def __repr__(self) -> str:
return self.dataset_name
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
dataset_info = json.load(f)
except Exception as err:
if data_args.dataset is not None:
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
if "columns" in dataset_info[name]:
if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"]
else:
column_names = ["messages", "tool"]
column_names += ["system"]
for column_name in column_names:
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]:
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
dataset_list.append(dataset_attr)
return dataset_list

View File

@@ -1,272 +1,241 @@
import os
import tiktoken
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from ..hparams import DataArguments
from .template import Template
logger = get_logger(__name__)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
system = examples["system"][i] if "system" in examples else None
yield query, response, history, system
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
return dataset # already preprocessed
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
# make sure the saved tokenizer is the same as the original one
if hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
source_len, target_len = len(source_ids), len(target_ids)
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
source_ids = source_ids[:max_source_len]
if target_len > max_target_len:
target_ids = target_ids[:max_target_len]
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
labels = labels[:data_args.cutoff_len]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return model_inputs
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
messages = examples["prompt"][i] + examples["response"][i]
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], 1_000_000
)):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if len(input_ids) > data_args.cutoff_len:
input_ids = input_ids[:data_args.cutoff_len]
if len(labels) > data_args.cutoff_len:
labels = labels[:data_args.cutoff_len]
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
total_length = len(input_ids)
block_size = data_args.cutoff_len
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of cutoff_len
for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs
return model_inputs
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1:
continue
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
if source_len > max_source_len:
prompt_ids = prompt_ids[:max_source_len]
if target_len > max_target_len:
chosen_ids = chosen_ids[:max_target_len]
rejected_ids = rejected_ids[:max_target_len]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Running tokenizer on dataset"
)
dataset = dataset.map(
preprocess_func,
batched=True,
remove_columns=column_names,
**kwargs
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Empty dataset!")
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
return dataset
return model_inputs
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2:
continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tool"][i], data_args.cutoff_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
))
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
def get_preprocess_and_print_func(
tokenizer: "PreTrainedTokenizer",
template: "Template",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
else:
preprocess_func = partial(
preprocess_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function

View File

@@ -1,8 +1,10 @@
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from ..extras.logging import get_logger
from .utils import Role
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -14,28 +16,30 @@ logger = get_logger(__name__)
@dataclass
class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
format_user: Callable
format_assistant: Callable
format_system: Callable
format_tool: Callable
format_observation: Callable
format_function: Callable
system: str
sep: List[Union[str, Dict[str, str]]]
separator: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
efficient_eos: bool
replace_eos: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
@@ -46,109 +50,75 @@ class Template:
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
messages: List[Dict[str, str]],
system: str,
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
encoded_pairs = self._encode(tokenizer, messages, system, tool, cutoff_len)
return encoded_pairs
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else: # baichuan, gpt2, qwen, yi models have no bos token
bos_ids = []
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos:
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
history: List[Tuple[str, str]]
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
system = system or self.system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
elements += self.format_system(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.separator
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
if message["role"] == Role.USER:
elements += self.format_user(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
def _convert_inputs_to_ids(
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
def _convert_elements_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts context to token ids.
Converts elements to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
for elem in elements:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id:
token_ids = token_ids + [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id:
token_ids = token_ids + [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
@@ -159,23 +129,39 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
history: List[Tuple[str, str]]
tool: str,
cutoff_len: int
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
Turn 0: system + query resp + eos
Turn t: sep + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
system = system or self.system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tool):
tool_text = self.format_tool(content=tool)[0] if tool else ""
system_text = self.format_system(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.separator
if message["role"] == Role.USER:
elements += self.format_user(content=system_text + message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function(content=message["content"])
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
templates: Dict[str, Template] = {}
@@ -183,23 +169,33 @@ templates: Dict[str, Template] = {}
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
format_user: Optional[Callable] = None,
format_assistant: Optional[Callable] = None,
format_system: Optional[Callable] = None,
format_tool: Optional[Callable] = None,
format_observation: Optional[Callable] = None,
format_function: Optional[Callable] = None,
system: Optional[str] = "",
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
format_user=format_user or StringFormatter(container=["{{content}}"]),
format_assistant=format_assistant or StringFormatter(container=[
"{{content}}", {"eos_token"}
]),
format_system=format_system or StringFormatter(container=["{{content}}"]),
format_tool=format_tool or ToolFormatter(type="default"),
format_observation=format_observation or format_user,
format_function=format_function or FunctionFormatter(container=[
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
]),
system=system,
sep=sep,
separator=separator,
stop_words=stop_words,
use_history=use_history,
efficient_eos=efficient_eos,
replace_eos=replace_eos
)
@@ -244,17 +240,14 @@ def get_template_and_fix_tokenizer(
register_template(
name="alpaca",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n\n### Response:\n"
]),
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
separator=[
"\n\n"
]
)
@@ -262,17 +255,14 @@ register_template(
register_template(
name="aquila",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}}###Assistant:"
]),
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
separator=[
"###"
],
stop_words=[
@@ -284,46 +274,32 @@ register_template(
register_template(
name="baichuan",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
format_user=StringFormatter(container=[
{"token": "<reserved_102>"},
"{{content}}",
{"token": "<reserved_103>"}
]),
efficient_eos=True
)
register_template(
name="baichuan2",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
format_user=StringFormatter(container=[
{"token": "<reserved_106>"},
"{{content}}",
{"token": "<reserved_107>"}
]),
efficient_eos=True
)
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
format_user=StringFormatter(container=[
"Human: {{content}}\n\nBelle: "
]),
separator=[
"\n\n"
]
)
@@ -331,31 +307,25 @@ register_template(
register_template(
name="bluelm",
prefix=[
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "[|Human|]:"},
"{{query}}",
"{{content}}",
{"token": "[|AI|]:"}
],
system="",
sep=[]
])
)
register_template(
name="chatglm2",
prefix=[
format_user=StringFormatter(container=[
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"{{content}}"
]),
separator=[
"\n\n"
],
efficient_eos=True
@@ -364,53 +334,35 @@ register_template(
register_template(
name="chatglm3",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n",
"{{query}}",
{"token": "<|assistant|>"},
"\n" # add an extra newline to avoid error in ChatGLM's process_response method
],
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
)
register_template(
name="chatglm3_raw", # the raw template for tool tuning
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{system}}"
],
prompt=[
{"token": "<|user|>"},
"\n",
"{{query}}",
"{{content}}",
{"token": "<|assistant|>"}
],
]),
format_assistant=StringFormatter(container=[
"\n"
"{{content}}"
]),
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
{"token": "<|system|>"},
"\n",
"{{content}}"
]),
format_observation=StringFormatter(container=[
{"token": "<|observation|>"},
"\n",
"{{content}}"
]),
format_function=FunctionFormatter(container=[
"{{name}}\n{{arguments}}"
]),
system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
),
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
@@ -421,47 +373,34 @@ register_template(
register_template(
name="codegeex2",
prefix=[
format_system=StringFormatter(container=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"{{query}}"
],
system="",
sep=[]
"{{content}}"
])
)
register_template(
name="deepseek",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\n\nAssistant:"
],
system="",
sep=[]
format_user=StringFormatter(container=[
"User: {{content}}\n\nAssistant:"
])
)
register_template(
name="deepseekcoder",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n### Response:\n"
],
format_user=StringFormatter(container=[
"### Instruction:\n{{content}}\n### Response:\n"
]),
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
sep=[
separator=[
"\n",
{"token": "<|EOT|>"},
"\n"
@@ -475,17 +414,14 @@ register_template(
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}}\nAssistant: "
]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
),
sep=[
separator=[
"\n"
]
)
@@ -493,14 +429,10 @@ register_template(
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
format_user=StringFormatter(container=[
"User: {{content}}\nFalcon:"
]),
separator=[
"\n"
],
efficient_eos=True
@@ -509,16 +441,12 @@ register_template(
register_template(
name="intern",
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
format_user=StringFormatter(container=[
"<|User|>:{{content}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
]),
separator=[
{"token": "<eoa>"},
"\n"
],
@@ -529,14 +457,44 @@ register_template(
)
register_template(
name="intern2",
format_user=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"user\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n",
{"token": "[UNUSED_TOKEN_146]"},
"assistant\n"
]),
format_system=StringFormatter(container=[
{"token": "[UNUSED_TOKEN_146]"},
"system\n{{content}}",
{"token": "[UNUSED_TOKEN_145]"},
"\n"
]),
system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
separator=[
{"token": "[UNUSED_TOKEN_145]"},
"\n"
],
stop_words=[
"[UNUSED_TOKEN_145]"
],
efficient_eos=True
)
register_template(
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
@@ -546,49 +504,32 @@ register_template(
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
)
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
system="You are a helpful assistant. 你是一个乐于助人的助手。"
)
register_template(
name="mistral",
prefix=[
"{{system}}"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="",
sep=[]
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
)
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 Correct User: {{query}}",
format_user=StringFormatter(container=[
"GPT4 Correct User: {{content}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
],
system="",
sep=[
]),
separator=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
@@ -600,14 +541,14 @@ register_template(
register_template(
name="qwen",
prefix=[
"<|im_start|>system\n{{system}}<|im_end|>"
],
prompt=[
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
],
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
format_system=StringFormatter(container=[
"<|im_start|>system\n{{content}}<|im_end|>\n"
]),
system="You are a helpful assistant.",
sep=[
separator=[
"\n"
],
stop_words=[
@@ -619,32 +560,28 @@ register_template(
register_template(
name="solar",
prefix=[
"{{system}}"
],
prompt=[
"### User:\n{{query}}\n\n### Assistant:\n"
],
system="",
sep=[]
format_user=StringFormatter(container=[
"### User:\n{{content}}\n\n### Assistant:\n"
])
)
register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|user|>"},
"\n{{query}}",
"\n{{content}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
]),
format_system=StringFormatter(container=[
{"token": "<|system|>"},
"\n{{content}}",
{"token": "<|end|>"},
"\n"
]),
separator=[
{"token": "<|end|>"},
"\n"
],
@@ -656,75 +593,55 @@ register_template(
register_template(
name="vanilla",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
name="vanilla"
)
register_template(
name="vicuna",
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT:"
],
format_user=StringFormatter(container=[
"USER: {{content}} ASSISTANT:"
]),
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
)
register_template(
name="xuanyuan",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}} Assistant:"
],
format_user=StringFormatter(container=[
"Human: {{content}} Assistant:"
]),
system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头人工智能助手以Assistant开头"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
),
sep=[]
)
)
register_template(
name="xverse",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nAssistant: "
],
system="",
sep=[]
format_user=StringFormatter(container=[
"Human: {{content}}\n\nAssistant: "
])
)
register_template(
name="yayi",
prefix=[
{"token": "<|System|>"},
":\n{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<|Human|>"},
":\n{{query}}\n\n",
":\n{{content}}\n\n",
{"token": "<|YaYi|>"},
":"
],
]),
format_system=StringFormatter(container=[
{"token": "<|System|>"},
":\n{{content}}\n\n"
]),
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
@@ -736,7 +653,7 @@ register_template(
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[
separator=[
"\n\n"
],
stop_words=[
@@ -747,14 +664,10 @@ register_template(
register_template(
name="yi",
prefix=[
"{{system}}"
],
prompt=[
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
],
system="",
sep=[
format_user=StringFormatter(container=[
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
]),
separator=[
"\n"
],
stop_words=[
@@ -766,15 +679,11 @@ register_template(
register_template(
name="yuan",
prefix=[
"{{system}}"
],
prompt=[
"{{query}}",
format_user=StringFormatter(container=[
"{{content}}",
{"token": "<sep>"}
],
system="",
sep=[
]),
separator=[
"\n"
],
stop_words=[
@@ -786,30 +695,25 @@ register_template(
register_template(
name="zephyr",
prefix=[
"<|system|>\n{{system}}</s>",
],
prompt=[
"<|user|>\n{{query}}</s><|assistant|>"
],
system="You are a friendly chatbot who always responds in the style of a pirate",
sep=[]
format_user=StringFormatter(container=[
"<|user|>\n{{content}}</s><|assistant|>"
]),
format_system=StringFormatter(container=[
"<|system|>\n{{content}}</s>",
]),
system="You are a friendly chatbot who always responds in the style of a pirate"
)
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
format_user=StringFormatter(container=[
{"token": "<human>"},
":{{query}}\n",
":{{content}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
]),
separator=[
"\n"
]
)

View File

@@ -1,7 +1,8 @@
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@@ -12,6 +13,14 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
OBSERVATION = "observation"
FUNCTION = "function"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
@@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",