use fp16 model, add logcallback

Former-commit-id: bea275d51338b49ce855eec0178e759607265e3d
This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 17024ebc1a
commit 1fc551e1be
7 changed files with 112 additions and 10 deletions

View File

@@ -17,6 +17,7 @@ from utils import (
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
LogCallback,
plot_loss
)
@@ -54,6 +55,7 @@ def main():
ppo_trainer = PPOTrainerForLLaMA(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],
config=ppo_config,
model=model,
ref_model=None,