[model] add gpt oss (#8826)
This commit is contained in:
@@ -1063,6 +1063,16 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="gpt",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="granite3",
|
||||
format_user=StringFormatter(
|
||||
|
||||
@@ -945,6 +945,21 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GPT-OSS-20B-Thinking": {
|
||||
DownloadSource.DEFAULT: "openai/gpt-oss-20b",
|
||||
DownloadSource.MODELSCOPE: "openai/gpt-oss-20b",
|
||||
},
|
||||
"GPT-OSS-120B-Thinking": {
|
||||
DownloadSource.DEFAULT: "openai/gpt-oss-120b",
|
||||
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
|
||||
},
|
||||
},
|
||||
template="gpt",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Granite-3.0-1B-A400M-Base": {
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import gc
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
|
||||
check_version("transformers>=4.49.0,<=4.55.0")
|
||||
check_version("datasets>=2.16.0,<=3.6.0")
|
||||
check_version("accelerate>=1.3.0,<=1.7.0")
|
||||
check_version("peft>=0.14.0,<=0.15.2")
|
||||
@@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype":
|
||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
return torch.float16
|
||||
|
||||
@@ -156,10 +156,10 @@ def load_model(
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
else:
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForImageTextToText
|
||||
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||
load_class = AutoModelForVision2Seq
|
||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||
load_class = AutoModelForSeq2SeqLM
|
||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||
|
||||
Reference in New Issue
Block a user