implement rm server #1543

Former-commit-id: 2e5bb6888c86079493456c2ddd525f8c52b9963e
This commit is contained in:
hiyouga
2023-12-03 20:52:54 +08:00
parent 4a14099cfd
commit 29545d0e5e
11 changed files with 104 additions and 24 deletions

View File

@@ -49,7 +49,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
add_valuehead: Optional[bool] = False
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -205,10 +205,9 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage in ["rm", "ppo"]:
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
@@ -224,6 +223,9 @@ def load_model_and_tokenizer(
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(