remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
@@ -6,18 +6,16 @@ from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
beta: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
@@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self.beta = beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user