better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
hiyouga
2024-05-29 23:55:38 +08:00
parent f90c4ca672
commit 87aa332583
14 changed files with 303 additions and 193 deletions

View File

@@ -3,12 +3,13 @@ import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from yaml import safe_dump
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from .common import DEFAULT_CACHE_DIR
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
from .locales import ALERTS
@@ -17,13 +18,26 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
@@ -38,11 +52,17 @@ def check_json_schema(text: str, lang: str) -> None:
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
@@ -52,17 +72,39 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
@@ -96,17 +138,56 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
return running_log, running_progress, running_loss
def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_ds_config() -> None:
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
r"""
Saves arguments.
"""
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_arg_save_path(config_path))
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [initial_dir]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
r"""
Check if output dir exists.
"""
if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",