[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -13,14 +13,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional
@dataclass
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
r"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers: int = field(
default=2,
@@ -56,9 +54,7 @@ class FreezeArguments:
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
default=None,
@@ -128,9 +124,7 @@ class LoraArguments:
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
r"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta: float = field(
default=0.1,
@@ -212,9 +206,7 @@ class RLHFArguments:
@dataclass
class GaloreArguments:
r"""
Arguments pertaining to the GaLore algorithm.
"""
r"""Arguments pertaining to the GaLore algorithm."""
use_galore: bool = field(
default=False,
@@ -253,9 +245,7 @@ class GaloreArguments:
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
r"""Arguments pertaining to the APOLLO algorithm."""
use_apollo: bool = field(
default=False,
@@ -306,9 +296,7 @@ class ApolloArguments:
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
r"""Arguments pertaining to the BAdam optimizer."""
use_badam: bool = field(
default=False,
@@ -393,9 +381,7 @@ class SwanLabArguments:
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16: bool = field(
default=False,
@@ -452,13 +438,13 @@ class FinetuningArguments(
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@@ -499,7 +485,7 @@ class FinetuningArguments(
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args