fix ppo train and dpo eval

Former-commit-id: ced863031836632cb5920e22ae6991f251372118
This commit is contained in:
hiyouga
2023-11-07 22:48:51 +08:00
parent 14a38b5069
commit f5ba2190fb
5 changed files with 56 additions and 21 deletions

View File

@@ -75,6 +75,14 @@ class FinetuningArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@@ -91,7 +99,7 @@ class FinetuningArguments:
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""

View File

@@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
@@ -44,7 +44,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field(
default=False,
@@ -83,3 +83,6 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)