From 92fa3df4c4e4e338ecc1cbdc6530b118337286e3 Mon Sep 17 00:00:00 2001 From: Username_Full Date: Wed, 4 Feb 2026 23:27:12 +0800 Subject: [PATCH] [trainer] add dpo/kto fsdp fsdp2 support (#10127) --- src/llamafactory/train/dpo/trainer.py | 9 ++++++++- src/llamafactory/train/kto/trainer.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7780e20ee..f0ecdba2c 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -25,8 +25,8 @@ import torch import torch.nn.functional as F from transformers import Trainer from trl import DPOTrainer +from trl.models.utils import prepare_deepspeed, prepare_fsdp from trl.trainer import disable_dropout_in_model -from trl.trainer.utils import prepare_deepspeed from typing_extensions import override from ...extras.constants import IGNORE_INDEX @@ -97,6 +97,13 @@ class CustomDPOTrainer(DPOTrainer): getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) ): # quantized models are already set on the correct device self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + if self.accelerator.is_fsdp2: + from accelerate.utils.fsdp_utils import fsdp2_prepare_model + + self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model) + else: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model.eval() diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index eea92f3be..1d679821f 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Literal, Optional, Union import torch from transformers import Trainer from trl import KTOTrainer +from trl.models.utils import prepare_deepspeed, prepare_fsdp from trl.trainer import disable_dropout_in_model -from trl.trainer.utils import prepare_deepspeed from typing_extensions import override from ...extras.constants import IGNORE_INDEX @@ -99,6 +99,13 @@ class CustomKTOTrainer(KTOTrainer): getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) ): # quantized models are already set on the correct device self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + if self.accelerator.is_fsdp2: + from accelerate.utils.fsdp_utils import fsdp2_prepare_model + + self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model) + else: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model.eval()