Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
@@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
else:
|
||||
backbone_model = model
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model_params, v_head_params = {}, {}
|
||||
for name in state_dict.keys():
|
||||
if name.startswith("pretrained_model."):
|
||||
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
|
||||
elif name.startswith("v_head."):
|
||||
v_head_params[name.replace("v_head.", "")] = state_dict[name]
|
||||
|
||||
if isinstance(backbone_model, PeftModel): # LoRA tuning
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
|
||||
backbone_model.config.use_cache = True
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
|
||||
safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
backbone_model.config.use_cache = False
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
state_dict = model_params
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||
model.config.use_cache = False
|
||||
else:
|
||||
logger.warning("No model to save.")
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
@@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if isinstance(backbone_model, PeftModel):
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
Reference in New Issue
Block a user