fix rlhf callback

Former-commit-id: f5485452d660caef56474cb7dc37abbe4f34599e
This commit is contained in:
hiyouga
2023-11-16 03:26:19 +08:00
parent e017266b98
commit de3a84ac59
4 changed files with 19 additions and 12 deletions

View File

@@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
@@ -25,16 +26,22 @@ class SavePeftModelCallback(TrainerCallback):
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir)
model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir)