refactor mllm param logic
Former-commit-id: b895c190945cf5d991cb4e4dea2ae73cc9c8d246
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
@@ -39,16 +40,11 @@ TRAIN_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
def test_checkpointing_enable():
|
||||
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
|
||||
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
|
||||
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
|
||||
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is True
|
||||
|
||||
|
||||
def test_checkpointing_disable():
|
||||
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
|
||||
|
||||
|
||||
def test_unsloth_gradient_checkpointing():
|
||||
|
||||
Reference in New Issue
Block a user