[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
@@ -43,15 +44,13 @@ class Role(str, Enum):
|
||||
|
||||
class DatasetModule(TypedDict):
|
||||
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Merges multiple datasets to a unified dataset.
|
||||
"""
|
||||
r"""Merge multiple datasets to a unified dataset."""
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
|
||||
@@ -78,14 +77,13 @@ def merge_dataset(
|
||||
|
||||
def split_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||
data_args: "DataArguments",
|
||||
seed: int,
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
||||
|
||||
Supports both map dataset and iterable dataset.
|
||||
Support both map dataset and iterable dataset.
|
||||
"""
|
||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
@@ -120,10 +118,8 @@ def split_dataset(
|
||||
|
||||
|
||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||
r"""
|
||||
Converts dataset or dataset dict to dataset module.
|
||||
"""
|
||||
dataset_module: "DatasetModule" = {}
|
||||
r"""Convert dataset or dataset dict to dataset module."""
|
||||
dataset_module: DatasetModule = {}
|
||||
if isinstance(dataset, DatasetDict): # dataset dict
|
||||
if "train" in dataset:
|
||||
dataset_module["train_dataset"] = dataset["train"]
|
||||
|
||||
Reference in New Issue
Block a user