[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -17,7 +17,7 @@ import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
r"""Abort the processes recursively in a bottom-up way."""
try:
children = Process(pid).children()
if children:
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
r"""Get the path to saved model checkpoints."""
if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1]
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
r"""Get the path to user config."""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str:
r"""
Gets the model path according to the model name.
"""
r"""Get the model path according to the model name."""
user_config = load_config()
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if (
use_modelscope()
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat/distill/instruct model.
"""
r"""Get the template name if the model is a chat/distill/instruct model."""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str:
r"""
Gets current date and time.
"""
r"""Get current date and time."""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
r"""Judge if the model is a vision language model."""
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r"""Load dataset_info.json."""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads the training configuration from config path.
"""
def load_args(config_path: str) -> Optional[dict[str, Any]]:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r"""Save the training configuration to config path."""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
def gen_cmd(args: dict[str, Any]) -> str:
r"""Generate CLI commands for previewing."""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
def save_cmd(args: dict[str, Any]) -> str:
r"""Save CLI commands to launch training."""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
r"""Get scores after evaluation."""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",