support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
logger.warning("There is no current event loop, creating a new one.")
|
||||
logger.warning_once("There is no current event loop, creating a new one.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
|
||||
@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
@@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
|
||||
Reference in New Issue
Block a user