@@ -20,19 +20,17 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2
|
||||
datasets>=2.16.0
|
||||
accelerate>=0.30.1
|
||||
peft>=0.11.1
|
||||
trl>=0.8.6
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
datasets>=2.16.0,<=2.20.0
|
||||
accelerate>=0.30.1,<=0.32.0
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
patcher:
|
||||
transformers==4.41.2 (chatglm)
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
"""
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
@@ -535,10 +535,6 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
|
||||
},
|
||||
"Gemma-2-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||
},
|
||||
"Gemma-2-9B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
||||
@@ -547,6 +543,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
||||
},
|
||||
"Gemma-2-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||
},
|
||||
"Gemma-2-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||
|
||||
@@ -79,11 +79,11 @@ def check_dependencies() -> None:
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
||||
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
|
||||
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
|
||||
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
|
||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
require_version("datasets>=2.16.0,<=2.20.0", "To fix: pip install datasets>=2.16.0,<=2.20.0")
|
||||
require_version("accelerate>=0.30.1,<=0.32.0", "To fix: pip install accelerate>=0.30.1,<=0.32.0")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
|
||||
@@ -70,6 +70,11 @@ def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_greater_than_4_43():
|
||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def configure_attn_implementation(
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
|
||||
@@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_attention_forward(
|
||||
self: "LlamaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
hidden_states: "torch.Tensor",
|
||||
attention_mask: Optional["torch.Tensor"] = None,
|
||||
position_ids: Optional["torch.LongTensor"] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
@@ -68,7 +70,11 @@ def llama_attention_forward(
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -130,14 +136,15 @@ def llama_attention_forward(
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_flash_attention_2_forward(
|
||||
self: "LlamaFlashAttention2",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
hidden_states: "torch.Tensor",
|
||||
attention_mask: Optional["torch.Tensor"] = None,
|
||||
position_ids: Optional["torch.LongTensor"] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
@@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_states.size(1),
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
@@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
def llama_sdpa_attention_forward(
|
||||
self: "LlamaSdpaAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
hidden_states: "torch.Tensor",
|
||||
attention_mask: Optional["torch.Tensor"] = None,
|
||||
position_ids: Optional["torch.LongTensor"] = None,
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
if output_attentions:
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
@@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
return
|
||||
|
||||
import transformers.models
|
||||
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
|
||||
@@ -162,11 +162,12 @@ class PissaConvertCallback(TrainerCallback):
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
model.save_pretrained(
|
||||
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
||||
)
|
||||
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
|
||||
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||
model.set_adapter("default")
|
||||
if "pissa_init" in model.peft_config.keys():
|
||||
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
|
||||
model.delete_adapter("pissa_init")
|
||||
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def create_web_demo() -> "gr.Blocks":
|
||||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
|
||||
engine.manager.add_elems("top", dict(lang=lang))
|
||||
|
||||
_, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
|
||||
@@ -362,7 +362,6 @@ LOCALES = {
|
||||
"label": "학습률",
|
||||
"info": "AdamW의 초기 학습률.",
|
||||
},
|
||||
|
||||
},
|
||||
"num_train_epochs": {
|
||||
"en": {
|
||||
|
||||
Reference in New Issue
Block a user