refactor export, fix #1190

Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
hiyouga
2023-10-15 16:01:48 +08:00
parent 68330eab2a
commit c2e84d4558
9 changed files with 52 additions and 49 deletions

View File

@@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
with gr.Row():
save_dir = gr.Textbox()
export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button()
@@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem("top.finetuning_type"),
engine.manager.get_elem("top.template"),
max_shard_size,
save_dir
export_dir
],
[info_box]
)
elem_dict.update(dict(
save_dir=save_dir,
export_dir=export_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box

View File

@@ -531,7 +531,7 @@ LOCALES = {
"label": "温度系数"
}
},
"save_dir": {
"export_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
@@ -587,7 +587,7 @@ ALERTS = {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"err_no_export_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},

View File

@@ -124,7 +124,7 @@ def save_model(
finetuning_type: str,
template: str,
max_shard_size: int,
save_dir: str
export_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
@@ -138,8 +138,8 @@ def save_model(
yield ALERTS["err_no_checkpoint"][lang]
return
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
if not export_dir:
yield ALERTS["err_no_export_dir"][lang]
return
args = dict(
@@ -147,7 +147,7 @@ def save_model(
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type,
template=template,
output_dir=save_dir
export_dir=export_dir
)
yield ALERTS["info_exporting"][lang]