refactor model_dtype, fix PPO trainer

Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent a2d08ce961
commit 3198a7e5f4
10 changed files with 104 additions and 119 deletions

View File

@@ -3,6 +3,19 @@ import torch
from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
try:
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
@@ -49,7 +62,22 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
def get_logits_processor() -> LogitsProcessorList:
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor