ppo support rm server

Former-commit-id: 20b0edf16f5b42cb2c4a795674647afb68cb3a4a
This commit is contained in:
hiyouga
2023-12-03 21:38:51 +08:00
parent 29545d0e5e
commit 60aea7521b
5 changed files with 47 additions and 15 deletions

View File

@@ -167,7 +167,8 @@ class ChatModel:
scores = []
for i in range(input_ids.size(0)):
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
scores.append(values[i, length-1].nan_to_num().item())
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores