add kto to webui
Former-commit-id: 6c866f4dbd45e868860be8351d1a65c4e1a4e02b
This commit is contained in:
@@ -145,11 +145,14 @@ class Runner:
|
||||
plot_loss=True,
|
||||
)
|
||||
|
||||
# freeze config
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
|
||||
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
|
||||
elif args["finetuning_type"] == "lora":
|
||||
|
||||
# lora config
|
||||
if args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_alpha"] = get("train.lora_alpha")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
@@ -163,6 +166,7 @@ class Runner:
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
||||
# rlhf config
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
[
|
||||
@@ -171,31 +175,41 @@ class Runner:
|
||||
]
|
||||
)
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
||||
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||
args["top_k"] = 0
|
||||
args["top_p"] = 0.9
|
||||
elif args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
args["dpo_ftx"] = get("train.dpo_ftx")
|
||||
args["dpo_beta"] = get("train.pref_beta")
|
||||
args["dpo_ftx"] = get("train.pref_ftx")
|
||||
args["dpo_loss"] = get("train.pref_loss")
|
||||
elif args["stage"] == "kto":
|
||||
args["kto_beta"] = get("train.pref_beta")
|
||||
args["kto_ftx"] = get("train.pref_ftx")
|
||||
elif args["stage"] == "orpo":
|
||||
args["orpo_beta"] = get("train.orpo_beta")
|
||||
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||
args["orpo_beta"] = get("train.pref_beta")
|
||||
|
||||
# galore config
|
||||
if args["use_galore"]:
|
||||
args["galore_rank"] = get("train.galore_rank")
|
||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
# badam config
|
||||
if args["use_badam"]:
|
||||
args["badam_mode"] = get("train.badam_mode")
|
||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
# eval config
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user