[trainer] add dpo/kto fsdp fsdp2 support (#10127)

This commit is contained in:
Username_Full
2026-02-04 23:27:12 +08:00
committed by GitHub
parent 8bedfafa4e
commit 92fa3df4c4
2 changed files with 16 additions and 2 deletions

View File

@@ -25,8 +25,8 @@ import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.models.utils import prepare_deepspeed, prepare_fsdp
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
@@ -97,6 +97,13 @@ class CustomDPOTrainer(DPOTrainer):
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()

View File

@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.models.utils import prepare_deepspeed, prepare_fsdp
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
@@ -99,6 +99,13 @@ class CustomKTOTrainer(KTOTrainer):
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
if self.accelerator.is_fsdp2:
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
else:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()