allow non-packing pretraining

Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent c631799f5d
commit 4881f4e631
22 changed files with 64 additions and 67 deletions

View File

@@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
@@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train",
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.