[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -14,7 +14,6 @@
import os
import random
from typing import Dict, List
import pytest
from datasets import load_dataset
@@ -43,7 +42,7 @@ TRAIN_ARGS = {
}
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str, str]]:
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
new_messages = []
for message in messages:

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import pytest
@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen
((10, 10, 1000), (10, 10)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
def test_infer_seqlen(test_input: tuple[int, int, int], test_output: tuple[int, int]):
assert test_output == infer_seqlen(*test_input)

View File

@@ -112,7 +112,8 @@ def test_glm4_tool_formatter():
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
@@ -136,7 +137,8 @@ def test_llama3_tool_formatter():
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]

View File

@@ -13,7 +13,8 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import pytest
import torch
@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
@@ -96,11 +97,11 @@ def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {},
) -> None:
# test mm_messages
if plugin.__class__.__name__ != "BasePlugin":

View File

@@ -13,7 +13,8 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING
import pytest
from transformers import AutoTokenizer
@@ -42,8 +43,7 @@ MESSAGES = [
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
) -> None:
r"""
Checks token ids and texts.
r"""Check token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
@@ -54,8 +54,7 @@ def _check_tokenization(
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
r"""
Checks template.
r"""Check template.
Args:
model_id: the model id on hugging face hub.
@@ -63,6 +62,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)