support lora target auto find

Former-commit-id: bce9984733d88bf013847eed523d1c75fdf0995e
This commit is contained in:
hiyouga
2023-09-09 15:38:37 +08:00
parent 50e93392dd
commit 7143c551ab
11 changed files with 117 additions and 72 deletions

View File

@@ -16,8 +16,8 @@ USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
def get_save_dir(*args) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *args)
def get_config_path() -> os.PathLike:
@@ -29,7 +29,7 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"lang": "", "last_model": "", "path_dict": {}}
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: str, model_path: str) -> None:
@@ -56,7 +56,7 @@ def get_template(model_name: str) -> str:
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (

View File

@@ -16,7 +16,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
info_box = gr.Textbox(show_label=False, interactive=False)
chat_model = WebChatModel()
chat_model = WebChatModel(lazy_init=True)
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
load_btn.click(

View File

@@ -12,7 +12,7 @@ from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.common import get_model_path, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@@ -97,21 +97,25 @@ class Runner:
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
args = dict(
stage=TRAINING_STAGES[training_stage],
model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
overwrite_cache=False,
cache_dir=cache_dir,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
dataset_dir=dataset_dir,
@@ -172,22 +176,26 @@ class Runner:
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
output_dir = get_save_dir(model_name, finetuning_type, "eval_base")
user_config = load_config()
cache_dir = user_config.get("cache_dir", None)
args = dict(
stage="sft",
model_name_or_path=get_model_path(model_name),
do_eval=True,
overwrite_cache=True,
overwrite_cache=False,
predict_with_generate=True,
cache_dir=cache_dir,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
template=template,
system_prompt=system_prompt,
dataset_dir=dataset_dir,

View File

@@ -90,7 +90,7 @@ def get_eval_results(path: os.PathLike) -> str:
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return None
@@ -139,7 +139,7 @@ def save_model(
return
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
)
if not save_dir: