fix generating args

Former-commit-id: 52805a8441bd7b324bd89489de60f18f103c8e4c
This commit is contained in:
hiyouga
2023-06-13 01:33:56 +08:00
parent 4724ae3492
commit 6828f07d54
5 changed files with 20 additions and 16 deletions

View File

@@ -87,6 +87,8 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
@dataclass
class DataTrainingArguments:
@@ -125,7 +127,7 @@ class DataTrainingArguments:
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
)
num_beams: Optional[int] = field(
eval_num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
)
@@ -164,7 +166,7 @@ class DataTrainingArguments:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
file_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
@@ -262,7 +264,7 @@ class GeneratingArguments:
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
infer_num_beams: Optional[int] = field(
num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
@@ -276,7 +278,4 @@ class GeneratingArguments:
)
def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict
return asdict(self)