[fix] fp8: add Transformer Engine backend support (#9705)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Santosh Bhavani
2025-12-31 18:18:02 -08:00
committed by GitHub
parent 6fe6bd290b
commit 355d5c5e5a
6 changed files with 128 additions and 70 deletions

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging
if TYPE_CHECKING:
from ..hparams import ModelArguments
from ..hparams import TrainingArguments
logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
"""
if not model_args.fp8:
if not training_args.fp8:
return []
try:
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
from accelerate.utils import AORecipeKwargs
backend = getattr(training_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
backend = getattr(model_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
try:
# Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used
config = None
@@ -83,7 +93,7 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +102,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
Returns:
"fp8" if FP8 is enabled, None otherwise
"""
return "fp8" if model_args.fp8 else None
return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None:
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +122,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration.
Args:
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
import os
if not model_args.fp8:
if not training_args.fp8:
return
# Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +132,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto":
os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args)
fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation.
Args:
accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration
training_args: Training arguments containing FP8 configuration
"""
if not model_args.fp8:
if not training_args.fp8:
return
# Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto")
backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto":
logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +177,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True