support DPO training (2305.18290)

Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 72dfd74005
commit ca719a8697
33 changed files with 513 additions and 192 deletions

View File

@@ -1,12 +1,11 @@
import torch
from typing import Literal, Optional
from dataclasses import dataclass, field
from huggingface_hub.hf_api import HfFolder
@dataclass
class ModelArguments:
"""
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
@@ -64,12 +63,11 @@ class ModelArguments:
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_hub_token : Optional[str] = field(
hf_auth_token: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@@ -77,5 +75,6 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_hub_token != None:
HfFolder.save_token(self.hf_hub_token)
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)