mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-05 09:33:09 +00:00
[trainer] add dpo/kto fsdp fsdp2 support (#10127)
This commit is contained in:
@@ -25,8 +25,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
from trl.trainer.utils import prepare_deepspeed
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@@ -97,6 +97,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
): # quantized models are already set on the correct device
|
): # quantized models are already set on the correct device
|
||||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||||
|
elif self.is_fsdp_enabled:
|
||||||
|
if self.accelerator.is_fsdp2:
|
||||||
|
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||||
|
|
||||||
|
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
self.ref_model.eval()
|
self.ref_model.eval()
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import KTOTrainer
|
from trl import KTOTrainer
|
||||||
|
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
from trl.trainer.utils import prepare_deepspeed
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@@ -99,6 +99,13 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
): # quantized models are already set on the correct device
|
): # quantized models are already set on the correct device
|
||||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||||
|
elif self.is_fsdp_enabled:
|
||||||
|
if self.accelerator.is_fsdp2:
|
||||||
|
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||||
|
|
||||||
|
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
self.ref_model.eval()
|
self.ref_model.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user