update webUI, fix #179

Former-commit-id: f9074fed5e22585679661588befcf266a79009f2
This commit is contained in:
hiyouga
2023-07-18 15:35:17 +08:00
parent fd8c2d4aac
commit a864a7b395
9 changed files with 247 additions and 154 deletions

View File

@@ -3,7 +3,7 @@ import os
import threading
import time
import transformers
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
@@ -59,10 +59,26 @@ class Runner:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
fp16: bool,
logging_steps: int,
save_steps: int,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
@@ -79,24 +95,25 @@ class Runner:
args = dict(
model_name_or_path=model_name_or_path,
do_train=True,
finetuning_type=finetuning_type,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
max_samples=int(max_samples),
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
)
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
@@ -120,8 +137,19 @@ class Runner:
yield self.finalize(lang)
def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_samples: str,
batch_size: int,
predict: bool
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
@@ -140,17 +168,18 @@ class Runner:
args = dict(
model_name_or_path=model_name_or_path,
do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_samples=int(max_samples),
per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None
output_dir=output_dir
)
if predict: