fix aqlm version
Former-commit-id: 05673f81f0295c76957f3247c62f95fda322a63e
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -87,17 +86,7 @@ def load_model(
|
||||
logger.warning("Unsloth does not support loading adapters.")
|
||||
|
||||
if model is None:
|
||||
model_init_context = nullcontext()
|
||||
if model_args.aqlm_optimization and getattr(config, "quantization_config", None):
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
import aqlm # type: ignore
|
||||
|
||||
model_init_context = aqlm.optimize_for_training()
|
||||
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
|
||||
|
||||
with model_init_context:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
|
||||
|
||||
patch_model(model, tokenizer, model_args, is_trainable)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user