support function calling

Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent f7329b1a0e
commit a423274fd9
67 changed files with 1239 additions and 1079 deletions

View File

@@ -1,7 +1,8 @@
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@@ -12,6 +13,14 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
OBSERVATION = "observation"
FUNCTION = "function"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
@@ -27,6 +36,13 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, data_args.reserved_label_len)
max_source_len = data_args.cutoff_len - max_target_len
return max_source_len, max_target_len
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",