Former-commit-id: 07c8d734529f03e47ef638a1bda222e8824d3d38
This commit is contained in:
hiyouga
2023-11-14 18:07:20 +08:00
parent 87197ba91d
commit c9a4551012
2 changed files with 32 additions and 22 deletions

View File

@@ -52,6 +52,10 @@ class DataArguments:
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1,
metadata={"help": "The maximum length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}
@@ -110,6 +114,9 @@ class DataArguments:
)
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")