Former-commit-id: 032245647848aaa4167086636b6c985268c5fee3
This commit is contained in:
hiyouga
2023-09-21 19:51:02 +08:00
parent 95c0d9ab24
commit dc68c313ee
11 changed files with 116 additions and 101 deletions

View File

@@ -42,12 +42,16 @@ class DataArguments:
default="train",
metadata={"help": "Which dataset split to use for training and evaluation."}
)
cutoff_len: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs after tokenization."}
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=1024,
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
@@ -66,14 +70,6 @@ class DataArguments:
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
)
max_source_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
max_target_length: Optional[int] = field(
default=512,
metadata={"help": "The maximum total output sequence length after tokenization."}
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}

View File

@@ -63,18 +63,10 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
model_max_length: Optional[int] = field(
default=None,
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
)
def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
raise ValueError("These arguments cannot be specified.")
self.compute_dtype = None
self.model_max_length = None
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]