@@ -17,16 +17,8 @@ r"""Efficient fine-tuning of large language models.
|
||||
Level:
|
||||
api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
Dependency graph:
|
||||
transformers>=4.41.2,<=4.43.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
datasets>=2.16.0,<=3.5.0
|
||||
accelerate>=0.34.0,<=1.6.0
|
||||
peft>=0.14.0,<=0.15.1
|
||||
trl>=0.8.6,<=0.9.6
|
||||
|
||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
Force check imports: FORCE_CHECK_IMPORTS=1
|
||||
Force using torchrun: FORCE_TORCHRUN=1
|
||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||
|
||||
@@ -21,7 +21,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -86,7 +86,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _concatenate_list(input_list):
|
||||
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
|
||||
r"""Concatenate a list of lists, numpy arrays or torch tensors.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.43.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.5.0")
|
||||
check_version("accelerate>=0.34.0,<=1.6.0")
|
||||
check_version("peft>=0.14.0,<=0.15.1")
|
||||
|
||||
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if not is_transformers_version_greater_than("4.48.0"):
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: torch.Tensor = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
check_version("transformers>=4.43.0,<4.48.0", mandatory=True)
|
||||
check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -43,11 +43,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -116,5 +111,7 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
@@ -40,6 +40,11 @@ class CustomTrainer(Trainer):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
# avoid wrong loss under gradient accumulation
|
||||
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
if processor is not None:
|
||||
|
||||
@@ -60,6 +60,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
# avoid wrong loss under gradient accumulation
|
||||
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
Reference in New Issue
Block a user