support full-parameter PPO
Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
This commit is contained in:
@@ -1,17 +1,89 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
stage: Literal["ppo", "dpo"]
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
if finetuning_args.ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
return ref_model
|
||||
|
||||
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> "AutoModelForCausalLMWithValueHead":
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return None
|
||||
else:
|
||||
reward_model_args_dict = model_args.to_dict()
|
||||
reward_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit
|
||||
))
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return reward_model
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
@@ -41,6 +113,9 @@ def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
"""
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
@@ -76,6 +151,32 @@ def generate_model_card(
|
||||
}
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
path_or_repo_id: str,
|
||||
model_args: "ModelArguments"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {
|
||||
"path_or_repo_id": path_or_repo_id,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
return None
|
||||
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
|
||||
Reference in New Issue
Block a user