add resume args in webui

Former-commit-id: 1d86ad768b1f36e54b4c2a9f18f6ea5a7df04c90
This commit is contained in:
hiyouga
2024-06-08 00:22:16 +08:00
parent f45e81e186
commit 3f6b3eed98
7 changed files with 68 additions and 49 deletions

View File

@@ -8,10 +8,10 @@ import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
@@ -93,10 +93,10 @@ def save_cmd(args: Dict[str, Any]) -> str:
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINER_CONFIG)
return os.path.join(output_dir, TRAINING_ARGS)
def get_eval_results(path: os.PathLike) -> str:
@@ -157,22 +157,19 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
Loads saved arguments.
"""
try:
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
with open(config_path, "r", encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
def save_args(config_path: str, config_dict: Dict[str, Any]):
r"""
Saves arguments.
"""
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
return str(get_arg_save_path(config_path))
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
@@ -181,13 +178,13 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
config_files = ["{}.yaml".format(current_time)]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml"):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown":
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
@@ -203,14 +200,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -
return gr.Dropdown(choices=output_dirs)
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
r"""
Check if output dir exists.
"""
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
def create_ds_config() -> None:
r"""
Creates deepspeed config.