[model] add llama4 (#7611)

This commit is contained in:
hoshi-hiyouga
2025-04-06 13:42:31 +08:00
committed by GitHub
parent d4cfa9507e
commit 831e7f1cfd
11 changed files with 167 additions and 8 deletions

View File

@@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: torch.nn.Module = func.__self__
if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):

View File

@@ -203,6 +203,12 @@ _register_composite_model(
)
_register_composite_model(
model_type="llama4",
vision_model_keys=["vision_model"],
)
_register_composite_model(
model_type="llava",
)