support vllm

Former-commit-id: 889f6e910e654d8ec3922c2185042d737ffbf1c3
This commit is contained in:
hiyouga
2024-03-07 20:26:31 +08:00
parent 9a69cadab3
commit 056d2d956a
32 changed files with 752 additions and 316 deletions

View File

@@ -2,7 +2,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
tool_format: Optional[Literal["default"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -83,12 +83,30 @@ class Formatter(ABC):
@dataclass
class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
@@ -109,6 +127,17 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
@@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try: