implement rm server #1543

Former-commit-id: 2e5bb6888c86079493456c2ddd525f8c52b9963e
This commit is contained in:
hiyouga
2023-12-03 20:52:54 +08:00
parent 4a14099cfd
commit 29545d0e5e
11 changed files with 104 additions and 24 deletions

View File

@@ -1,4 +1,5 @@
import torch
import tiktoken
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
@@ -22,8 +23,11 @@ class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.tokenizer.padding_side = "left"
self.can_generate = (finetuning_args.stage == "sft")
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
@@ -130,3 +134,40 @@ class ChatModel:
thread.start()
yield from streamer
@torch.inference_mode()
def get_scores(
self,
batch_input: List[str],
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt",
**kwargs
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(self.model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
scores.append(values[i, length-1].nan_to_num().item())
return scores