support full-parameter PPO

Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8263b2d32d
commit 7a3a0144a5
19 changed files with 280 additions and 140 deletions

View File

@@ -1,3 +1,3 @@
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card
from llmtuner.model.utils import create_ref_model, create_reward_model, dispatch_model, generate_model_card

View File

@@ -1,8 +1,5 @@
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
@@ -98,30 +95,3 @@ def init_adapter(
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@@ -15,7 +15,6 @@ from transformers import (
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
@@ -24,12 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter, load_valuehead_params
from llmtuner.model.utils import prepare_model_for_training
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -74,6 +73,7 @@ def load_model_and_tokenizer(
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
@@ -145,7 +145,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@@ -165,10 +165,10 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead).
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
@@ -186,7 +186,7 @@ def load_model_and_tokenizer(
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files.
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
@@ -200,25 +200,15 @@ def load_model_and_tokenizer(
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
# Prepare model for inference
if not is_trainable:

View File

@@ -1,17 +1,89 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, Tuple, Union
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
return reward_model
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
@@ -41,6 +113,9 @@ def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
@@ -76,6 +151,32 @@ def generate_model_card(
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",