web UI integrating RLHF

Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
This commit is contained in:
hiyouga
2023-08-14 10:48:47 +08:00
parent 4933ab5956
commit 688e8601ab
11 changed files with 128 additions and 32 deletions

View File

@@ -37,7 +37,9 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)