mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -35,6 +35,12 @@ class Response:
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
r"""
|
||||
Base class for inference engine of chat models.
|
||||
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
|
||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
r"""
|
||||
Initializes an inference engine.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
@@ -59,7 +69,11 @@ class BaseEngine(ABC):
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
@@ -70,11 +84,19 @@ class BaseEngine(ABC):
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]: ...
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user