Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
@@ -1,9 +1,25 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .model_utils.misc import find_all_linear_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationMethod",
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -25,8 +39,12 @@ def _setup_full_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
if is_trainable:
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
init_kwargs = {
|
||||
"subfolder": model_args.adapter_folder,
|
||||
"offload_folder": model_args.offload_folder,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"token": model_args.hf_hub_token,
|
||||
}
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@@ -209,16 +238,24 @@ def _setup_lora_tuning(
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"use_dora": finetuning_args.use_dora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
if finetuning_args.pissa_iter == -1:
|
||||
logger.info("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -247,29 +281,36 @@ def init_adapter(
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "lora":
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
@@ -33,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
skip_check_imports()
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
@@ -162,17 +178,21 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
for param in model.parameters():
|
||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:d}".format(all_param)
|
||||
param_stats = "all params: {:,}".format(all_param)
|
||||
|
||||
logger.info(param_stats)
|
||||
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,21 +28,33 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||
if model_args.flash_attn == "auto":
|
||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn != "disabled":
|
||||
logger.warning(
|
||||
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
||||
"Will proceed at your own risk.".format(model_args.flash_attn)
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
elif model_args.flash_attn == "disabled":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_sdpa_available():
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn2_available():
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
@@ -60,15 +78,12 @@ def _fp32_forward_post_hook(
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
@@ -87,8 +102,8 @@ def prepare_model_for_training(
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||
# This code is also inspired by the original LongLoRA implementation.
|
||||
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@@ -96,7 +115,8 @@ def llama_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -181,11 +201,9 @@ def llama_flash_attention_2_forward(
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
else:
|
||||
groupsz = q_len
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
@@ -194,7 +212,8 @@ def llama_flash_attention_2_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
@@ -293,7 +312,8 @@ def llama_sdpa_attention_forward(
|
||||
(
|
||||
attn_output[:, :, : self.num_heads // 2],
|
||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||
)
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
@@ -303,7 +323,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -10,6 +25,13 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
set_z3_leaf_modules(model, leaf_modules)
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
@@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jetmoe":
|
||||
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
|
||||
|
||||
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers and Optimum library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
|
||||
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
@@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
@@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
data_files=data_files,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
samples = []
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
n_try = 0
|
||||
while True:
|
||||
if n_try > 100:
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
|
||||
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
|
||||
|
||||
return samples
|
||||
|
||||
@@ -76,14 +105,14 @@ def configure_quantization(
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
@@ -105,46 +134,72 @@ def configure_quantization(
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
raise ValueError("ChatGLM model is not supported yet.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
# Do not assign device map if:
|
||||
# 1. deepspeed zero3 or fsdp (train)
|
||||
# 2. auto quantization device map (inference)
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
# Copyright 2024 LMSYS and the LlamaFactory team.
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# This code is inspired by the LMSYS's FastChat library.
|
||||
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
@@ -46,13 +60,16 @@ def patch_config(
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
if model_args.infer_dtype != "auto" and not is_trainable:
|
||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||
else:
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
@@ -74,14 +91,17 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
|
||||
# cast data type of the model if:
|
||||
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
|
||||
# 2. quantization_bit is not None (qlora)
|
||||
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map
|
||||
|
||||
if init_kwargs["device_map"] == "auto":
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
if finetune_args.stage == "sft" and data_args.efficient_packing:
|
||||
@@ -137,6 +157,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_output_embeddings()
|
||||
|
||||
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
@@ -145,4 +169,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
||||
Reference in New Issue
Block a user