refactor pissa, improve llamaboard

Former-commit-id: 619556e46c19718f702c97df5d570a2a4c5fb13a
This commit is contained in:
hiyouga
2024-06-28 01:04:24 +08:00
parent edc7498111
commit 46f0189e88
16 changed files with 219 additions and 216 deletions

View File

@@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -75,6 +74,7 @@ def run_ppo(
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])