fix mod stuff

Former-commit-id: cf3988226e6398c67bb2955578e436fc505aa5c5
This commit is contained in:
hiyouga
2024-04-21 18:11:10 +08:00
parent 3365cc8cf0
commit f8e219dc81
16 changed files with 63 additions and 88 deletions

View File

@@ -343,7 +343,7 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
template = templates["empty"] # placeholder
else:
template = templates.get(name, None)
if template is None:
@@ -385,7 +385,8 @@ _register_template(
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
)
@@ -596,6 +597,13 @@ _register_template(
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@@ -740,13 +748,6 @@ _register_template(
)
_register_template(
name="vanilla",
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),

View File

@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@@ -83,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1

View File

@@ -63,15 +63,15 @@ class ModelArguments:
)
flash_attn: bool = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."},
metadata={"help": "Enable FlashAttention for faster training."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "continue"]] = field(
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Whether or not to use MoD in the model."},
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,

View File

@@ -82,8 +82,8 @@ def _check_extra_dependencies(
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths == 'convert' or model_args.mixture_of_depths == 'continue':
require_version("mixture-of-depth", "To fix: pip install mixture-of-depth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")

View File

@@ -69,7 +69,7 @@ def init_adapter(
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # here since MoD starts from layer 1
elif ".1." in name: # MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = []

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.constants import MOD_SUPPORTED_MODELS
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from .adapter import init_adapter
@@ -44,7 +45,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
padding_side="right",
**init_kwargs,
)
except Exception: # try the fast one
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
@@ -71,12 +72,6 @@ def load_model(
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
if model_args.mixture_of_depths == 'continue':
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
@@ -104,14 +99,22 @@ def load_model(
if model is None:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == 'convert':
from MoD import convert_hf_model
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")
model = convert_hf_model(model)
if model_args.mixture_of_depths == "load":
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
@@ -119,7 +122,7 @@ def load_model(
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:

View File

@@ -61,9 +61,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples
def _configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", init_kwargs: Dict[str, Any]
) -> None:
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
@@ -73,9 +71,9 @@ def _configure_attn_implementation(
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", "flash_attention_2")
else:
init_kwargs["attn_implementation"] = "flash_attention_2"
setattr(config, "_attn_implementation", "flash_attention_2")
else:
init_kwargs["attn_implementation"] = "eager"
setattr(config, "_attn_implementation", "eager")
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@@ -295,7 +293,7 @@ def patch_config(
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
_configure_attn_implementation(config, model_args, init_kwargs)
_configure_attn_implementation(config, model_args)
_configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs)