update webui #1086

Former-commit-id: 65a48bc398f18f71f5f2659b2070e3b9593af243
This commit is contained in:
hiyouga
2023-10-09 14:50:14 +08:00
parent f22886e2b6
commit 5c4248a29c
10 changed files with 105 additions and 56 deletions

View File

@@ -70,6 +70,9 @@ class Runner:
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
training_stage: str,
dataset_dir: str,
dataset: List[str],
@@ -86,8 +89,6 @@ class Runner:
logging_steps: int,
save_steps: int,
warmup_steps: int,
flash_attn: bool,
rope_scaling: bool,
lora_rank: int,
lora_dropout: float,
lora_target: str,
@@ -97,9 +98,7 @@ class Runner:
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
else:
checkpoint_dir = None
@@ -119,6 +118,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,
@@ -132,8 +134,6 @@ class Runner:
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
flash_attn=flash_attn,
rope_scaling="linear" if rope_scaling else None,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
@@ -168,6 +168,9 @@ class Runner:
quantization_bit: str,
template: str,
system_prompt: str,
flash_attn: bool,
shift_attn: bool,
rope_scaling: str,
dataset_dir: str,
dataset: List[str],
cutoff_len: int,
@@ -179,9 +182,7 @@ class Runner:
temperature: float
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
@@ -202,6 +203,9 @@ class Runner:
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
flash_attn=flash_attn,
shift_attn=shift_attn,
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
cutoff_len=cutoff_len,