[misc] update format (#7277)

This commit is contained in:
hoshi-hiyouga
2025-03-13 02:53:08 +08:00
committed by GitHub
parent 4b9d8da5a4
commit 650a9a9057
62 changed files with 384 additions and 288 deletions

38
tests/check_license.py Normal file
View File

@@ -0,0 +1,38 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
def main():
path_list = []
for check_dir in sys.argv[1:]:
path_list.extend(Path(check_dir).glob("**/*.py"))
for path in path_list:
with open(path.absolute(), encoding="utf-8") as f:
file_content = f.read().strip().split("\n")
if not file_content[0]:
continue
print(f"Check license: {path}")
assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
if __name__ == "__main__":
main()

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import os
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import pytest
@@ -97,7 +96,7 @@ def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES,
expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {},

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import os
from collections.abc import Sequence
from typing import TYPE_CHECKING
import pytest
@@ -41,7 +40,7 @@ MESSAGES = [
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
) -> None:
r"""Check token ids and texts.