fix generation bug #532

Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
This commit is contained in:
hiyouga
2023-08-17 22:21:34 +08:00
parent e993e717a5
commit fa1893b59c
5 changed files with 15 additions and 46 deletions

View File

@@ -2,9 +2,8 @@ import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import (
LogitsProcessor,
LogitsProcessorList,
StoppingCriteria,
StoppingCriteriaList
InfNanRemoveLogitsProcessor,
LogitsProcessorList
)
from llmtuner.extras.constants import LAYERNORM_NAMES
@@ -33,37 +32,12 @@ class AverageMeter:
self.avg = self.sum / self.count
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_ids: List[int]) -> None:
super().__init__()
self.stop_ids = stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopWordsCriteria(stop_ids))
return stopping_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.