rename files
Former-commit-id: e1a8431770fc36c0c9ee7fed4abbc3d7fdcc5efd
This commit is contained in:
28
src/llamafactory/model/model_utils/mod.py
Normal file
28
src/llamafactory/model/model_utils/mod.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
|
||||
def convert_pretrained_model_to_mod(
|
||||
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> "PreTrainedModel":
|
||||
from MoD import apply_mod_to_hf
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
return model
|
||||
Reference in New Issue
Block a user