add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)

View File

@@ -16,21 +16,36 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS:
return self.slots
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
@@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
@@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
def __post_init__(self):
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
@@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)

View File

@@ -48,6 +48,9 @@ def _load_single_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
@@ -117,7 +120,7 @@ def _load_single_dataset(
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
@@ -141,6 +144,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None:
return None
@@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None:
return None
@@ -209,6 +218,9 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):

View File

@@ -3,6 +3,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available
@@ -209,6 +210,7 @@ class BasePlugin:
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
@@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
@@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
@@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
class Qwen2vlPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
@@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],

View File

@@ -16,6 +16,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger
from .data_utils import Role
@@ -152,6 +153,7 @@ class Template:
@dataclass
class Llama2Template(Template):
@override
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@@ -195,7 +197,7 @@ class Llama2Template(Template):
return encoded_messages
TEMPLATES: Dict[str, Template] = {}
TEMPLATES: Dict[str, "Template"] = {}
def _register_template(
@@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"

View File

@@ -15,9 +15,12 @@
import json
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass
class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS: ...
"""
Base class for tool utilities.
"""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content:
return content