support vllm

Former-commit-id: 889f6e910e654d8ec3922c2185042d737ffbf1c3
This commit is contained in:
hiyouga
2024-03-07 20:26:31 +08:00
parent 9a69cadab3
commit 056d2d956a
32 changed files with 752 additions and 316 deletions

View File

@@ -1,6 +1,6 @@
from .loader import get_dataset
from .template import get_template_and_fix_tokenizer, templates
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = ["get_dataset", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]

View File

@@ -2,7 +2,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
tool_format: Optional[Literal["default"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -83,12 +83,30 @@ class Formatter(ABC):
@dataclass
class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
@@ -109,6 +127,17 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
@@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:

View File

@@ -44,7 +44,7 @@ def load_single_dataset(
elif dataset_attr.load_from == "file":
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))

View File

@@ -19,13 +19,13 @@ class DatasetAttr:
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None
dataset_name: str
""" extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
system: Optional[str] = None
""" columns for the alpaca format """