[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
r"""Arguments pertaining to the freeze (partial-parameter) training."""
|
||||
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
@@ -56,9 +54,7 @@ class FreezeArguments:
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -128,9 +124,7 @@ class LoraArguments:
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
r"""Arguments pertaining to the PPO, DPO and KTO training."""
|
||||
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
@@ -212,9 +206,7 @@ class RLHFArguments:
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the GaLore algorithm."""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
@@ -253,9 +245,7 @@ class GaloreArguments:
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the APOLLO algorithm."""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
@@ -306,9 +296,7 @@ class ApolloArguments:
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
r"""Arguments pertaining to the BAdam optimizer."""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
@@ -393,9 +381,7 @@ class SwanLabArguments:
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
@@ -452,13 +438,13 @@ class FinetuningArguments(
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
@@ -499,7 +485,7 @@ class FinetuningArguments(
|
||||
if self.pissa_init:
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
Reference in New Issue
Block a user