allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if disable_dropout:
|
||||
@@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
train_eval: Optional[Literal["train", "eval"]] = "train",
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
|
||||
Reference in New Issue
Block a user