add test case

Former-commit-id: c452d65e1551074dddd1d87517c0d44dc014c6aa
This commit is contained in:
hiyouga
2024-09-08 01:40:49 +08:00
parent 294a103ead
commit 158e0e1f63
2 changed files with 28 additions and 12 deletions

View File

@@ -51,6 +51,12 @@ def test_checkpointing_disable():
assert getattr(module, "gradient_checkpointing") is False
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():