implement rm server #1543

Former-commit-id: 2e5bb6888c86079493456c2ddd525f8c52b9963e
This commit is contained in:
hiyouga
2023-12-03 20:52:54 +08:00
parent 4a14099cfd
commit 29545d0e5e
11 changed files with 104 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Optional, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
@@ -35,7 +35,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@@ -51,13 +51,17 @@ def create_ref_model(
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
ref_model, _ = load_model_and_tokenizer(
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
ref_model, _ = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
return ref_model
@@ -71,7 +75,9 @@ def create_reward_model(
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
if finetuning_args.reward_model_type == "api":
raise NotImplementedError
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
@@ -93,7 +99,9 @@ def create_reward_model(
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model