update UI, fix #212

Former-commit-id: ac92c2bd7c47353759474fad9412f21b38c65501
This commit is contained in:
hiyouga
2023-07-20 22:09:06 +08:00
parent 64db4abc68
commit 9b3304b054
12 changed files with 155 additions and 45 deletions

View File

@@ -6,7 +6,7 @@ import transformers
from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
@@ -77,10 +77,15 @@ class Runner:
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
dev_ratio: float,
fp16: bool,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
@@ -99,7 +104,6 @@ class Runner:
model_name_or_path=model_name_or_path,
do_train=True,
overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
@@ -115,9 +119,15 @@ class Runner:
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
max_grad_norm=float(max_grad_norm),
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
)