fix yi vl vllm infer
Former-commit-id: de54e5d7ec06dd7c20ec82c9ff032fc16cd50244
This commit is contained in:
@@ -16,25 +16,51 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(torch.nn.Module):
|
||||
def __init__(self, config: "LlavaConfig"):
|
||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
def __init__(self, config: "LlavaConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if config is None:
|
||||
return
|
||||
|
||||
self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
|
||||
self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
|
||||
def forward(self, image_features):
|
||||
def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_3(hidden_states)
|
||||
hidden_states = self.linear_4(hidden_states)
|
||||
if hidden_states.dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.linear_1.weight.dtype
|
||||
|
||||
logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||
hidden_states = hidden_states.to(target_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
|
||||
super().__init__(config=None)
|
||||
|
||||
self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
|
||||
self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
|
||||
self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
|
||||
self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
|
||||
self.act = torch.nn.GELU()
|
||||
|
||||
|
||||
def autocast_projector_dtype(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||
) -> None:
|
||||
@@ -53,5 +79,6 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "model_type", None) == "llava":
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjector
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
Reference in New Issue
Block a user