update gradio, support multiple resp in api
Former-commit-id: a34263e7c0e07a080276d164cdab9f12f1d767d2
This commit is contained in:
@@ -180,15 +180,15 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
response_length = response_index[-1] + 2 # save the EOS token
|
||||
response_length = response_index[-1].item() + 2 # save the EOS token
|
||||
else:
|
||||
response_length = response_index[-1] + 1
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
@@ -216,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@@ -266,7 +266,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
|
||||
Reference in New Issue
Block a user