@@ -1,17 +1,19 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from peft import PeftModel
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -20,31 +22,66 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _save_model_with_valuehead(
|
||||
def _fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
output_dir: str,
|
||||
safe_serialization: bool
|
||||
) -> None:
|
||||
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=decoder_state_dict or None,
|
||||
safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
_fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors
|
||||
)
|
||||
@@ -54,10 +91,8 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=args.output_dir,
|
||||
safe_serialization=args.save_safetensors
|
||||
_fix_valuehead_checkpoint(
|
||||
model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user