format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,25 +1,25 @@
import torch
import inspect
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from ..extras.logging import get_logger
from .utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def init_adapter(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Initializes the adapters.
@@ -47,10 +47,10 @@ def init_adapter(
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
@@ -69,7 +69,7 @@ def init_adapter(
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
@@ -90,10 +90,10 @@ def init_adapter(
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if adapter_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
else:
@@ -103,11 +103,12 @@ def init_adapter(
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout
"lora_dropout": finetuning_args.lora_dropout,
}
if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {}
@@ -124,7 +125,7 @@ def init_adapter(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
**peft_kwargs
**peft_kwargs,
)
model = get_peft_model(model, lora_config)

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -7,12 +8,14 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..hparams import ModelArguments, FinetuningArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -29,7 +32,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False
add_valuehead: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
@@ -43,7 +46,7 @@ def load_model_and_tokenizer(
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token
"token": model_args.hf_hub_token,
}
tokenizer = AutoTokenizer.from_pretrained(
@@ -51,7 +54,7 @@ def load_model_and_tokenizer(
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**config_kwargs
**config_kwargs,
)
patch_tokenizer(tokenizer)
@@ -61,7 +64,8 @@ def load_model_and_tokenizer(
model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = {
"model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
@@ -69,7 +73,7 @@ def load_model_and_tokenizer(
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None)
"rope_scaling": getattr(config, "rope_scaling", None),
}
if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
@@ -89,7 +93,7 @@ def load_model_and_tokenizer(
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
**config_kwargs,
)
patch_model(model, tokenizer, model_args, is_trainable)
@@ -119,9 +123,11 @@ def load_model_and_tokenizer(
model.train()
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
logger.info(
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
)
if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")

View File

@@ -1,12 +1,12 @@
import os
import math
import torch
import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from contextlib import nullcontext
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -17,9 +17,11 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
@@ -40,7 +42,8 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
@@ -88,7 +91,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
@@ -119,9 +122,9 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
@@ -146,22 +149,22 @@ def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any]
config_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq
if getattr(config, "quantization_config", None): # gptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
quantization_config["use_exllama"] = False # disable exllama
quantization_config["use_exllama"] = False # disable exllama
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
@@ -172,13 +175,13 @@ def _configure_quantization(
config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args)
dataset=_get_quantization_dataset(tokenizer, model_args),
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@@ -192,7 +195,7 @@ def _configure_quantization(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
bnb_4bit_quant_type=model_args.quantization_type,
)
config_kwargs["device_map"] = {"": get_current_device()}
@@ -200,9 +203,7 @@ def _configure_quantization(
def _prepare_model_for_training(
model: "PreTrainedModel",
model_args: "ModelArguments",
output_layer_name: Optional[str] = "lm_head"
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
@@ -222,10 +223,11 @@ def _prepare_model_for_training(
logger.warning("Current model does not support gradient checkpointing.")
else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.config.use_cache = False # turn off when gradient checkpointing is enabled
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
@@ -244,9 +246,9 @@ def patch_config(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if getattr(config, "model_type", None) == "qwen":
@@ -266,10 +268,7 @@ def patch_config(
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)

View File

@@ -1,16 +1,19 @@
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from transformers import PreTrainedModel
from transformers.utils import cached_file
from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -21,7 +24,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
if getattr(model, "quantization_method", None): # already set on current device
return model
if (
@@ -31,7 +34,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
@@ -55,6 +58,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes":
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
@@ -65,10 +69,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, linear_cls)
and not any([output_layer in name for output_layer in output_layer_names])
):
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
@@ -76,16 +77,14 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
def get_modelcard_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments"
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"license": "other",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
}
@@ -95,14 +94,11 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}