refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
hiyouga
2023-09-01 19:00:45 +08:00
parent 93be211f80
commit e5b72c6a77
19 changed files with 108 additions and 126 deletions

View File

@@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
):
r"""
Calculates model outputs in multiple batches.
@@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
@@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)