support DPO training (2305.18290)

Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 72dfd74005
commit ca719a8697
33 changed files with 513 additions and 192 deletions

View File

@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer):
class PeftModelMixin:
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self._remove_log()
def _remove_log(self):
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
@@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args