update hparams

Former-commit-id: 1c4feac44192b1f540208837f5a530b0d3f5fb37
This commit is contained in:
hiyouga
2024-07-03 23:18:58 +08:00
parent 8ac4f87c91
commit 5acaa476d6
8 changed files with 72 additions and 28 deletions

View File

@@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +22,44 @@ import torch
from transformers import DataCollatorForSeq2Seq
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```
[1, 1, 2, 2, 2, 0]
```
->
```
[[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
return attention_mask_4d
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""