optimize aqlm training
Former-commit-id: 8b42660e4039b3d6475f502f397686ba6b140627
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -86,7 +87,17 @@ def load_model(
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
model_init_context = nullcontext()
|
||||
if is_trainable and getattr(config, "quantization_config", None):
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
import aqlm # type: ignore
|
||||
|
||||
model_init_context = aqlm.optimize_for_training()
|
||||
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
|
||||
|
||||
with model_init_context:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user