remove gc warnings in DPO&KTO

Former-commit-id: b649bdcbafb464a638387429b770fe258b41f8af
This commit is contained in:
hiyouga
2024-06-03 22:53:54 +08:00
parent 0655a183d3
commit 4c1f015eca
3 changed files with 20 additions and 6 deletions

View File

@@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:
@@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
@@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
ref_context = get_ref_context(self.accelerator, model)
else:
ref_model = self.ref_model
ref_context = nullcontext()