allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if disable_dropout:
|
||||
@@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
train_eval: Optional[Literal["train", "eval"]] = "train",
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
|
||||
@@ -292,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False,
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
@@ -46,7 +46,7 @@ def create_modelcard_and_push(
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
Reference in New Issue
Block a user