@@ -9,7 +9,7 @@ from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -98,16 +98,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||
return -all_logps
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
@@ -127,28 +117,23 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
logps = self.get_batch_logps(
|
||||
logits=logits,
|
||||
labels=batch["{}labels".format(prefix)],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
return logits, logps
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
return logps, logps / valid_length
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps = self.forward(model, batch)
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps = self.forward(model, batch, prefix="kl_")
|
||||
kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
|
||||
if len(target_logps) != len(batch["kto_tags"]):
|
||||
raise ValueError("Mismatched shape of inputs and labels.")
|
||||
|
||||
chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]]
|
||||
chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]]
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||
chosen_logps = target_logps[batch["kto_tags"]]
|
||||
rejected_logps = target_logps[~batch["kto_tags"]]
|
||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
||||
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
@@ -164,13 +149,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_kl_logps,
|
||||
) = self.concatenated_forward(ref_model, batch)
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
|
||||
ref_model, batch
|
||||
)
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||
|
||||
@@ -183,14 +164,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
_,
|
||||
policy_kl_logps,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
|
||||
self.concatenated_forward(model, batch)
|
||||
)
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
||||
model, batch
|
||||
)
|
||||
@@ -205,8 +181,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
losses = losses.nanmean()
|
||||
|
||||
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
|
||||
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
|
||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
|
||||
sft_loss = -policy_chosen_logps_avg
|
||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
|
||||
|
||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||
|
||||
Reference in New Issue
Block a user