[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -21,7 +21,7 @@ import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import Trainer
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
@@ -148,9 +145,7 @@ def create_ref_model(
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""
Creates reward model for PPO training.
"""
r"""Create reward model for PPO training."""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
@@ -189,10 +184,8 @@ def create_reward_model(
return reward_model
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
else:
galore_targets = finetuning_args.galore_target
galore_params: List["torch.nn.Parameter"] = []
galore_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
apollo_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
param_dict: dict[str, list[torch.nn.Parameter]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
@@ -524,7 +517,7 @@ def create_custom_scheduler(
) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
@@ -544,13 +537,13 @@ def create_custom_scheduler(
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Computes the log probabilities of the given labels under the given logits.
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
@@ -564,12 +557,10 @@ def get_batch_logps(
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
@@ -585,9 +576,7 @@ def nested_detach(
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
r"""Get the callback for logging to SwanLab."""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
@@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray: