support full-parameter PPO

Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8263b2d32d
commit 7a3a0144a5
19 changed files with 280 additions and 140 deletions

View File

@@ -54,22 +54,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self):
self.compute_dtype = None
@@ -81,8 +69,7 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)