support vllm

Former-commit-id: 889f6e910e654d8ec3922c2185042d737ffbf1c3
This commit is contained in:
hiyouga
2024-03-07 20:26:31 +08:00
parent 9a69cadab3
commit 056d2d956a
32 changed files with 752 additions and 316 deletions

View File

@@ -1,11 +1,10 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
from .utils import dispatch_model, load_valuehead_params
from .utils import load_valuehead_params
__all__ = [
"load_model",
"load_model_and_tokenizer",
"load_tokenizer",
"dispatch_model",
"load_valuehead_params",
]

View File

@@ -1,4 +1,3 @@
import inspect
from typing import TYPE_CHECKING, Dict, List
import torch
@@ -7,7 +6,6 @@ from transformers.utils import cached_file
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..extras.misc import get_current_device
if TYPE_CHECKING:
@@ -19,36 +17,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
"""
if getattr(model, "quantization_method", None): # already set on current device
return model
if (
torch.cuda.device_count() > 1
and isinstance(model, PreTrainedModel)
and model._no_split_modules is not None
and model.config.model_type != "chatglm"
):
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)
else:
return model.to(device=get_current_device())
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora.