[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -19,7 +19,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
r"""Fix the valuehead checkpoint files.
|
||||
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for fixing the checkpoint for valuehead models.
|
||||
"""
|
||||
r"""A callback for fixing the checkpoint for valuehead models."""
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for saving the processor.
|
||||
"""
|
||||
r"""A callback for saving the processor."""
|
||||
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
self.processor = processor
|
||||
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
r"""A callback for converting the PiSSA adapter to a normal one."""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for logging training and evaluation status.
|
||||
"""
|
||||
r"""A callback for logging training and evaluation status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Progress
|
||||
@@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
# Status
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
@@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
@@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||
log_odds = (chosen_logps - rejected_logps) - (
|
||||
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
||||
)
|
||||
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return orpo_loss
|
||||
|
||||
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes SimPO loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute SimPO loss for batched log probabilities of the policy model."""
|
||||
pi_logratios = chosen_logps - rejected_logps
|
||||
gamma_logratios = self.simpo_gamma / self.beta
|
||||
logits = pi_logratios - gamma_logratios
|
||||
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
policy_rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: Optional["torch.Tensor"],
|
||||
reference_rejected_logps: Optional["torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes loss for preference learning.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute loss for preference learning."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
if self.loss_type == "orpo":
|
||||
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
return None, None
|
||||
|
||||
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -38,7 +38,7 @@ def run_dpo(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
@@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Run forward pass and computes the log probabilities."""
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
|
||||
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||
if f"count/{split}" in metric_dict:
|
||||
for key in ("rewards", "logps", "logits"):
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -37,7 +37,7 @@ def run_kto(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@@ -31,10 +31,8 @@ if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Gets reward scores from the API server.
|
||||
"""
|
||||
def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
|
||||
r"""Get reward scores from the API server."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
r"""
|
||||
Replaces the default/reward modules in the model. The model is already unwrapped.
|
||||
"""
|
||||
r"""Replace the default/reward modules in the model. The model is already unwrapped."""
|
||||
v_head_layer = model.v_head.summary
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
|
||||
r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.data.dtype == torch.float32:
|
||||
@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
return layer_norm_params
|
||||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""
|
||||
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
r"""Inherit PPOTrainer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]],
|
||||
callbacks: Optional[list["TrainerCallback"]],
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
|
||||
if resume_from_checkpoint is not None:
|
||||
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||
|
||||
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info_rank0(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
|
||||
)
|
||||
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
@@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return lr_scheduler
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
|
||||
r"""Generate model's responses given queries."""
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
@@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List["torch.Tensor"],
|
||||
responses: List["torch.Tensor"],
|
||||
) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
queries: list["torch.Tensor"],
|
||||
responses: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
r"""Compute scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
@@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
@@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
queries: "torch.Tensor",
|
||||
responses: "torch.Tensor",
|
||||
model_inputs: Dict[str, Any],
|
||||
model_inputs: dict[str, Any],
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional["torch.Tensor"] = None,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Calculate model outputs in multiple batches.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
r"""Save model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
self._save(output_dir, state_dict=unwrapped_model.state_dict())
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@@ -37,7 +37,7 @@ def run_ppo(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -53,7 +53,7 @@ def run_ppo(
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||
ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
|
||||
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer for custom optimizer.
|
||||
"""
|
||||
r"""Inherit Trainer for custom optimizer."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
@@ -38,7 +38,7 @@ def run_pt(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -41,7 +39,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer to compute pairwise loss.
|
||||
"""
|
||||
r"""Inherits Trainer to compute pairwise loss."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
res: list[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@@ -37,7 +37,7 @@ def run_rm(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -45,9 +45,7 @@ if is_rouge_available():
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes the token with the largest likelihood to reduce memory footprint.
|
||||
"""
|
||||
r"""Compute the token with the largest likelihood to reduce memory footprint."""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -77,7 +73,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
for i in range(len(preds)):
|
||||
pred, label = preds[i, :-1], labels[i, 1:]
|
||||
@@ -90,15 +86,14 @@ class ComputeAccuracy:
|
||||
|
||||
@dataclass
|
||||
class ComputeSimilarity:
|
||||
r"""
|
||||
Computes text similarity scores and supports `batch_eval_metrics`.
|
||||
r"""Compute text similarity scores and support `batch_eval_metrics`.
|
||||
|
||||
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -109,7 +104,7 @@ class ComputeSimilarity:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
inputs: Dict[str, Union["torch.Tensor", Any]],
|
||||
inputs: dict[str, Union["torch.Tensor", Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
**gen_kwargs,
|
||||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Remove the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def save_predictions(
|
||||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -43,7 +43,7 @@ def run_sft(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
|
||||
linear_modules, extra_modules = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
@@ -83,7 +84,7 @@ def load_reference_model(
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
current_device = get_current_device()
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||
)
|
||||
if not is_trainable:
|
||||
@@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
|
||||
|
||||
|
||||
def patch_valuehead_model() -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
@@ -21,7 +21,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
A dummy optimizer used for the GaLore or APOLLO algorithm.
|
||||
"""
|
||||
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
|
||||
|
||||
def __init__(
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
) -> None:
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
self.optimizer_dict = optimizer_dict
|
||||
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
|
||||
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
@@ -148,9 +145,7 @@ def create_ref_model(
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
r"""Create reward model for PPO training."""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
|
||||
@@ -189,10 +184,8 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
"""
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
|
||||
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
|
||||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
return decay_parameters
|
||||
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
|
||||
else:
|
||||
galore_targets = finetuning_args.galore_target
|
||||
|
||||
galore_params: List["torch.nn.Parameter"] = []
|
||||
galore_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
for param in module.parameters():
|
||||
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
|
||||
|
||||
id_galore_params = {id(param) for param in galore_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
|
||||
else:
|
||||
apollo_targets = finetuning_args.apollo_target
|
||||
|
||||
apollo_params: List["torch.nn.Parameter"] = []
|
||||
apollo_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
|
||||
for param in module.parameters():
|
||||
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-apollo parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
param_dict: dict[str, list[torch.nn.Parameter]] = {
|
||||
"lora_a": [],
|
||||
"lora_b": [],
|
||||
"lora_b_nodecay": [],
|
||||
@@ -524,7 +517,7 @@ def create_custom_scheduler(
|
||||
) -> None:
|
||||
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
@@ -544,13 +537,13 @@ def create_custom_scheduler(
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the log probabilities of the given labels under the given logits.
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
|
||||
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
|
||||
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
|
||||
@@ -564,12 +557,10 @@ def get_batch_logps(
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
@@ -585,9 +576,7 @@ def nested_detach(
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
r"""Get the callback for logging to SwanLab."""
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
@@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -48,9 +48,9 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
def _training_function(config: dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
callbacks: list[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
@@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None:
|
||||
logger.warning(f"Failed to destroy process group: {e}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
|
||||
args = read_args(args)
|
||||
if "-h" in args or "--help" in args:
|
||||
get_train_args(args)
|
||||
@@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
|
||||
Reference in New Issue
Block a user