add vllm config
Former-commit-id: 95365f0ce4f362bde7de8b679b54b548d7055bfb
This commit is contained in:
@@ -83,6 +83,7 @@ class VllmEngine(BaseEngine):
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
import vllm.model_executor.models.llava
|
||||
@@ -173,7 +174,7 @@ class VllmEngine(BaseEngine):
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
@@ -46,7 +46,7 @@ class DataArguments:
|
||||
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
default=2048,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@@ -125,7 +127,7 @@ class VllmArguments:
|
||||
"""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=2048,
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
@@ -140,6 +142,10 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -312,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
|
||||
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
init_args, lazy_args = {}, {}
|
||||
|
||||
@@ -122,7 +122,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
|
||||
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
@@ -68,7 +68,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=131072, value=1024, step=1)
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=131072, value=2048, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
|
||||
|
||||
Reference in New Issue
Block a user