Former-commit-id: 2a44da678a5e360a9c0f9056397ac9e801329321
This commit is contained in:
hiyouga
2024-06-07 04:18:05 +08:00
parent 8c4c2e580c
commit d0edcde4ea
7 changed files with 47 additions and 54 deletions

View File

@@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
if TYPE_CHECKING:
@@ -98,16 +98,6 @@ class CustomKTOTrainer(KTOTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -127,28 +117,23 @@ class CustomKTOTrainer(KTOTrainer):
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps = self.get_batch_logps(
logits=logits,
labels=batch["{}labels".format(prefix)],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
return logits, logps
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps = self.forward(model, batch)
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
_, kl_logps = self.forward(model, batch, prefix="kl_")
kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]]
chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
chosen_logps = target_logps[batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
@@ -164,13 +149,9 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_kl_logps,
) = self.concatenated_forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch
)
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@@ -183,14 +164,9 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
_,
policy_kl_logps,
) = self.concatenated_forward(model, batch)
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
self.concatenated_forward(model, batch)
)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch
)
@@ -205,8 +181,8 @@ class CustomKTOTrainer(KTOTrainer):
losses = losses.nanmean()
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)