refactor adapter hparam

Former-commit-id: f82aece9ebd6df83a7a005cc7cbbcec07fa6e14d
This commit is contained in:
hiyouga
2023-12-15 20:53:11 +08:00
parent 27ef5b1aa7
commit f902b0d420
21 changed files with 302 additions and 311 deletions

View File

@@ -27,8 +27,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type == "full" and is_trainable:
@@ -44,6 +44,7 @@ def init_adapter(
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
@@ -62,30 +63,31 @@ def init_adapter(
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
checkpoint_to_resume = None
adapter_to_resume = None
if model_args.checkpoint_dir is not None:
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable):
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
adapter_to_merge = model_args.adapter_name_or_path
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
for adapter in adapter_to_merge:
model = PeftModel.from_pretrained(model, adapter)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
else:
@@ -105,7 +107,7 @@ def init_adapter(
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model

View File

@@ -1,41 +1,26 @@
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.integrations import is_deepspeed_zero3_enabled
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.patches import patch_config, patch_model, patch_valuehead_model, patch_tokenizer, register_autoclass
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
require_version("transformers>=4.36.0", "To fix: pip install transformers>=4.36.0")
require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.1")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
@@ -47,7 +32,7 @@ def load_model_and_tokenizer(
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -70,73 +55,15 @@ def load_model_and_tokenizer(
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)
patch_tokenizer(tokenizer)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
# if getattr(config, "model_type", None) == "llama":
# setattr(config, "group_size_ratio", 0.25)
# logger.info("Using shift short attention with group_size_ratio=1/4.")
# else:
# logger.warning("Current model does not support shift short attention.")
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, model_args, is_trainable)
# Set FlashAttention-2
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Current model automatically enables FlashAttention if installed.")
else:
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
if model_args.flash_attn and is_flash_attn2_available():
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
# Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None):
@@ -168,33 +95,16 @@ def load_model_and_tokenizer(
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
model_args.model_name_or_path,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
# Resize token embeddings
patch_model(model)
register_autoclass(config, model, tokenizer)
resize_embedding_layer(model, tokenizer)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2 and ChatGLM3)
if getattr(config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
@@ -202,19 +112,10 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
return self.pretrained_model.get_input_embeddings()
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
patch_valuehead_model(model)
vhead_params = load_valuehead_params(model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable:

View File

@@ -41,15 +41,19 @@ _EVAL_CLS = Tuple[
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None and len(model_args.checkpoint_dir) != 1:
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
@@ -139,11 +143,17 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args_dict.update(dict(ddp_find_unused_parameters=False))
training_args = Seq2SeqTrainingArguments(**training_args_dict)
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
@@ -158,7 +168,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
))
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
))

View File

@@ -0,0 +1,94 @@
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import infer_optim_dtype
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__)
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
else:
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
# if getattr(config, "model_type", None) == "llama":
# setattr(config, "group_size_ratio", 0.25)
# logger.info("Using shift short attention with group_size_ratio=1/4.")
# else:
# logger.warning("Current model does not support shift short attention.")
def patch_model(model: "PreTrainedModel"):
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
return self.pretrained_model.get_input_embeddings()
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizerBase"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -1,5 +1,4 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
@@ -86,29 +85,26 @@ def get_modelcard_args(
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
def load_valuehead_params(model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
if model_args.adapter_name_or_path is not None:
path_or_repo_id = model_args.adapter_name_or_path[-1]
else:
path_or_repo_id = model_args.model_name_or_path
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
if "token" in inspect.signature(cached_file).parameters:
kwargs["token"] = model_args.hf_hub_token
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
kwargs["use_auth_token"] = model_args.hf_hub_token
else:
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
@@ -116,6 +112,7 @@ def load_valuehead_params(
try:
from safetensors import safe_open
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
logger.info("Loaded valuehead from {}".format(path_or_repo_id))
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),