Former-commit-id: 392bdaf1ea9e5baf6289f2d4415a175dd55a479d
This commit is contained in:
hiyouga
2024-09-11 17:36:42 +08:00
parent 588ea95732
commit 7fd0d2fc2f
4 changed files with 12 additions and 22 deletions

View File

@@ -246,29 +246,18 @@ class HuggingfaceEngine(BaseEngine):
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
max_length = input_kwargs.pop("max_length", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer(
inputs: Dict[str, "torch.Tensor"] = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True,
add_special_tokens=False,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
@override