Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
@@ -66,15 +68,56 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
is_mergeable = True
|
||||
|
||||
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
||||
if model_args.compute_dtype == torch.bfloat16:
|
||||
setattr(config, "bf16", True)
|
||||
else:
|
||||
setattr(config, "fp16", True)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
|
||||
if is_trainable:
|
||||
logger.warning("Qwen model does not support rope scaling in training.")
|
||||
else:
|
||||
setattr(config, "use_dynamic_ntk", True)
|
||||
setattr(config, "use_logn_attn", True)
|
||||
logger.info("Using dynamic NTK scaling.")
|
||||
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA models
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length <= current_max_length:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
if model_args.quantization_bit is not None:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
@@ -95,7 +138,7 @@ def load_model_and_tokenizer(
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
# Load and prepare pre-trained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
@@ -104,6 +147,10 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
@@ -116,10 +163,10 @@ def load_model_and_tokenizer(
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
@@ -133,9 +180,11 @@ def load_model_and_tokenizer(
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
||||
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
|
||||
@@ -151,13 +151,16 @@ def get_train_args(
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
|
||||
|
||||
@@ -35,7 +35,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
||||
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user