[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
@@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
|
||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
@@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
|
||||
Reference in New Issue
Block a user