support vllm
Former-commit-id: 889f6e910e654d8ec3922c2185042d737ffbf1c3
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
||||
from .utils import dispatch_model, load_valuehead_params
|
||||
from .utils import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"load_model_and_tokenizer",
|
||||
"load_tokenizer",
|
||||
"dispatch_model",
|
||||
"load_valuehead_params",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -7,7 +6,6 @@ from transformers.utils import cached_file
|
||||
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_current_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,36 +17,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
|
||||
"""
|
||||
if getattr(model, "quantization_method", None): # already set on current device
|
||||
return model
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and isinstance(model, PreTrainedModel)
|
||||
and model._no_split_modules is not None
|
||||
and model.config.model_type != "chatglm"
|
||||
):
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||
return dispatch_model(model, **device_map_kwargs)
|
||||
else:
|
||||
return model.to(device=get_current_device())
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
|
||||
Reference in New Issue
Block a user