web UI integrating RLHF
Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
This commit is contained in:
@@ -91,6 +91,9 @@ class Runner:
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
rlhf_method: str,
|
||||
dpo_beta: float,
|
||||
reward_model: str,
|
||||
output_dir: str
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
@@ -109,7 +112,7 @@ class Runner:
|
||||
overwrite_cache=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
@@ -134,6 +137,21 @@ class Runner:
|
||||
output_dir=output_dir
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if rlhf_method == "Reward Modeling":
|
||||
args["stage"] = "rm"
|
||||
args["resume_lora_training"] = False
|
||||
elif rlhf_method == "PPO":
|
||||
args["stage"] = "ppo"
|
||||
args["resume_lora_training"] = False
|
||||
args["reward_model"] = reward_model
|
||||
args["padding_side"] = "left"
|
||||
val_size = 0
|
||||
elif rlhf_method == "DPO":
|
||||
args["stage"] = "dpo"
|
||||
args["resume_lora_training"] = False
|
||||
args["dpo_beta"] = dpo_beta
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
args["evaluation_strategy"] = "steps"
|
||||
@@ -176,7 +194,7 @@ class Runner:
|
||||
predict_with_generate=True,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
|
||||
Reference in New Issue
Block a user