support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
This commit is contained in:
@@ -54,22 +54,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
@@ -81,8 +69,7 @@ class ModelArguments:
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
Reference in New Issue
Block a user