[model] add llama4 (#7611)
This commit is contained in:
@@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: torch.nn.Module = func.__self__
|
||||
if isinstance(func, partial):
|
||||
module: torch.nn.Module = func.func.__self__
|
||||
else:
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
|
||||
@@ -203,6 +203,12 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llama4",
|
||||
vision_model_keys=["vision_model"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llava",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user