[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
@@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
|
||||
processor: Optional["ProcessorMixin"]
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
|
||||
r"""Get arguments to load config/tokenizer/model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer and optionally loads processor.
|
||||
r"""Load pretrained tokenizer and optionally loads processor.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
r"""
|
||||
Loads model config.
|
||||
"""
|
||||
r"""Load model config."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
@@ -124,9 +120,7 @@ def load_model(
|
||||
is_trainable: bool = False,
|
||||
add_valuehead: bool = False,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads pretrained model.
|
||||
"""
|
||||
r"""Load pretrained model."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
@@ -194,8 +188,9 @@ def load_model(
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
param_stats = (
|
||||
f"trainable params: {trainable_params:,} || "
|
||||
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
|
||||
)
|
||||
else:
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
||||
Reference in New Issue
Block a user