Merge pull request #3338 from astramind-ai/main
Adding Mixture of Depth Former-commit-id: 4da2ece53353b63e672ff529d6beba41ff710c14
This commit is contained in:
@@ -69,6 +69,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "continue"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether or not to use MoD in the model."},
|
||||
)
|
||||
use_unsloth: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
|
||||
@@ -82,6 +82,9 @@ def _check_extra_dependencies(
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.mixture_of_depths == 'convert' or model_args.mixture_of_depths == 'continue':
|
||||
require_version("mixture-of-depth", "To fix: pip install mixture-of-depth")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")
|
||||
|
||||
|
||||
@@ -69,6 +69,8 @@ def init_adapter(
|
||||
for name, _ in model.named_modules():
|
||||
if ".0." in name:
|
||||
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # here since MoD starts from layer 1
|
||||
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.name_module_trainable:
|
||||
|
||||
@@ -71,6 +71,12 @@ def load_model(
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if model_args.mixture_of_depths == 'continue':
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
model = AutoMoDModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
|
||||
if model.config.model_type == 'qwen2':
|
||||
RuntimeError("Qwen models are not supported for MoD training.")
|
||||
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
@@ -100,6 +106,13 @@ def load_model(
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
if model_args.mixture_of_depths == 'convert':
|
||||
from MoD import convert_hf_model
|
||||
if model.config.model_type == 'qwen2':
|
||||
RuntimeError("Qwen models are not supported for MoD training.")
|
||||
model = convert_hf_model(model)
|
||||
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user