support rope scaling, fix #475 #476 #478

Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
This commit is contained in:
hiyouga
2023-08-12 20:46:27 +08:00
parent cde9f3db57
commit fdfb644f0a
12 changed files with 267 additions and 277 deletions

View File

@@ -5,16 +5,16 @@ import threading
import time
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Tuple
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, SFT_SCRIPT_PREFIX
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import get_eval_results, update_process_bar
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner:
@@ -22,39 +22,36 @@ class Runner:
def __init__(self):
self.aborted = False
self.running = False
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self):
self.aborted = True
self.running = False
def initialize(
def _initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]:
) -> str:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
return ALERTS["err_conflict"][lang]
if not model_name:
return None, ALERTS["err_no_model"][lang], None, None
return ALERTS["err_no_model"][lang]
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
return None, ALERTS["err_no_path"][lang], None, None
if not get_model_path(model_name):
return ALERTS["err_no_path"][lang]
if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None
return ALERTS["err_no_dataset"][lang]
self.aborted = False
self.running = True
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
logger_handler = LoggerHandler()
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
def _finalize(
self, lang: str, finish_info: str
) -> str:
self.running = False
@@ -64,7 +61,7 @@ class Runner:
else:
return finish_info
def run_train(
def _parse_train_args(
self,
lang: str,
model_name: str,
@@ -95,52 +92,19 @@ class Runner:
lora_target: str,
resume_lora_training: bool,
output_dir: str
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank,
lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length,
max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir,
padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix,
template, val_size, warmup_steps)
run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False)
def _build_args(self, batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank, lora_target,
lr_scheduler_type, max_grad_norm, max_samples, max_source_length, max_target_length, model_name,
model_name_or_path, num_train_epochs, output_dir, padding_side, quantization_bit,
resume_lora_training, save_steps, source_prefix, template, val_size, warmup_steps):
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
[os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
)
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
@@ -171,14 +135,16 @@ class Runner:
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
return args
def run_eval(
return lang, model_name, dataset, output_dir, args
def _parse_eval_args(
self,
lang: str,
model_name: str,
@@ -194,12 +160,7 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@@ -211,7 +172,7 @@ class Runner:
args = dict(
stage="sft",
model_name_or_path=model_name_or_path,
model_name_or_path=get_model_path(model_name),
do_eval=True,
overwrite_cache=True,
predict_with_generate=True,
@@ -233,7 +194,33 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
run_kwargs = dict(args=args, callbacks=[trainer_callback])
return lang, model_name, dataset, output_dir, args
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
@@ -242,60 +229,37 @@ class Runner:
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield logger_handler.log, update_process_bar(trainer_callback)
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info), gr.update(visible=False)
def preview_sft_script(
self,
lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
quantization_bit: str,
template: str,
source_prefix: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
max_target_length: int,
learning_rate: str,
num_train_epochs: str,
max_samples: str,
batch_size: int,
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
output_dir: str
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = self._build_args(batch_size, checkpoints, compute_type, dataset, dataset_dir, finetuning_type,
gradient_accumulation_steps, learning_rate, logging_steps, lora_dropout, lora_rank,
lora_target, lr_scheduler_type, max_grad_norm, max_samples, max_source_length,
max_target_length, model_name, model_name_or_path, num_train_epochs, output_dir,
padding_side, quantization_bit, resume_lora_training, save_steps, source_prefix,
template, val_size, warmup_steps)
script_lines = [SFT_SCRIPT_PREFIX]
for param_key, param_value in args.items():
# filter None
if param_value:
script_lines.append(" --" + param_key + " " + str(param_value) + " ")
script_str = "\\\n".join(script_lines)
return gr.update(value=script_str)
yield self._finalize(lang, finish_info), gr.update(visible=False)