mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-22 12:23:08 +00:00
[fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -142,15 +142,6 @@ def _verify_model_args(
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
# Validate advanced training features
|
||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
@@ -347,6 +338,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if training_args.fp8 and training_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -363,6 +357,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
|
||||
@@ -92,7 +92,29 @@ class RayArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
Reference in New Issue
Block a user