refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
@@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
@@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||
if response_masks is not None:
|
||||
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
@@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
if response_masks is not None:
|
||||
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
||||
|
||||
if return_logits:
|
||||
all_logits.append(logits)
|
||||
|
||||
Reference in New Issue
Block a user