remove PeftTrainer

Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
hiyouga
2023-09-10 22:23:23 +08:00
parent 332d7bbd56
commit a09a7b650d
17 changed files with 75 additions and 259 deletions

View File

@@ -6,18 +6,16 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
beta: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
@@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self.beta = beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)

View File

@@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@@ -37,10 +37,10 @@ def run_dpo(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,