add kto to webui

Former-commit-id: 6c866f4dbd45e868860be8351d1a65c4e1a4e02b
This commit is contained in:
hiyouga
2024-05-20 21:20:25 +08:00
parent ab48653e63
commit e30975e9a2
3 changed files with 91 additions and 38 deletions

View File

@@ -145,11 +145,14 @@ class Runner:
plot_loss=True,
)
# freeze config
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
elif args["finetuning_type"] == "lora":
# lora config
if args["finetuning_type"] == "lora":
args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout")
@@ -163,6 +166,7 @@ class Runner:
if args["use_llama_pro"]:
args["num_layer_trainable"] = get("train.num_layer_trainable")
# rlhf config
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
[
@@ -171,31 +175,41 @@ class Runner:
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
args["top_p"] = 0.9
elif args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx")
args["dpo_beta"] = get("train.pref_beta")
args["dpo_ftx"] = get("train.pref_ftx")
args["dpo_loss"] = get("train.pref_loss")
elif args["stage"] == "kto":
args["kto_beta"] = get("train.pref_beta")
args["kto_ftx"] = get("train.pref_ftx")
elif args["stage"] == "orpo":
args["orpo_beta"] = get("train.orpo_beta")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
args["orpo_beta"] = get("train.pref_beta")
# galore config
if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
# badam config
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
args["evaluation_strategy"] = "steps"
args["eval_steps"] = args["save_steps"]
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: