add vllm config

Former-commit-id: 95365f0ce4f362bde7de8b679b54b548d7055bfb
This commit is contained in:
hiyouga
2024-11-10 21:28:18 +08:00
parent fcb6283a72
commit 1e6f96508a
34 changed files with 44 additions and 34 deletions

View File

@@ -15,10 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
@@ -125,7 +127,7 @@ class VllmArguments:
"""
vllm_maxlen: int = field(
default=2048,
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
@@ -140,6 +142,10 @@ class VllmArguments:
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
@dataclass
@@ -312,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}