disentangle model from tuner and rename modules

Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
This commit is contained in:
hiyouga
2023-11-15 16:29:09 +08:00
parent 81530133ff
commit 09a4474e7f
57 changed files with 324 additions and 263 deletions

View File

@@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
@@ -16,14 +28,6 @@ TRAINING_STAGES = {
"Pre-Training": "pt"
}
LAYERNORM_NAMES = {"norm", "ln"}
SUPPORTED_MODELS = OrderedDict()
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
def register_model_group(
models: Dict[str, str],

View File

@@ -13,14 +13,13 @@ try:
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
@@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
@@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
return torch.float32
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
return parser.parse_args_into_dataclasses()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
def torch_gc() -> None:
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
Collects GPU memory.
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

@@ -3,16 +3,19 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
is_flash_attn_2_available = False
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError:
is_flash_attn_2_available = False
logger = logging.get_logger(__name__)

View File

@@ -1,11 +1,14 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)

View File

@@ -1,769 +0,0 @@
import tiktoken
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass
class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
efficient_eos: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format(
self,
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> Tuple[str, List[Tuple[str, str]]]:
r"""
Aligns inputs to the standard format.
"""
system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else: # baichuan, qwen and gpt2 models have no bos token
bos_ids = []
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
if len(elem) != 0:
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
return token_ids
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
templates: Dict[str, Template] = {}
def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history,
efficient_eos=efficient_eos
)
def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False
)
return template
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
"\n\n"
]
)
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
"""
register_template(
name="aquila",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}###Assistant:"
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
"###"
],
stop_words=[
"</s>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
"""
register_template(
name="baichuan2",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
)
r"""
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
"""
register_template(
name="bluelm",
prefix=[
"{{system}}"
],
prompt=[
{"token": "[|Human|]:"},
"{{query}}",
{"token": "[|AI|]:"}
],
system="",
sep=[]
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/THUDM/chatglm3-6b
"""
register_template(
name="chatglm3",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
{"token": "<|user|>"},
"\n",
"{{query}}",
{"token": "<|assistant|>"}
],
system="",
sep=[],
stop_words=[
"<|user|>",
"<|observation|>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
"""
register_template(
name="deepseek",
prefix=[
"{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer."
),
sep=[
"\n",
{"token": "<|EOT|>"},
"\n\n"
],
stop_words=[
"<|EOT|>"
],
efficient_eos=True
)
r"""
Default template.
"""
register_template(
name="default",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant:"
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
)
r"""
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
"""
register_template(
name="falcon",
prefix=[
"{{system}}"
],
prompt=[
"User: {{query}}\nFalcon:"
],
system="",
sep=[
"\n"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
"""
register_template(
name="intern",
prefix=[
"{{system}}"
],
prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
r"""
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
)
r"""
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"""
register_template(
name="mistral",
prefix=[
"{{system}}"
],
prompt=[
"[INST] {{query}} [/INST]"
],
system="",
sep=[]
)
r"""
Supports: https://huggingface.co/openchat/openchat_3.5
"""
register_template(
name="openchat",
prefix=[
"{{system}}"
],
prompt=[
"GPT4 Correct User: {{query}}",
{"token": "<|end_of_turn|>"},
"GPT4 Correct Assistant:"
],
system="",
sep=[
{"token": "<|end_of_turn|>"}
],
stop_words=[
"<|end_of_turn|>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
"""
register_template(
name="qwen",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}"
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
efficient_eos=True
)
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix=[],
prompt=[
"{{query}}"
],
system="",
sep=[],
use_history=False
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-v1.5
"""
register_template(
name="vicuna",
prefix=[
"{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT:"
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template(
name="xverse",
prefix=[
"{{system}}"
],
prompt=[
"Human: {{query}}\n\nAssistant: "
],
system="",
sep=[]
)
r"""
Supports: https://huggingface.co/wenge-research/yayi-7b
https://huggingface.co/wenge-research/yayi-7b-llama2
https://huggingface.co/wenge-research/yayi-13b-llama2
"""
register_template(
name="yayi",
prefix=[
{"token": "<|System|>"},
":\n{{system}}"
],
prompt=[
{"token": "<|Human|>"},
":\n{{query}}\n\n",
{"token": "<|YaYi|>"},
":"
],
system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[
"\n\n"
],
stop_words=[
"<|End|>"
]
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
"""
register_template(
name="zephyr",
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
{"token": "</s>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "</s>"},
{"token": "<|assistant|>"}
],
system="You are a friendly chatbot who always responds in the style of a pirate",
sep=[]
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
"""
register_template(
name="ziya",
prefix=[
"{{system}}"
],
prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
)