fix badam configs

Former-commit-id: 8a4e6a4c65a9a42e6501b0d3ce81d6220c287454
This commit is contained in:
hiyouga
2024-05-02 02:47:04 +08:00
parent cd4dad846b
commit dd0b85580e
5 changed files with 44 additions and 69 deletions

View File

@@ -147,11 +147,11 @@ class Runner:
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
)
args["disable_tqdm"] = True
@@ -201,11 +201,9 @@ class Runner:
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")
return args