From e944dc442cac475604f115a84f307ac1051cdc89 Mon Sep 17 00:00:00 2001 From: yanglele <111259718+ymxyll@users.noreply.github.com> Date: Tue, 6 Jan 2026 23:07:12 +0800 Subject: [PATCH] [feature] add support for EAFT loss (#9720) Co-authored-by: Yaowei Zheng Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../extras/eaft/qwen25_05b_eaft_full.yaml | 40 ++++++++++++++ src/llamafactory/hparams/finetuning_args.py | 8 +++ src/llamafactory/train/sft/trainer.py | 9 +++ src/llamafactory/train/trainer_utils.py | 55 +++++++++++++++++++ 4 files changed, 112 insertions(+) create mode 100644 examples/extras/eaft/qwen25_05b_eaft_full.yaml diff --git a/examples/extras/eaft/qwen25_05b_eaft_full.yaml b/examples/extras/eaft/qwen25_05b_eaft_full.yaml new file mode 100644 index 000000000..904858f73 --- /dev/null +++ b/examples/extras/eaft/qwen25_05b_eaft_full.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: Qwen/Qwen2.5-0.5B-Instruct +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +use_eaft_loss: true + +### dataset +dataset: identity,alpaca_en_demo +template: qwen +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: qwen2.5-0_5b/full/sft_eaft +logging_steps: 1 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + + +### train +per_device_train_batch_size: 2 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + + diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 7ab2ce3bc..c089ca67c 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -490,6 +490,14 @@ class FinetuningArguments( default=False, metadata={"help": "Whether to use the DFT loss."}, ) + use_eaft_loss: bool = field( + default=False, + metadata={"help": "Whether to use the EAFT loss."}, + ) + eaft_alpha: float = field( + default=1.0, + metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."}, + ) freeze_vision_tower: bool = field( default=True, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0ee389b3c..fff990666 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -87,6 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.compute_loss_func = dft_loss_func + + elif finetuning_args.use_eaft_loss: + from ..trainer_utils import eaft_loss_func + + self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( + outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha + ) + + if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization verify_fp8_status(self.accelerator, training_args) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index ec291e447..3967d2cc7 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -679,6 +679,61 @@ def _dft_cross_entropy( return loss +def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0): + logits = outputs.get("logits") + if logits is None: + return outputs.get("loss", torch.tensor(0.0)) + + logits = logits.float() + vocab_size = logits.size(-1) + labels = torch.nn.functional.pad(labels, (0, 1), value=-100) + shift_labels = labels[..., 1:].contiguous() + logits = logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(logits.device) + + loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha) + return loss + + +def _eaft_cross_entropy( + source: torch.Tensor, + target: torch.Tensor, + num_items_in_batch: Optional[torch.Tensor] = None, + alpha: float = 1.0, + ignore_index: int = -100, +) -> torch.Tensor: + per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none") + valid_mask = target != ignore_index + if not valid_mask.any(): + return torch.tensor(0.0, device=source.device, dtype=source.dtype) + + valid_losses = per_token_loss[valid_mask] + + with torch.no_grad(): + source_detached = source[valid_mask].detach() + + topk_val, _ = torch.topk(source_detached, k=20, dim=-1) + logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True) + log_probs_topk = topk_val - logsumexp_topk + probs_topk = torch.exp(log_probs_topk) + entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) + + entropy_term = entropy_approx / 3.0 + adaptive_weight = torch.pow(entropy_term, alpha) + + weighted_losses = valid_losses * adaptive_weight + + if num_items_in_batch is not None: + total_loss = weighted_losses.sum() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.to(total_loss.device) + loss = total_loss / num_items_in_batch + else: + loss = weighted_losses.mean() + return loss + + def nested_detach( tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]], clone: bool = False,