format style
Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from .loader import get_dataset
|
||||
from .template import get_template_and_fix_tokenizer, templates
|
||||
from .utils import split_dataset, Role
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "split_dataset", "Role"]
|
||||
|
||||
@@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
|
||||
if dataset_attr.response:
|
||||
if isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||
response = [
|
||||
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
@@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
}
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
@@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
if message[dataset_attr.role_tag] not in accept_tags:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
|
||||
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
|
||||
prompt.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
last_message = prompt.pop(-1)
|
||||
response.append(last_message)
|
||||
@@ -98,12 +102,7 @@ def align_dataset(
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset"
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
@@ -76,7 +76,11 @@ class ToolFormatter:
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format(
|
||||
name=name, type=param.get("type", ""), required=required, desc=param.get("description", ""), enum=enum
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
@@ -85,9 +89,7 @@ class ToolFormatter:
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text,
|
||||
tool_names=", ".join(tool_names),
|
||||
format_prompt=JSON_FORMAT_PROMPT
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
def __call__(self, content: str) -> List[Union[str, Dict[str, str]]]:
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import os
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Literal, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import checksum
|
||||
from .parser import get_dataset_list
|
||||
from .aligner import align_dataset
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
from .utils import checksum
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -18,8 +18,8 @@ if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
from ..hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -44,14 +44,14 @@ def load_single_dataset(
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_files = []
|
||||
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
if os.path.isdir(local_path): # is directory
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
@@ -78,12 +78,12 @@ def load_single_dataset(
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
).to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
else:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
@@ -97,13 +97,13 @@ def load_single_dataset(
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
|
||||
@@ -113,7 +113,7 @@ def load_single_dataset(
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments"
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
@@ -128,7 +128,7 @@ def merge_dataset(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
@@ -160,7 +160,7 @@ def get_dataset(
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
@@ -174,15 +174,10 @@ def get_dataset(
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset"
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
@@ -49,7 +49,9 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if data_args.dataset is not None:
|
||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)))
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
@@ -74,7 +76,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||
)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
@@ -17,9 +18,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_pretrain_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments"
|
||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
||||
@@ -35,7 +34,7 @@ def preprocess_pretrain_dataset(
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
return result
|
||||
@@ -57,9 +56,11 @@ def preprocess_supervised_dataset(
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)):
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
@@ -96,9 +97,9 @@ def preprocess_packed_supervised_dataset(
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
)):
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
@@ -119,9 +120,9 @@ def preprocess_packed_supervised_dataset(
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i: i + block_size])
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -191,9 +192,11 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
))
|
||||
print(
|
||||
"labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
@@ -232,10 +235,14 @@ def get_preprocess_and_print_func(
|
||||
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "rm":
|
||||
preprocess_func = partial(preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
||||
preprocess_func = partial(
|
||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args)
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
|
||||
return preprocess_func, print_function
|
||||
|
||||
@@ -2,8 +2,8 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .formatter import FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .utils import Role
|
||||
from .formatter import StringFormatter, FunctionFormatter, ToolFormatter
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,7 +15,6 @@ logger = get_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
format_user: Callable
|
||||
format_assistant: Callable
|
||||
format_system: Callable
|
||||
@@ -34,7 +33,7 @@ class Template:
|
||||
messages: List[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: Optional[int] = 1_000_000
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
@@ -53,7 +52,7 @@ class Template:
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: Optional[int] = 1_000_000
|
||||
cutoff_len: Optional[int] = 1_000_000,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
@@ -67,7 +66,7 @@ class Template:
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int
|
||||
cutoff_len: int,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
@@ -102,19 +101,17 @@ class Template:
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
||||
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||
total_length += len(encoded_messages[i])
|
||||
|
||||
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
||||
total_length += len(encoded_messages[i+1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||
total_length += len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
elements: List[Union[str, Dict[str, str]]]
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
@@ -139,14 +136,13 @@ class Template:
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int
|
||||
cutoff_len: int,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
@@ -182,12 +178,12 @@ class Llama2Template(Template):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
|
||||
encoded_messages[i] = encoded_messages[i][: cutoff_len - total_length]
|
||||
total_length += len(encoded_messages[i])
|
||||
|
||||
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
|
||||
total_length += len(encoded_messages[i+1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
|
||||
encoded_messages[i + 1] = encoded_messages[i + 1][: max(1, cutoff_len - total_length)]
|
||||
total_length += len(encoded_messages[i + 1])
|
||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
@@ -207,32 +203,26 @@ def register_template(
|
||||
separator: Optional[List[Union[str, Dict[str, str]]]] = "",
|
||||
stop_words: Optional[List[str]] = [],
|
||||
efficient_eos: Optional[bool] = False,
|
||||
replace_eos: Optional[bool] = False
|
||||
replace_eos: Optional[bool] = False,
|
||||
) -> None:
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
templates[name] = template_class(
|
||||
format_user=format_user or StringFormatter(container=["{{content}}"]),
|
||||
format_assistant=format_assistant or StringFormatter(container=[
|
||||
"{{content}}", {"eos_token"}
|
||||
]),
|
||||
format_assistant=format_assistant or StringFormatter(container=["{{content}}", {"eos_token"}]),
|
||||
format_system=format_system or StringFormatter(container=["{{content}}"]),
|
||||
format_tool=format_tool or ToolFormatter(type="default"),
|
||||
format_observation=format_observation or format_user,
|
||||
format_function=format_function or FunctionFormatter(container=[
|
||||
"Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}
|
||||
]),
|
||||
format_function=format_function
|
||||
or FunctionFormatter(container=["Action: {{name}}\nAction Input: {{arguments}}", {"eos_token"}]),
|
||||
system=system,
|
||||
separator=separator,
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos
|
||||
replace_eos=replace_eos,
|
||||
)
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(
|
||||
name: str,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Template:
|
||||
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template:
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
@@ -241,7 +231,7 @@ def get_template_and_fix_tokenizer(
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if name is None: # for pre-training
|
||||
if name is None: # for pre-training
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
@@ -258,8 +248,7 @@ def get_template_and_fix_tokenizer(
|
||||
|
||||
if stop_words:
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words),
|
||||
replace_additional_special_tokens=False
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
|
||||
@@ -268,263 +257,153 @@ def get_template_and_fix_tokenizer(
|
||||
|
||||
register_template(
|
||||
name="alpaca",
|
||||
format_user=StringFormatter(container=[
|
||||
"### Instruction:\n{{content}}\n\n### Response:\n"
|
||||
]),
|
||||
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
||||
system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
|
||||
),
|
||||
separator=[
|
||||
"\n\n"
|
||||
]
|
||||
separator=["\n\n"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="aquila",
|
||||
format_user=StringFormatter(container=[
|
||||
"Human: {{content}}###Assistant:"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
format_user=StringFormatter(container=["Human: {{content}}###Assistant:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
separator=[
|
||||
"###"
|
||||
],
|
||||
stop_words=[
|
||||
"</s>"
|
||||
],
|
||||
efficient_eos=True
|
||||
separator=["###"],
|
||||
stop_words=["</s>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<reserved_102>"},
|
||||
"{{content}}",
|
||||
{"token": "<reserved_103>"}
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(container=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<reserved_106>"},
|
||||
"{{content}}",
|
||||
{"token": "<reserved_107>"}
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(container=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(container=[
|
||||
"Human: {{content}}\n\nBelle: "
|
||||
]),
|
||||
separator=[
|
||||
"\n\n"
|
||||
]
|
||||
name="belle", format_user=StringFormatter(container=["Human: {{content}}\n\nBelle: "]), separator=["\n\n"]
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="bluelm",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "[|Human|]:"},
|
||||
"{{content}}",
|
||||
{"token": "[|AI|]:"}
|
||||
])
|
||||
format_user=StringFormatter(container=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(container=[
|
||||
"[Round {{idx}}]\n\n问:{{content}}\n\n答:"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{content}}"
|
||||
]),
|
||||
separator=[
|
||||
"\n\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(container=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
separator=["\n\n"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<|user|>"},
|
||||
"\n",
|
||||
"{{content}}",
|
||||
{"token": "<|assistant|>"}
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"\n"
|
||||
"{{content}}"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
{"token": "<|system|>"},
|
||||
"\n",
|
||||
"{{content}}"
|
||||
]),
|
||||
format_observation=StringFormatter(container=[
|
||||
{"token": "<|observation|>"},
|
||||
"\n",
|
||||
"{{content}}"
|
||||
]),
|
||||
format_function=FunctionFormatter(container=[
|
||||
"{{name}}\n{{arguments}}"
|
||||
]),
|
||||
format_user=StringFormatter(container=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(container=["\n" "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
container=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_observation=StringFormatter(container=[{"token": "<|observation|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(container=["{{name}}\n{{arguments}}"]),
|
||||
system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
stop_words=[
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{content}}"
|
||||
])
|
||||
name="codegeex2", format_system=StringFormatter(container=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"])
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(container=[
|
||||
"User: {{content}}\n\nAssistant:"
|
||||
])
|
||||
)
|
||||
register_template(name="deepseek", format_user=StringFormatter(container=["User: {{content}}\n\nAssistant:"]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(container=[
|
||||
"### Instruction:\n{{content}}\n### Response:\n"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
format_user=StringFormatter(container=["### Instruction:\n{{content}}\n### Response:\n"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
separator=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
separator=["\n", {"token": "<|EOT|>"}, "\n"],
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(container=[
|
||||
"Human: {{content}}\nAssistant: "
|
||||
]),
|
||||
format_user=StringFormatter(container=["Human: {{content}}\nAssistant: "]),
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
),
|
||||
separator=[
|
||||
"\n"
|
||||
]
|
||||
separator=["\n"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(container=[
|
||||
"User: {{content}}\nFalcon:"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
separator=[
|
||||
"\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(container=["User: {{content}}\nFalcon:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=["\n"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(container=[
|
||||
"<|User|>:{{content}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
separator=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
],
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(container=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=[{"token": "<eoa>"}, "\n"],
|
||||
stop_words=["<eoa>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="intern2",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"user\n{{content}}",
|
||||
{"token": "[UNUSED_TOKEN_145]"},
|
||||
"\n",
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"assistant\n"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"system\n{{content}}",
|
||||
{"token": "[UNUSED_TOKEN_145]"},
|
||||
"\n"
|
||||
]),
|
||||
format_user=StringFormatter(
|
||||
container=[
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"user\n{{content}}",
|
||||
{"token": "[UNUSED_TOKEN_145]"},
|
||||
"\n",
|
||||
{"token": "[UNUSED_TOKEN_146]"},
|
||||
"assistant\n",
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
container=[{"token": "[UNUSED_TOKEN_146]"}, "system\n{{content}}", {"token": "[UNUSED_TOKEN_145]"}, "\n"]
|
||||
),
|
||||
system=(
|
||||
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
|
||||
"- InternLM (书生·浦语) is a conversational language model that is developed "
|
||||
@@ -532,14 +411,9 @@ register_template(
|
||||
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
|
||||
"by the user such as English and 中文."
|
||||
),
|
||||
separator=[
|
||||
{"token": "[UNUSED_TOKEN_145]"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"[UNUSED_TOKEN_145]"
|
||||
],
|
||||
efficient_eos=True
|
||||
separator=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
|
||||
stop_words=["[UNUSED_TOKEN_145]"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -556,7 +430,7 @@ register_template(
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -564,142 +438,83 @@ register_template(
|
||||
name="llama2_zh",
|
||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(container=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。"
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(container=["[INST] {{content}} [/INST]"])
|
||||
)
|
||||
register_template(name="mistral", format_user=StringFormatter(container=["[INST] {{content}} [/INST]"]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(container=[
|
||||
"GPT4 Correct User: {{content}}",
|
||||
{"token": "<|end_of_turn|>"},
|
||||
"GPT4 Correct Assistant:"
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
separator=[
|
||||
{"token": "<|end_of_turn|>"}
|
||||
],
|
||||
stop_words=[
|
||||
"<|end_of_turn|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(
|
||||
container=["GPT4 Correct User: {{content}}", {"token": "<|end_of_turn|>"}, "GPT4 Correct Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
separator=[{"token": "<|end_of_turn|>"}],
|
||||
stop_words=["<|end_of_turn|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(container=[
|
||||
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
"<|im_start|>system\n{{content}}<|im_end|>\n"
|
||||
]),
|
||||
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(container=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
system="You are a helpful assistant.",
|
||||
separator=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
replace_eos=True
|
||||
separator=["\n"],
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="solar",
|
||||
format_user=StringFormatter(container=[
|
||||
"### User:\n{{content}}\n\n### Assistant:\n"
|
||||
])
|
||||
)
|
||||
register_template(name="solar", format_user=StringFormatter(container=["### User:\n{{content}}\n\n### Assistant:\n"]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="starchat",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{content}}",
|
||||
{"token": "<|end|>"},
|
||||
"\n",
|
||||
{"token": "<|assistant|>"}
|
||||
]),
|
||||
format_assistant=StringFormatter(container=[
|
||||
"{{content}}"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{content}}",
|
||||
{"token": "<|end|>"},
|
||||
"\n"
|
||||
]),
|
||||
separator=[
|
||||
{"token": "<|end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
format_user=StringFormatter(
|
||||
container=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_assistant=StringFormatter(container=["{{content}}"]),
|
||||
format_system=StringFormatter(container=[{"token": "<|system|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n"]),
|
||||
separator=[{"token": "<|end|>"}, "\n"],
|
||||
stop_words=["<|end|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="vanilla"
|
||||
)
|
||||
register_template(name="vanilla")
|
||||
|
||||
|
||||
register_template(
|
||||
name="vicuna",
|
||||
format_user=StringFormatter(container=[
|
||||
"USER: {{content}} ASSISTANT:"
|
||||
]),
|
||||
format_user=StringFormatter(container=["USER: {{content}} ASSISTANT:"]),
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(container=[
|
||||
"Human: {{content}} Assistant:"
|
||||
]),
|
||||
format_user=StringFormatter(container=["Human: {{content}} Assistant:"]),
|
||||
system=(
|
||||
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
||||
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
||||
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="xverse",
|
||||
format_user=StringFormatter(container=[
|
||||
"Human: {{content}}\n\nAssistant: "
|
||||
])
|
||||
)
|
||||
register_template(name="xverse", format_user=StringFormatter(container=["Human: {{content}}\n\nAssistant: "]))
|
||||
|
||||
|
||||
register_template(
|
||||
name="yayi",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<|Human|>"},
|
||||
":\n{{content}}\n\n",
|
||||
{"token": "<|YaYi|>"},
|
||||
":"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
{"token": "<|System|>"},
|
||||
":\n{{content}}\n\n"
|
||||
]),
|
||||
format_user=StringFormatter(container=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
||||
format_system=StringFormatter(container=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
@@ -711,67 +526,43 @@ register_template(
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
separator=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|End|>"
|
||||
]
|
||||
separator=["\n\n"],
|
||||
stop_words=["<|End|>"],
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yi",
|
||||
format_user=StringFormatter(container=[
|
||||
"<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"
|
||||
]),
|
||||
separator=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
replace_eos=True
|
||||
format_user=StringFormatter(container=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
separator=["\n"],
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="yuan",
|
||||
format_user=StringFormatter(container=[
|
||||
"{{content}}",
|
||||
{"token": "<sep>"}
|
||||
]),
|
||||
separator=[
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eod>"
|
||||
],
|
||||
replace_eos=True
|
||||
format_user=StringFormatter(container=["{{content}}", {"token": "<sep>"}]),
|
||||
separator=["\n"],
|
||||
stop_words=["<eod>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(container=[
|
||||
"<|user|>\n{{content}}</s><|assistant|>"
|
||||
]),
|
||||
format_system=StringFormatter(container=[
|
||||
"<|system|>\n{{content}}</s>",
|
||||
]),
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate"
|
||||
format_user=StringFormatter(container=["<|user|>\n{{content}}</s><|assistant|>"]),
|
||||
format_system=StringFormatter(
|
||||
container=[
|
||||
"<|system|>\n{{content}}</s>",
|
||||
]
|
||||
),
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="ziya",
|
||||
format_user=StringFormatter(container=[
|
||||
{"token": "<human>"},
|
||||
":{{content}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
]),
|
||||
separator=[
|
||||
"\n"
|
||||
]
|
||||
format_user=StringFormatter(container=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
|
||||
separator=["\n"],
|
||||
)
|
||||
|
||||
@@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import TrainingArguments
|
||||
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
@@ -44,12 +46,10 @@ def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "TrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
@@ -63,5 +63,5 @@ def split_dataset(
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
|
||||
Reference in New Issue
Block a user