add logits processor

Former-commit-id: f6f4b1554ae1e8849b437d705ffa34ce7ebd56bb
This commit is contained in:
hiyouga
2023-06-03 16:34:54 +08:00
parent ec48d06b9e
commit 9b8b6623ac
5 changed files with 22 additions and 16 deletions

View File

@@ -49,6 +49,14 @@ class ModelArguments:
default=None,
metadata={"help": "The number of bits to quantize the model."}
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
@@ -206,14 +214,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
"""Create an instance from the content of `json_path`."""
"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))