[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -19,7 +19,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
r"""Fix the valuehead checkpoint files.
|
||||
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for fixing the checkpoint for valuehead models.
|
||||
"""
|
||||
r"""A callback for fixing the checkpoint for valuehead models."""
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for saving the processor.
|
||||
"""
|
||||
r"""A callback for saving the processor."""
|
||||
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
self.processor = processor
|
||||
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
r"""A callback for converting the PiSSA adapter to a normal one."""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for logging training and evaluation status.
|
||||
"""
|
||||
r"""A callback for logging training and evaluation status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Progress
|
||||
@@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
# Status
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
@@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
@@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user