add ds config to webui
Former-commit-id: 66d72b263d36dc81de9f6152077663b613035977
This commit is contained in:
@@ -10,7 +10,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
@@ -201,6 +201,12 @@ class Runner:
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
|
||||
# ds config
|
||||
if get("train.ds_stage") != "none":
|
||||
ds_stage = get("train.ds_stage")
|
||||
ds_offload = "offload_" if get("train.ds_offload") else ""
|
||||
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload))
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user