add max_memory for gptq #1923
Former-commit-id: 9afc42c8b999fbbc206d9a467ca5795b27a10096
This commit is contained in:
@@ -8,6 +8,7 @@ from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import get_current_device
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,7 +21,7 @@ logger = get_logger(__name__)
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||
"""
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
@@ -43,7 +44,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
return dispatch_model(model, **device_map_kwargs)
|
||||
else:
|
||||
return model.cuda()
|
||||
return model.to(device=get_current_device())
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
|
||||
Reference in New Issue
Block a user