add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Merges multiple datasets to a unified dataset.
|
||||
"""
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
@@ -67,14 +70,16 @@ def merge_dataset(
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
||||
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||
|
||||
Supports both map dataset and iterable dataset.
|
||||
"""
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
||||
@@ -16,21 +16,36 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tool_utils import FunctionCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""
|
||||
Forms a list of slots according to the inputs to encode.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
|
||||
if has_placeholder:
|
||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
|
||||
if not has_placeholder:
|
||||
raise ValueError("A placeholder is required in the string formatter.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
@@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
@@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
@@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
@@ -48,6 +48,9 @@ def _load_single_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
@@ -117,7 +120,7 @@ def _load_single_dataset(
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
@@ -141,6 +144,9 @@ def _get_merged_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Gets the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
@@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Preprocesses the dataset, including format checking and tokenization.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@@ -209,6 +218,9 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
|
||||
@@ -3,6 +3,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
@@ -209,6 +210,7 @@ class BasePlugin:
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
@@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
@@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
@@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
@@ -152,6 +153,7 @@ class Template:
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
@override
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -195,7 +197,7 @@ class Llama2Template(Template):
|
||||
return encoded_messages
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
@@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||
|
||||
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
jinja_template = ""
|
||||
|
||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||
@@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
|
||||
@@ -15,9 +15,12 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS: ...
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||
def get_function_slots() -> SLOTS:
|
||||
r"""
|
||||
Gets a list of slots corresponding to a single function call.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the response message.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user