support full-parameter PPO

Former-commit-id: 4af967d69475e1c9fdf1a7983cd6b83bd431abff
This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8263b2d32d
commit 7a3a0144a5
19 changed files with 280 additions and 140 deletions

View File

@@ -64,6 +64,16 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.