fix config, #1191
Former-commit-id: 5dbc9b355e85b203cb43ff72589374f0e04be391
This commit is contained in:
@@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_module, get_save_dir
|
||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
@@ -74,6 +74,7 @@ class Runner:
|
||||
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
@@ -89,7 +90,7 @@ class Runner:
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
overwrite_cache=False,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
@@ -142,6 +143,7 @@ class Runner:
|
||||
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
@@ -160,7 +162,7 @@ class Runner:
|
||||
do_eval=True,
|
||||
overwrite_cache=False,
|
||||
predict_with_generate=True,
|
||||
cache_dir=get("top.config").get("cache_dir", None),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
|
||||
Reference in New Issue
Block a user