allow non-packing pretraining

Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent c631799f5d
commit 4881f4e631
22 changed files with 64 additions and 67 deletions

View File

@@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
disable_dropout: bool = True,
**kwargs,
):
if disable_dropout:
@@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train",
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.

View File

@@ -292,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
r"""

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
import torch
from transformers import Trainer
@@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
self.can_return_loss = True # override property to return eval_loss
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

View File

@@ -46,7 +46,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.