disentangle model from tuner and rename modules

Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
This commit is contained in:
hiyouga
2023-11-15 16:29:09 +08:00
parent 81530133ff
commit 09a4474e7f
57 changed files with 324 additions and 263 deletions

View File

@@ -0,0 +1,3 @@
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card

View File

@@ -0,0 +1,123 @@
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
checkpoint_to_resume = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else:
target_modules = finetuning_args.lora_target
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=target_modules,
modules_to_save=finetuning_args.additional_target
)
model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@@ -0,0 +1,236 @@
import os
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter, load_valuehead_params
from llmtuner.model.utils import prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
return model, tokenizer

View File

@@ -0,0 +1,212 @@
import os
import torch
import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
logger = get_logger(__name__)
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Check arguments
data_args.init_for_training(training_args.seed)
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"]:
if finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation.")
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
# postprocess training_args
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if (
model_args.checkpoint_dir is not None
and len(model_args.checkpoint_dir) != 1
and finetuning_args.finetuning_type != "lora"
):
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

132
src/llmtuner/model/utils.py Normal file
View File

@@ -0,0 +1,132 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
else:
linear_cls = torch.nn.Linear
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def generate_model_card(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
}
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if finetuning_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
logger.info("Upcasting weights in layernorm in float32.")
if finetuning_args.neft_alpha > 1e-6:
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model