refactor mllm param logic

Former-commit-id: b895c190945cf5d991cb4e4dea2ae73cc9c8d246
This commit is contained in:
hiyouga
2025-01-10 15:41:54 +00:00
parent 1675712a4c
commit dc65ecdf09
10 changed files with 198 additions and 62 deletions

View File

@@ -14,6 +14,7 @@
import os
import pytest
import torch
from llamafactory.extras.misc import get_current_device
@@ -39,16 +40,11 @@ TRAIN_ARGS = {
}
def test_checkpointing_enable():
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert getattr(module, "gradient_checkpointing") is True
def test_checkpointing_disable():
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert getattr(module, "gradient_checkpointing") is False
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
def test_unsloth_gradient_checkpointing():