Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
@@ -1,7 +1,22 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,21 +28,33 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||
if model_args.flash_attn == "auto":
|
||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn != "disabled":
|
||||
logger.warning(
|
||||
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
||||
"Will proceed at your own risk.".format(model_args.flash_attn)
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
elif model_args.flash_attn == "disabled":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn2_available():
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
@@ -60,15 +78,12 @@ def _fp32_forward_post_hook(
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
@@ -87,8 +102,8 @@ def prepare_model_for_training(
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
# This code is also inspired by the original LongLoRA implementation.
|
||||
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@@ -96,7 +115,8 @@ def llama_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -181,11 +201,9 @@ def llama_flash_attention_2_forward(
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
else:
|
||||
groupsz = q_len
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
@@ -194,7 +212,8 @@ def llama_flash_attention_2_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
@@ -293,7 +312,8 @@ def llama_sdpa_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -303,7 +323,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -10,6 +25,13 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, leaf_modules)
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
@@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jetmoe":
|
||||
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
|
||||
|
||||
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and Optimum library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
|
||||
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
@@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
@@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
data_files=data_files,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
samples = []
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
n_try = 0
|
||||
while True:
|
||||
if n_try > 100:
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
|
||||
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
|
||||
|
||||
return samples
|
||||
|
||||
@@ -76,14 +105,14 @@ def configure_quantization(
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
@@ -105,46 +134,72 @@ def configure_quantization(
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
raise ValueError("ChatGLM model is not supported yet.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
# Do not assign device map if:
|
||||
# 1. deepspeed zero3 or fsdp (train)
|
||||
# 2. auto quantization device map (inference)
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 LMSYS and the LlamaFactory team.
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# This code is inspired by the LMSYS's FastChat library.
|
||||
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user