optimize aqlm training

Former-commit-id: 8b42660e4039b3d6475f502f397686ba6b140627
This commit is contained in:
hiyouga
2024-03-05 18:35:41 +08:00
parent 3553e301dd
commit a10bead9b5
2 changed files with 28 additions and 13 deletions

View File

@@ -1,3 +1,4 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -86,7 +87,17 @@ def load_model(
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
model_init_context = nullcontext()
if is_trainable and getattr(config, "quantization_config", None):
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
if quantization_config.get("quant_method", None) == "aqlm":
import aqlm # type: ignore
model_init_context = aqlm.optimize_for_training()
logger.info("Optimize for AQLM training.") # https://github.com/Vahe1994/AQLM/issues/38
with model_init_context:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **init_kwargs)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)