[webui] improve webui & reasoning mode (#6778)

Former-commit-id: 3f17fc0d7163372e0446f1a38792ff761e99b739
This commit is contained in:
hoshi-hiyouga
2025-01-31 00:09:21 +08:00
committed by GitHub
parent 4f298894da
commit e71737351f
18 changed files with 570 additions and 409 deletions

View File

@@ -14,34 +14,28 @@
import json
import os
import signal
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from typing import Any, Dict, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
from ..extras import logging
from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG,
DEFAULT_TEMPLATE,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
TRAINING_STAGES,
TRAINING_ARGS,
VISION_MODELS,
DownloadSource,
)
from ..extras.misc import use_modelscope, use_openmind
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
@@ -49,6 +43,21 @@ DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
@@ -61,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
try:
with open(get_config_path(), encoding="utf-8") as f:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@@ -92,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
if model_name and model_path:
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
with open(_get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f)
@@ -120,20 +129,9 @@ def get_model_path(model_name: str) -> str:
return model_path
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
"""
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
Gets the template name if the model is a chat/distill/instruct model.
"""
return DEFAULT_TEMPLATE.get(model_name, "default")
@@ -145,24 +143,11 @@ def get_visual(model_name: str) -> bool:
return model_name in VISION_MODELS
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def get_time() -> str:
r"""
Lists all available checkpoints.
Gets current date and time.
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
@@ -181,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Lists all available datasets in the dataset dir for the training stage.
Loads the training configuration from config path.
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(_clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)