update ppo trainer

Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent e585950c54
commit 28258aecd2
7 changed files with 68 additions and 41 deletions

View File

@@ -16,7 +16,10 @@ try:
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser