Former-commit-id: 032245647848aaa4167086636b6c985268c5fee3
This commit is contained in:
hiyouga
2023-09-21 19:51:02 +08:00
parent 95c0d9ab24
commit dc68c313ee
11 changed files with 116 additions and 101 deletions

View File

@@ -73,11 +73,11 @@ class Runner:
training_stage: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
cutoff_len: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
compute_type: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
@@ -86,7 +86,8 @@ class Runner:
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
flash_attn: bool,
rope_scaling: bool,
lora_rank: int,
lora_dropout: float,
lora_target: str,
@@ -120,8 +121,7 @@ class Runner:
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
cutoff_len=cutoff_len,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
@@ -132,6 +132,8 @@ class Runner:
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling="linear" if rope_scaling else None,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
@@ -168,11 +170,13 @@ class Runner:
system_prompt: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
cutoff_len: int,
max_samples: str,
batch_size: int,
predict: bool
predict: bool,
max_new_tokens: int,
top_p: float,
temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
@@ -200,10 +204,12 @@ class Runner:
system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
max_target_length=max_target_length,
cutoff_len=cutoff_len,
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
output_dir=output_dir
)