support activation offloading via unsloth gc

Former-commit-id: d3d0dd0feba3ca6f0ae970d5856bec989d26ef67
This commit is contained in:
hiyouga
2024-09-08 01:22:19 +08:00
parent 7f71276ad8
commit 294a103ead
3 changed files with 58 additions and 7 deletions

View File

@@ -109,6 +109,7 @@ def calculate_mfu(
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
@@ -119,6 +120,7 @@ def calculate_mfu(
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,