[feat] fp8 training (#8960)

Co-authored-by: Benjamin Feuer <penfever@gmail.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Ben Feuer
2025-09-30 23:32:53 -07:00
committed by GitHub
parent e2b1594d31
commit 1c44b60e3e
8 changed files with 322 additions and 3 deletions

View File

@@ -213,6 +213,23 @@ class QuantizationArguments:
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass

View File

@@ -131,6 +131,14 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies(
model_args: "ModelArguments",