web UI integrating RLHF

Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
This commit is contained in:
hiyouga
2023-08-14 10:48:47 +08:00
parent 4933ab5956
commit 688e8601ab
11 changed files with 128 additions and 32 deletions

View File

@@ -91,6 +91,9 @@ class Runner:
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
rlhf_method: str,
dpo_beta: float,
reward_model: str,
output_dir: str
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
@@ -109,7 +112,7 @@ class Runner:
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,
@@ -134,6 +137,21 @@ class Runner:
output_dir=output_dir
)
args[compute_type] = True
if rlhf_method == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif rlhf_method == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif rlhf_method == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
@@ -176,7 +194,7 @@ class Runner:
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None,
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
source_prefix=source_prefix,
dataset_dir=dataset_dir,