support quantization in export model

Former-commit-id: f32500ae6edccab7d14df4c92467e15986866def
This commit is contained in:
hiyouga
2023-12-15 23:44:50 +08:00
parent 9121722999
commit 296711d502
9 changed files with 120 additions and 32 deletions

View File

@@ -125,7 +125,38 @@ class RLHFArguments:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
class ExportArguments:
r"""
Arguments pertaining to model exporting.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."}
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."}
)
export_quantization_maxlen: Optional[str] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
@@ -141,14 +172,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
@@ -170,6 +193,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
@@ -177,6 +201,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"