[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -19,7 +19,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import transformers
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
r"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {}
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
r"""A callback for fixing the checkpoint for valuehead models."""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
r"""A callback for saving the processor."""
def __init__(self, processor: "ProcessorMixin") -> None:
self.processor = processor
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class PissaConvertCallback(TrainerCallback):
r"""
A callback for converting the PiSSA adapter to a normal one.
"""
r"""A callback for converting the PiSSA adapter to a normal one."""
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
r"""A callback for logging training and evaluation status."""
def __init__(self) -> None:
# Progress
@@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
# Status
self.aborted = False
self.do_train = False
@@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
@@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
r"""A callback for reporting training status to external logger."""
def __init__(
self,