support RM metrics, add generating Args

Former-commit-id: c461c6190bc124e98dde7f3cf96a59ce40b26fb0
This commit is contained in:
hiyouga
2023-06-12 15:48:48 +08:00
parent 4c5cad9722
commit 4724ae3492
16 changed files with 177 additions and 163 deletions

View File

@@ -6,18 +6,17 @@
import math
from torch.optim import AdamW
from transformers.optimization import get_scheduler
from trl import PPOConfig
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
DynamicDataCollatorWithPadding,
PPOPeftTrainer,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
plot_loss
)
@@ -29,7 +28,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
data_collator = DynamicDataCollatorWithPadding(tokenizer)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,