[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
@@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
@@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer and optionally loads processor.
r"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
@@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
r"""Load model config."""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
@@ -124,9 +120,7 @@ def load_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
@@ -194,8 +188,9 @@ def load_model(
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
param_stats = (
f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
)
else:
param_stats = f"all params: {all_param:,}"