Former-commit-id: 9a30ee5009040afbc524dbac0dad99904b2adf5f
This commit is contained in:
hiyouga
2023-09-12 16:10:10 +08:00
parent 8b0e6b9d1b
commit e19a44c12b
8 changed files with 26 additions and 23 deletions

View File

@@ -92,6 +92,8 @@ def init_adapter(
target_modules=target_modules
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
model.base_model.peft_config = model.peft_config
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))

View File

@@ -4,7 +4,6 @@ import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -41,7 +40,7 @@ logger = get_logger(__name__)
check_min_version("4.30.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
@@ -133,11 +132,11 @@ def load_model_and_tokenizer(
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
import transformers.models.llama.modeling_llama as LlamaModule
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
LlamaModule.LlamaAttention = LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"):
import llmtuner.extras.patches.flash_llama as FlashLlama
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
@@ -199,11 +198,11 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
@@ -212,7 +211,8 @@ def load_model_and_tokenizer(
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference