refactor export, fix #1190
Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
@@ -124,7 +124,7 @@ def save_model(
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
save_dir: str
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
@@ -138,8 +138,8 @@ def save_model(
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
if not export_dir:
|
||||
yield ALERTS["err_no_export_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
@@ -147,7 +147,7 @@ def save_model(
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
output_dir=save_dir
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
|
||||
Reference in New Issue
Block a user