refactor pissa, improve llamaboard

Former-commit-id: 619556e46c19718f702c97df5d570a2a4c5fb13a
This commit is contained in:
hiyouga
2024-06-28 01:04:24 +08:00
parent edc7498111
commit 46f0189e88
16 changed files with 219 additions and 216 deletions

View File

@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,15 +17,11 @@
import gc
import os
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
@@ -31,15 +30,9 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
@@ -48,8 +41,6 @@ except Exception:
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
@@ -99,7 +90,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
@@ -117,51 +108,6 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
return
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {}
v_head_state_dict = {}
for name, param in state_dict.items():
if name.startswith("v_head."):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace("pretrained_model.", "")] = param
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
def get_current_device() -> torch.device:
r"""
Gets the current available device.
@@ -201,7 +147,7 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
@@ -220,7 +166,7 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
def has_tokenized_data(path: os.PathLike) -> bool:
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""