disentangle model from tuner and rename modules
Former-commit-id: 02cbf91e7e424f8379c1fed01b82a5f7a83b6947
This commit is contained in:
3
src/llmtuner/model/__init__.py
Normal file
3
src/llmtuner/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from llmtuner.model.loader import load_model_and_tokenizer
|
||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||
from llmtuner.model.utils import dispatch_model, generate_model_card
|
||||
123
src/llmtuner/model/adapter.py
Normal file
123
src/llmtuner/model/adapter.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.model.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
num_layers = getattr(model.config, "num_layers")
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
||||
|
||||
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
checkpoint_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if is_trainable and finetuning_args.resume_lora_training:
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token,
|
||||
"revision": model_args.model_revision
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
236
src/llmtuner/model/loader.py
Normal file
236
src/llmtuner/model/loader.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.model.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.model.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if is_flash_attn2_available():
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pre-trained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
model.__class__.register_for_auto_class()
|
||||
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
212
src/llmtuner/model/parser.py
Normal file
212
src/llmtuner/model/parser.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_TRAIN_CLS = Tuple[
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_ARGS = [
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_CLS = Tuple[
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_EVAL_ARGS = [
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
_EVAL_CLS = Tuple[
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
|
||||
|
||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
||||
))
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
||||
132
src/llmtuner/model/utils.py
Normal file
132
src/llmtuner/model/utils.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
linear_cls = torch.nn.Linear
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, linear_cls)
|
||||
and not any([output_layer in name for output_layer in output_layer_names])
|
||||
):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def generate_model_card(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
||||
}
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
if finetuning_args.upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting weights in layernorm in float32.")
|
||||
|
||||
if finetuning_args.neft_alpha > 1e-6:
|
||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear):
|
||||
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
||||
return args[0].to(output_layer.weight.dtype)
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
return model
|
||||
Reference in New Issue
Block a user