refactor export, fix #1190
Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
@@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
@@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem("top.finetuning_type"),
|
||||
engine.manager.get_elem("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
export_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
elem_dict.update(dict(
|
||||
save_dir=save_dir,
|
||||
export_dir=export_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
|
||||
@@ -531,7 +531,7 @@ LOCALES = {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
@@ -587,7 +587,7 @@ ALERTS = {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
|
||||
@@ -124,7 +124,7 @@ def save_model(
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
save_dir: str
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
@@ -138,8 +138,8 @@ def save_model(
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
if not export_dir:
|
||||
yield ALERTS["err_no_export_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
@@ -147,7 +147,7 @@ def save_model(
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
output_dir=save_dir
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
|
||||
Reference in New Issue
Block a user