[model] add gpt oss (#8826)

This commit is contained in:
Yaowei Zheng
2025-08-06 05:56:46 +08:00
committed by GitHub
parent c709c0378d
commit 4dfad24902
10 changed files with 97 additions and 16 deletions

View File

@@ -18,7 +18,7 @@
import gc
import os
import socket
from typing import TYPE_CHECKING, Any, Literal, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
import torch.distributed as dist
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
check_version("transformers>=4.49.0,<=4.55.0")
check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=1.3.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
@@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype":
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if _is_bf16_available and model_dtype == torch.bfloat16:
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
return torch.bfloat16
elif _is_fp16_available:
return torch.float16