fix aqlm version

Former-commit-id: 05673f81f0295c76957f3247c62f95fda322a63e
This commit is contained in:
hiyouga
2024-03-09 00:09:09 +08:00
parent 53ab28533e
commit 9b97b23ce7
6 changed files with 8 additions and 18 deletions

View File

@@ -1,4 +1,3 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -87,17 +86,7 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model_init_context = nullcontext()
if model_args.aqlm_optimization and getattr(config, "quantization_config", None):
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "aqlm":
import aqlm # type: ignore
model_init_context = aqlm.optimize_for_training()
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
with model_init_context:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)