Former-commit-id: cbff0ea0db6971f8ced503a2f0cb6bc43e7037ac
This commit is contained in:
hiyouga
2024-07-05 00:58:05 +08:00
parent ee1c786a12
commit edc8aefa59
2 changed files with 8 additions and 1 deletions

View File

@@ -79,6 +79,9 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
@@ -87,5 +90,6 @@ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
new_source_len = max(cutoff_len - new_target_len, 0)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len