[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
r"""
|
||||
Efficient fine-tuning of large language models.
|
||||
r"""Efficient fine-tuning of large language models.
|
||||
|
||||
Level:
|
||||
api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
@@ -16,9 +16,7 @@ import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Annotated
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
|
||||
@@ -18,7 +18,8 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras import logging
|
||||
@@ -71,7 +72,7 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||
) -> tuple[list[dict[str, str]], Optional[str], Optional[str], Optional[list["ImageInput"]]]:
|
||||
if is_env_enabled("API_VERBOSE", "1"):
|
||||
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
def dictify(data: "BaseModel") -> dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelCard] = []
|
||||
data: list[ModelCard] = []
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
@@ -56,7 +56,7 @@ class Function(BaseModel):
|
||||
class FunctionDefinition(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
parameters: dict[str, Any]
|
||||
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
@@ -82,26 +82,26 @@ class MultimodalInputItem(BaseModel):
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[Union[str, List[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[FunctionCall]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
messages: list[ChatMessage]
|
||||
tools: Optional[list[FunctionAvailable]] = None
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
choices: list[ChatCompletionResponseChoice]
|
||||
usage: ChatCompletionResponseUsage
|
||||
|
||||
|
||||
@@ -137,12 +137,12 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
choices: list[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[str]
|
||||
messages: list[str]
|
||||
max_length: Optional[int] = None
|
||||
|
||||
|
||||
@@ -150,4 +150,4 @@ class ScoreEvaluationResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
scores: list[float]
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,8 +37,7 @@ class Response:
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
r"""
|
||||
Base class for inference engine of chat models.
|
||||
r"""Base class for inference engine of chat models.
|
||||
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
@@ -47,7 +47,7 @@ class BaseEngine(ABC):
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
template: "Template"
|
||||
generating_args: Dict[str, Any]
|
||||
generating_args: dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@@ -57,31 +57,27 @@ class BaseEngine(ABC):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
r"""
|
||||
Initializes an inference engine.
|
||||
"""
|
||||
r"""Initialize an inference engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -89,18 +85,14 @@ class BaseEngine(ABC):
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
...
|
||||
|
||||
@@ -17,8 +17,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Generator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras.constants import EngineName
|
||||
from ..extras.misc import torch_gc
|
||||
@@ -38,20 +39,19 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
|
||||
|
||||
class ChatModel:
|
||||
r"""
|
||||
General class for chat models. Backed by huggingface or vllm engines.
|
||||
r"""General class for chat models. Backed by huggingface or vllm engines.
|
||||
|
||||
Supports both sync and async methods.
|
||||
Sync methods: chat(), stream_chat() and get_scores().
|
||||
Async methods: achat(), astream_chat() and aget_scores().
|
||||
"""
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
if model_args.infer_backend == EngineName.HF:
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == EngineName.VLLM:
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
@@ -61,17 +61,15 @@ class ChatModel:
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Get a list of responses of the chat model."""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
|
||||
)
|
||||
@@ -79,22 +77,20 @@ class ChatModel:
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
) -> list["Response"]:
|
||||
r"""Asynchronously get a list of responses of the chat model."""
|
||||
return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -102,9 +98,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Get the response token-by-token of the chat model."""
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
@@ -115,7 +109,7 @@ class ChatModel:
|
||||
|
||||
async def astream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -123,9 +117,7 @@ class ChatModel:
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
r"""Asynchronously get the response token-by-token of the chat model."""
|
||||
async for new_token in self.engine.stream_chat(
|
||||
messages, system, tools, images, videos, audios, **input_kwargs
|
||||
):
|
||||
@@ -133,23 +125,19 @@ class ChatModel:
|
||||
|
||||
def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Get a list of scores of the reward model."""
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def aget_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Asynchronously gets a list of scores of the reward model.
|
||||
"""
|
||||
) -> list[float]:
|
||||
r"""Asynchronously get a list of scores of the reward model."""
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
@@ -76,15 +77,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
@@ -130,7 +131,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
@@ -217,15 +218,15 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -272,14 +273,14 @@ class HuggingfaceEngine(BaseEngine):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
template: "Template",
|
||||
generating_args: Dict[str, Any],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
generating_args: dict[str, Any],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model,
|
||||
@@ -317,12 +318,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
def _get_scores(
|
||||
model: "PreTrainedModelWrapper",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_input: List[str],
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List[float]:
|
||||
batch_input: list[str],
|
||||
input_kwargs: Optional[dict[str, Any]] = {},
|
||||
) -> list[float]:
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||
inputs: dict[str, torch.Tensor] = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
@@ -330,21 +331,21 @@ class HuggingfaceEngine(BaseEngine):
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return scores
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
if not self.can_generate:
|
||||
raise ValueError("The current model does not support `chat`.")
|
||||
|
||||
@@ -370,7 +371,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -408,9 +409,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
if self.can_generate:
|
||||
raise ValueError("Cannot get scores using an auto-regressive model.")
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||
model_args.infer_dtype = "float16"
|
||||
@@ -101,7 +102,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
async def _generate(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -143,7 +144,7 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if length_penalty is not None:
|
||||
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||
@@ -201,14 +202,14 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
audios: Optional[Sequence["AudioInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
) -> list["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
@@ -230,7 +231,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
@@ -248,7 +249,7 @@ class VllmEngine(BaseEngine):
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
batch_input: list[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
raise NotImplementedError("vLLM engine does not support get_scores.")
|
||||
|
||||
@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TEMPLATES",
|
||||
"KTODataCollatorWithPadding",
|
||||
"MultiModalDataCollatorForSeq2Seq",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
"TEMPLATES",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Template",
|
||||
"get_dataset",
|
||||
"get_template_and_fix_tokenizer",
|
||||
"split_dataset",
|
||||
]
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@@ -78,8 +80,7 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
@dataclass
|
||||
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
r"""Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
|
||||
"""
|
||||
@@ -91,7 +92,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
for feature in features:
|
||||
@@ -166,7 +167,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
for i, feature in enumerate(features):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
rope_index_kwargs = {
|
||||
@@ -198,15 +199,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
r"""Data collator for 4d attention mask."""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
@@ -220,13 +219,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
r"""Pad batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
@@ -249,11 +245,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
r"""Data collator for KTO data."""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
@@ -36,10 +37,8 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[List[Any]]:
|
||||
r"""
|
||||
Optionally concatenates media path to media dir when loading from local disk.
|
||||
"""
|
||||
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if not isinstance(medias, list):
|
||||
medias = [medias] if medias is not None else []
|
||||
elif len(medias) == 0:
|
||||
@@ -57,16 +56,14 @@ class DatasetConverter:
|
||||
return medias
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Converts a single example in the dataset to the standard format.
|
||||
"""
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
r"""Convert a single example in the dataset to the standard format."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlpacaDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
prompt = []
|
||||
if self.dataset_attr.history and isinstance(example[self.dataset_attr.history], list):
|
||||
for old_prompt, old_response in example[self.dataset_attr.history]:
|
||||
@@ -116,7 +113,7 @@ class AlpacaDatasetConverter(DatasetConverter):
|
||||
|
||||
@dataclass
|
||||
class SharegptDatasetConverter(DatasetConverter):
|
||||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
|
||||
tag_mapping = {
|
||||
self.dataset_attr.user_tag: Role.USER.value,
|
||||
self.dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
@@ -216,10 +213,8 @@ DATASET_CONVERTERS = {
|
||||
}
|
||||
|
||||
|
||||
def register_dataset_converter(name: str, dataset_converter: Type["DatasetConverter"]) -> None:
|
||||
r"""
|
||||
Register a new dataset converter.
|
||||
"""
|
||||
def register_dataset_converter(name: str, dataset_converter: type["DatasetConverter"]) -> None:
|
||||
r"""Register a new dataset converter."""
|
||||
if name in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} already exists.")
|
||||
|
||||
@@ -227,9 +222,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
|
||||
|
||||
|
||||
def get_dataset_converter(name: str, dataset_attr: "DatasetAttr", data_args: "DataArguments") -> "DatasetConverter":
|
||||
r"""
|
||||
Gets a dataset converter.
|
||||
"""
|
||||
r"""Get a dataset converter."""
|
||||
if name not in DATASET_CONVERTERS:
|
||||
raise ValueError(f"Dataset converter {name} not found.")
|
||||
|
||||
@@ -242,17 +235,17 @@ def align_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
_audios: [],
|
||||
"""
|
||||
r"""Align the dataset to a specific format.
|
||||
|
||||
Aligned dataset:
|
||||
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
||||
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
||||
_system: "..."
|
||||
_tools: "..."
|
||||
_images: []
|
||||
_videos: []
|
||||
_audios: []
|
||||
"""
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
|
||||
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
@@ -43,15 +44,13 @@ class Role(str, Enum):
|
||||
|
||||
class DatasetModule(TypedDict):
|
||||
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Merges multiple datasets to a unified dataset.
|
||||
"""
|
||||
r"""Merge multiple datasets to a unified dataset."""
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
|
||||
@@ -78,14 +77,13 @@ def merge_dataset(
|
||||
|
||||
def split_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
|
||||
data_args: "DataArguments",
|
||||
seed: int,
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||
r"""Split the dataset and returns a dataset dict containing train set and validation set.
|
||||
|
||||
Supports both map dataset and iterable dataset.
|
||||
Support both map dataset and iterable dataset.
|
||||
"""
|
||||
if eval_dataset is not None and data_args.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
@@ -120,10 +118,8 @@ def split_dataset(
|
||||
|
||||
|
||||
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
|
||||
r"""
|
||||
Converts dataset or dataset dict to dataset module.
|
||||
"""
|
||||
dataset_module: "DatasetModule" = {}
|
||||
r"""Convert dataset or dataset dict to dataset module."""
|
||||
dataset_module: DatasetModule = {}
|
||||
if isinstance(dataset, DatasetDict): # dataset dict
|
||||
if "train" in dataset:
|
||||
dataset_module["train_dataset"] = dataset["train"]
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -31,14 +31,11 @@ class Formatter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""
|
||||
Forms a list of slots according to the inputs to encode.
|
||||
"""
|
||||
r"""Forms a list of slots according to the inputs to encode."""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extract a list of tuples from the response message if using tools.
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
"""
|
||||
@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
|
||||
if thought:
|
||||
content = content.replace(thought.group(0), "")
|
||||
|
||||
functions: List["FunctionCall"] = []
|
||||
functions: list[FunctionCall] = []
|
||||
try:
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
@@ -54,9 +55,7 @@ def _load_single_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
r"""Load a single dataset and aligns it to the standard format."""
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
@@ -164,10 +163,8 @@ def _get_merged_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
merge: bool = True,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
|
||||
r"""
|
||||
Returns the merged datasets in the standard format.
|
||||
"""
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
@@ -192,9 +189,7 @@ def _get_dataset_processor(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> "DatasetProcessor":
|
||||
r"""
|
||||
Returns the corresponding dataset processor.
|
||||
"""
|
||||
r"""Return the corresponding dataset processor."""
|
||||
if stage == "pt":
|
||||
dataset_processor_class = PretrainDatasetProcessor
|
||||
elif stage == "sft" and not do_generate:
|
||||
@@ -236,9 +231,7 @@ def _get_preprocessed_dataset(
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Preprocesses the dataset, including format checking and tokenization.
|
||||
"""
|
||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@@ -284,9 +277,7 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
r"""Get the train dataset and optionally gets the evaluation dataset."""
|
||||
# Load tokenized dataset if path exists
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import inspect
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -58,12 +59,12 @@ if TYPE_CHECKING:
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
) -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
batch_token_type_ids: shape (batch_size, sequence_length)
|
||||
|
||||
"""
|
||||
batch_token_type_ids = []
|
||||
for imglen, seqlen in zip(imglens, seqlens):
|
||||
@@ -87,11 +88,9 @@ class MMPluginMixin:
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> None:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
r"""Validate if this model accepts the input modalities."""
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@@ -119,9 +118,7 @@ class MMPluginMixin:
|
||||
def _preprocess_image(
|
||||
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
||||
) -> "ImageObject":
|
||||
r"""
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
r"""Pre-process a single image."""
|
||||
if (image.width * image.height) > image_max_pixels:
|
||||
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
@@ -139,10 +136,8 @@ class MMPluginMixin:
|
||||
|
||||
def _get_video_sample_indices(
|
||||
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Computes video sample indices according to fps.
|
||||
"""
|
||||
) -> list[int]:
|
||||
r"""Compute video sample indices according to fps."""
|
||||
total_frames = video_stream.frames
|
||||
if total_frames == 0: # infinite video
|
||||
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
||||
@@ -151,10 +146,8 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading and pre-processing.
|
||||
"""
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
@@ -174,16 +167,14 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: List["ImageObject"] = []
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@@ -194,10 +185,8 @@ class MMPluginMixin:
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
|
||||
r"""
|
||||
Regularizes audios to avoid error. Including reading and resampling.
|
||||
"""
|
||||
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results = []
|
||||
for audio in audios:
|
||||
if isinstance(audio, str):
|
||||
@@ -216,9 +205,8 @@ class MMPluginMixin:
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
@@ -229,9 +217,9 @@ class MMPluginMixin:
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
@@ -278,31 +266,27 @@ class MMPluginMixin:
|
||||
class BasePlugin(MMPluginMixin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Pre-processes input messages before tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
r"""Pre-processes token ids after tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
|
||||
@@ -314,20 +298,21 @@ class BasePlugin(MMPluginMixin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
r"""Build batched multimodal inputs for VLMs.
|
||||
|
||||
Arguments:
|
||||
images: a list of image inputs, shape (num_images,)
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
audios: a list of audio inputs, shape (num_audios,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
audlens: number of audios in each sample, shape (batch_size,)
|
||||
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
|
||||
"""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return {}
|
||||
@@ -338,12 +323,12 @@ class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
@@ -370,9 +355,9 @@ class LlavaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@@ -382,12 +367,12 @@ class LlavaNextPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@@ -426,9 +411,9 @@ class LlavaNextPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@@ -438,12 +423,12 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
@@ -502,9 +487,9 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@@ -514,16 +499,16 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
audio_inputs = {}
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
@@ -619,9 +604,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@@ -691,9 +676,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
# image bound
|
||||
image_bounds_list = []
|
||||
@@ -756,12 +741,12 @@ class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@@ -782,10 +767,9 @@ class MllamaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
imglens: List[int],
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
imglens: list[int],
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
@@ -794,8 +778,9 @@ class MllamaPlugin(BasePlugin):
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
if len(images) > 0:
|
||||
images = self._regularize_images(
|
||||
@@ -821,9 +806,9 @@ class MllamaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
||||
if mm_inputs:
|
||||
@@ -850,12 +835,12 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
@@ -875,14 +860,14 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
@@ -902,9 +887,9 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@@ -917,12 +902,12 @@ class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
@@ -968,9 +953,9 @@ class PixtralPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
@@ -982,12 +967,12 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
bos_token: str = getattr(processor, "audio_bos_token")
|
||||
eos_token: str = getattr(processor, "audio_eos_token")
|
||||
@@ -1028,9 +1013,9 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@@ -1057,13 +1042,13 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def _regularize_videos(
|
||||
self, videos: Sequence["VideoInput"], **kwargs
|
||||
) -> Tuple[List[List["ImageObject"]], List[float]]:
|
||||
) -> tuple[list[list["ImageObject"]], list[float]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
frames: List["ImageObject"] = []
|
||||
frames: list[ImageObject] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@@ -1088,8 +1073,8 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@@ -1115,16 +1100,16 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
@@ -1176,13 +1161,13 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
|
||||
|
||||
@@ -1194,12 +1179,12 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
@@ -1255,9 +1240,9 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
audlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
batch_ids: Sequence[list[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
@@ -1277,10 +1262,8 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
|
||||
r"""
|
||||
Registers a multimodal plugin.
|
||||
"""
|
||||
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
||||
r"""Register a multimodal plugin."""
|
||||
if name in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin {name} already exists.")
|
||||
|
||||
@@ -1293,9 +1276,7 @@ def get_mm_plugin(
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
r"""
|
||||
Gets plugin for multimodal inputs.
|
||||
"""
|
||||
r"""Get plugin for multimodal inputs."""
|
||||
if name not in PLUGINS:
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from transformers.utils import cached_file
|
||||
|
||||
@@ -25,9 +26,7 @@ from ..extras.misc import use_modelscope, use_openmind
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
r"""Dataset attributes."""
|
||||
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||
@@ -68,10 +67,10 @@ class DatasetAttr:
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: Dict[str, Any]) -> None:
|
||||
def join(self, attr: dict[str, Any]) -> None:
|
||||
self.set_attr("formatting", attr, default="alpaca")
|
||||
self.set_attr("ranking", attr, default=False)
|
||||
self.set_attr("subset", attr)
|
||||
@@ -92,10 +91,8 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||
r"""
|
||||
Gets the attributes of the datasets.
|
||||
"""
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@@ -116,7 +113,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
|
||||
dataset_info = None
|
||||
|
||||
dataset_list: List["DatasetAttr"] = []
|
||||
dataset_list: list[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None: # dataset_dir is ONLINE
|
||||
if use_modelscope():
|
||||
|
||||
@@ -9,9 +9,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
|
||||
__all__ = [
|
||||
"DatasetProcessor",
|
||||
"FeedbackDatasetProcessor",
|
||||
"PackedSupervisedDatasetProcessor",
|
||||
"PairwiseDatasetProcessor",
|
||||
"PretrainDatasetProcessor",
|
||||
"PackedSupervisedDatasetProcessor",
|
||||
"SupervisedDatasetProcessor",
|
||||
"UnsupervisedDatasetProcessor",
|
||||
]
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
|
||||
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
kl_response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
) -> tuple[list[int], list[int], list[int], list[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
@@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
|
||||
class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int], list[int], list[int]]:
|
||||
chosen_messages = self.template.mm_plugin.process_messages(
|
||||
prompt + [response[0]], images, videos, audios, self.processor
|
||||
)
|
||||
@@ -68,7 +69,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@@ -99,7 +100,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
|
||||
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
|
||||
@@ -17,14 +17,14 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from .processor_utils import DatasetProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainDatasetProcessor(DatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
|
||||
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return result
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
|
||||
import bisect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -27,9 +28,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class DatasetProcessor(ABC):
|
||||
r"""
|
||||
A class for data processors.
|
||||
"""
|
||||
r"""A class for data processors."""
|
||||
|
||||
template: "Template"
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
@@ -37,32 +36,24 @@ class DatasetProcessor(ABC):
|
||||
data_args: "DataArguments"
|
||||
|
||||
@abstractmethod
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Builds model inputs from the examples.
|
||||
"""
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
r"""Build model inputs from the examples."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
r"""
|
||||
Print a data example to stdout.
|
||||
"""
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
r"""Print a data example to stdout."""
|
||||
...
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
r"""
|
||||
Finds the index of largest number that fits into the knapsack with the given capacity.
|
||||
"""
|
||||
r"""Find the index of largest number that fits into the knapsack with the given capacity."""
|
||||
index = bisect.bisect(numbers, capacity)
|
||||
return -1 if index == 0 else (index - 1)
|
||||
|
||||
|
||||
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
r"""
|
||||
An efficient greedy algorithm with binary search for the knapsack problem.
|
||||
"""
|
||||
def greedy_knapsack(numbers: list[int], capacity: int) -> list[list[int]]:
|
||||
r"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
|
||||
numbers.sort() # sort numbers in ascending order for binary search
|
||||
knapsacks = []
|
||||
|
||||
@@ -83,10 +74,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
return knapsacks
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
r"""
|
||||
Computes the real sequence length after truncation by the cutoff_len.
|
||||
"""
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> tuple[int, int]:
|
||||
r"""Compute the real sequence length after truncation by the cutoff_len."""
|
||||
if target_len * 2 < cutoff_len: # truncate source
|
||||
max_target_len = cutoff_len
|
||||
elif source_len * 2 < cutoff_len: # truncate target
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -32,14 +33,14 @@ logger = logging.get_logger(__name__)
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
@@ -85,7 +86,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = defaultdict(list)
|
||||
@@ -114,7 +115,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@@ -124,7 +125,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
@dataclass
|
||||
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
@@ -30,14 +31,14 @@ logger = logging.get_logger(__name__)
|
||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
@@ -56,7 +57,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@@ -84,7 +85,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
|
||||
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -46,8 +47,8 @@ class Template:
|
||||
format_tools: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
thought_words: Tuple[str, str]
|
||||
stop_words: list[str]
|
||||
thought_words: tuple[str, str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
@@ -56,13 +57,11 @@ class Template:
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
) -> tuple[list[int], list[int]]:
|
||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
@@ -74,36 +73,28 @@ class Template:
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract tool message."""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
|
||||
r"""
|
||||
Returns stop token ids.
|
||||
"""
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||
r"""Return stop token ids."""
|
||||
stop_token_ids = {tokenizer.eos_token_id}
|
||||
for token in self.stop_words:
|
||||
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||
r"""Convert elements to token ids."""
|
||||
token_ids = []
|
||||
for elem in elements:
|
||||
if isinstance(elem, str):
|
||||
@@ -124,14 +115,14 @@ class Template:
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
) -> list[list[int]]:
|
||||
r"""Encode formatted inputs to pairs of token ids.
|
||||
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: query resp
|
||||
Turn t: query resp.
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@@ -161,9 +152,7 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
||||
r"""
|
||||
Adds or replaces eos token to the tokenizer.
|
||||
"""
|
||||
r"""Add or replace eos token to the tokenizer."""
|
||||
is_added = tokenizer.eos_token_id is None
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
@@ -176,9 +165,7 @@ class Template:
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Adds eos token and pad token to the tokenizer.
|
||||
"""
|
||||
r"""Add eos token and pad token to the tokenizer."""
|
||||
stop_words = self.stop_words
|
||||
if self.replace_eos:
|
||||
if not stop_words:
|
||||
@@ -204,16 +191,12 @@ class Template:
|
||||
|
||||
@staticmethod
|
||||
def _jinja_escape(content: str) -> str:
|
||||
r"""
|
||||
Escape single quotes in content.
|
||||
"""
|
||||
r"""Escape single quotes in content."""
|
||||
return content.replace("'", r"\'")
|
||||
|
||||
@staticmethod
|
||||
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
||||
r"""
|
||||
Converts slots to jinja template.
|
||||
"""
|
||||
r"""Convert slots to jinja template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@@ -235,9 +218,7 @@ class Template:
|
||||
return " + ".join(slot_items)
|
||||
|
||||
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
r"""Return the jinja template."""
|
||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
@@ -265,9 +246,7 @@ class Template:
|
||||
return jinja_template
|
||||
|
||||
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Replaces the jinja template in the tokenizer.
|
||||
"""
|
||||
r"""Replace the jinja template in the tokenizer."""
|
||||
if tokenizer.chat_template is None or self.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = self._get_jinja_template(tokenizer)
|
||||
@@ -278,9 +257,7 @@ class Template:
|
||||
def _convert_slots_to_ollama(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
r"""
|
||||
Converts slots to ollama template.
|
||||
"""
|
||||
r"""Convert slots to ollama template."""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
@@ -302,9 +279,7 @@ class Template:
|
||||
return "".join(slot_items)
|
||||
|
||||
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama template.
|
||||
"""
|
||||
r"""Return the ollama template."""
|
||||
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||
@@ -316,8 +291,7 @@ class Template:
|
||||
)
|
||||
|
||||
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama modelfile.
|
||||
r"""Return the ollama modelfile.
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
@@ -340,10 +314,10 @@ class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: Sequence[Dict[str, str]],
|
||||
messages: Sequence[dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
) -> List[List[int]]:
|
||||
) -> list[list[int]]:
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
@@ -402,7 +376,7 @@ class Llama2Template(Template):
|
||||
return jinja_template
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
TEMPLATES: dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def register_template(
|
||||
@@ -416,15 +390,14 @@ def register_template(
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Optional[Sequence[str]] = None,
|
||||
thought_words: Optional[Tuple[str, str]] = None,
|
||||
thought_words: Optional[tuple[str, str]] = None,
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
template_class: Type["Template"] = Template,
|
||||
template_class: type["Template"] = Template,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
r"""Register a chat template.
|
||||
|
||||
To add the following chat template:
|
||||
```
|
||||
@@ -472,9 +445,7 @@ def register_template(
|
||||
|
||||
|
||||
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
r"""
|
||||
Extracts a chat template from the tokenizer.
|
||||
"""
|
||||
r"""Extract a chat template from the tokenizer."""
|
||||
|
||||
def find_diff(short_str: str, long_str: str) -> str:
|
||||
i, j = 0, 0
|
||||
@@ -532,9 +503,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
r"""Get chat template and fixes the tokenizer."""
|
||||
if data_args.template is None:
|
||||
if isinstance(tokenizer.chat_template, str):
|
||||
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
|
||||
@@ -1149,7 +1118,8 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
default_system=(
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"你是一个经过良好训练的AI助手,你的名字是Marco-o1."
|
||||
"由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
|
||||
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
|
||||
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
|
||||
),
|
||||
|
||||
@@ -17,7 +17,7 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
from typing import Any, NamedTuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -60,31 +60,24 @@ QWEN_TOOL_PROMPT = (
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
"""Base class for tool utilities."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
r"""Generate the system message describing all the available tools."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
r"""Generate the assistant message including all the tool calls."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the assistant message.
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
r"""Extract all the function calls from the assistant message.
|
||||
|
||||
It should be an inverse function of `function_formatter`.
|
||||
"""
|
||||
@@ -92,13 +85,11 @@ class ToolUtils(ABC):
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
r"""
|
||||
Default tool using template.
|
||||
"""
|
||||
r"""Default tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
@@ -132,7 +123,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
@@ -141,9 +132,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
action_match: list[tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
@@ -161,13 +152,11 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""
|
||||
GLM-4 tool using template.
|
||||
"""
|
||||
r"""GLM-4 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
@@ -178,7 +167,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
@@ -186,7 +175,7 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
@@ -200,15 +189,14 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
r"""
|
||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
r"""Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
|
||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||
"""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
@@ -219,7 +207,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
@@ -227,7 +215,7 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@@ -240,13 +228,11 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""
|
||||
Mistral v0.3 tool using template.
|
||||
"""
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
wrapped_tools = []
|
||||
for tool in tools:
|
||||
wrapped_tools.append({"type": "function", "function": tool})
|
||||
@@ -255,7 +241,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
@@ -264,7 +250,7 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
@@ -284,13 +270,11 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
r"""
|
||||
Qwen 2.5 tool using template.
|
||||
"""
|
||||
r"""Qwen 2.5 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
@@ -300,7 +284,7 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
@@ -311,9 +295,9 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
|
||||
tool_match: List[str] = re.findall(regex, content)
|
||||
tool_match: list[str] = re.findall(regex, content)
|
||||
if not tool_match:
|
||||
return content
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
@@ -69,7 +69,7 @@ class Evaluator:
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
@@ -88,7 +88,7 @@ class Evaluator:
|
||||
)
|
||||
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
categorys: dict[str, dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
@@ -136,7 +136,7 @@ class Evaluator:
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
@@ -25,20 +25,19 @@ class EvalTemplate:
|
||||
choice: str
|
||||
answer: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
|
||||
r"""Parse eval example.
|
||||
|
||||
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||
output: a tuple of (prompt, response)
|
||||
output: a tuple of (prompt, response).
|
||||
"""
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Converts dataset examples to messages.
|
||||
"""
|
||||
self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Convert dataset examples to messages."""
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self._parse_example(support_set[k])
|
||||
@@ -52,7 +51,7 @@ class EvalTemplate:
|
||||
return messages
|
||||
|
||||
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
eval_templates: dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@@ -122,7 +122,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -32,9 +32,7 @@ _default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Redirects the logging output to the logging file for LLaMA Board.
|
||||
"""
|
||||
r"""Redirect the logging output to the logging file for LLaMA Board."""
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
@@ -67,9 +65,7 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports rank0 logging.
|
||||
"""
|
||||
r"""A logger that supports rank0 logging."""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
@@ -82,9 +78,7 @@ class _Logger(logging.Logger):
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Returns the default logging level.
|
||||
"""
|
||||
r"""Return the default logging level."""
|
||||
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
@@ -104,9 +98,7 @@ def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""
|
||||
Configures root logger using a stdout stream handler with an explicit format.
|
||||
"""
|
||||
r"""Configure root logger using a stdout stream handler with an explicit format."""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
@@ -126,9 +118,7 @@ def _configure_library_root_logger() -> None:
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -137,17 +127,13 @@ def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""
|
||||
Adds a handler to the root logger.
|
||||
"""
|
||||
r"""Add a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""
|
||||
Removes a handler to the root logger.
|
||||
"""
|
||||
r"""Remove a handler to the root logger."""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -54,9 +55,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
r"""Compute and store the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -75,9 +74,7 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""
|
||||
Optionally checks the package version.
|
||||
"""
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
@@ -91,9 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.2.0")
|
||||
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||
@@ -103,10 +98,8 @@ def check_dependencies() -> None:
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""
|
||||
Calculates effective tokens per second.
|
||||
"""
|
||||
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
r"""Calculate effective tokens per second."""
|
||||
effective_token_num = 0
|
||||
for data in dataset:
|
||||
if stage == "sft":
|
||||
@@ -118,10 +111,8 @@ def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float],
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
||||
r"""Return the number of trainable parameters and number of all parameters in the model."""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
@@ -148,9 +139,7 @@ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
@@ -166,9 +155,7 @@ def get_current_device() -> "torch.device":
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
r"""Get the number of available GPU or NPU devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
@@ -180,18 +167,14 @@ def get_device_count() -> int:
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
r"""Get logits processor that removes NaN and Inf logits."""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def get_peak_memory() -> Tuple[int, int]:
|
||||
r"""
|
||||
Gets the peak memory usage for the current device (in Bytes).
|
||||
"""
|
||||
def get_peak_memory() -> tuple[int, int]:
|
||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
@@ -201,16 +184,12 @@ def get_peak_memory() -> Tuple[int, int]:
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
r"""Check if the path has a tokenized dataset."""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
@@ -220,23 +199,17 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
|
||||
|
||||
def is_gpu_or_npu_available() -> bool:
|
||||
r"""
|
||||
Checks if the GPU or NPU is available.
|
||||
"""
|
||||
r"""Check if the GPU or NPU is available."""
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""
|
||||
Checks if the environment variable is enabled.
|
||||
"""
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
r"""
|
||||
Casts a torch tensor or a numpy array to a numpy array.
|
||||
"""
|
||||
r"""Cast a torch tensor or a numpy array to a numpy array."""
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
@@ -248,17 +221,13 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
r"""
|
||||
Avoids flash attention import error in custom model files.
|
||||
"""
|
||||
r"""Avoid flash attention import error in custom model files."""
|
||||
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU or NPU memory.
|
||||
"""
|
||||
r"""Collect GPU or NPU memory."""
|
||||
gc.collect()
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@@ -31,10 +31,8 @@ if is_matplotlib_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
def smooth(scalars: list[float]) -> list[float]:
|
||||
r"""EMA implementation according to TensorBoard."""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""Plot loss curves in LlamaBoard."""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
|
||||
r"""Plot loss curves and saves the image."""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
@@ -16,14 +16,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -162,5 +160,5 @@ class DataArguments:
|
||||
if self.mask_history and self.train_on_prompt:
|
||||
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -21,9 +21,7 @@ from datasets import DownloadMode
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the evaluation parameters."""
|
||||
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."},
|
||||
|
||||
@@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
r"""Arguments pertaining to the freeze (partial-parameter) training."""
|
||||
|
||||
freeze_trainable_layers: int = field(
|
||||
default=2,
|
||||
@@ -56,9 +54,7 @@ class FreezeArguments:
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -128,9 +124,7 @@ class LoraArguments:
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO, DPO and KTO training.
|
||||
"""
|
||||
r"""Arguments pertaining to the PPO, DPO and KTO training."""
|
||||
|
||||
pref_beta: float = field(
|
||||
default=0.1,
|
||||
@@ -212,9 +206,7 @@ class RLHFArguments:
|
||||
|
||||
@dataclass
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the GaLore algorithm."""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
@@ -253,9 +245,7 @@ class GaloreArguments:
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
r"""Arguments pertaining to the APOLLO algorithm."""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
@@ -306,9 +296,7 @@ class ApolloArguments:
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
Arguments pertaining to the BAdam optimizer.
|
||||
"""
|
||||
r"""Arguments pertaining to the BAdam optimizer."""
|
||||
|
||||
use_badam: bool = field(
|
||||
default=False,
|
||||
@@ -393,9 +381,7 @@ class SwanLabArguments:
|
||||
class FinetuningArguments(
|
||||
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
|
||||
|
||||
pure_bf16: bool = field(
|
||||
default=False,
|
||||
@@ -452,13 +438,13 @@ class FinetuningArguments(
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
@@ -499,7 +485,7 @@ class FinetuningArguments(
|
||||
if self.pissa_init:
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
r"""Arguments pertaining to specify the decoding parameters."""
|
||||
|
||||
do_sample: bool = field(
|
||||
default=True,
|
||||
@@ -35,7 +33,9 @@ class GeneratingArguments:
|
||||
top_p: float = field(
|
||||
default=0.7,
|
||||
metadata={
|
||||
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
"help": (
|
||||
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
|
||||
)
|
||||
},
|
||||
)
|
||||
top_k: int = field(
|
||||
@@ -71,7 +71,7 @@ class GeneratingArguments:
|
||||
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
||||
)
|
||||
|
||||
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
|
||||
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
args.pop("max_length", None)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -28,9 +28,7 @@ from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model.
|
||||
"""
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -184,9 +182,7 @@ class BaseModelArguments:
|
||||
|
||||
@dataclass
|
||||
class QuantizationArguments:
|
||||
r"""
|
||||
Arguments pertaining to the quantization method.
|
||||
"""
|
||||
r"""Arguments pertaining to the quantization method."""
|
||||
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
@@ -212,9 +208,7 @@ class QuantizationArguments:
|
||||
|
||||
@dataclass
|
||||
class ProcessorArguments:
|
||||
r"""
|
||||
Arguments pertaining to the image processor.
|
||||
"""
|
||||
r"""Arguments pertaining to the image processor."""
|
||||
|
||||
image_max_pixels: int = field(
|
||||
default=768 * 768,
|
||||
@@ -244,9 +238,7 @@ class ProcessorArguments:
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
r"""
|
||||
Arguments pertaining to the model export.
|
||||
"""
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -292,9 +284,7 @@ class ExportArguments:
|
||||
|
||||
@dataclass
|
||||
class VllmArguments:
|
||||
r"""
|
||||
Arguments pertaining to the vLLM worker.
|
||||
"""
|
||||
r"""Arguments pertaining to the vLLM worker."""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=4096,
|
||||
@@ -324,8 +314,7 @@ class VllmArguments:
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
@@ -335,7 +324,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
@@ -372,7 +361,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
args = asdict(self)
|
||||
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
|
||||
return args
|
||||
|
||||
@@ -19,7 +19,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -47,17 +47,15 @@ check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
|
||||
r"""
|
||||
Gets arguments from the command line or a config file.
|
||||
"""
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
|
||||
@@ -70,8 +68,8 @@ def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> Tuple[Any]:
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
|
||||
@@ -161,31 +159,31 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -364,9 +362,7 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
|
||||
)
|
||||
|
||||
# Post-process model arguments
|
||||
@@ -382,20 +378,17 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
|
||||
# Log on each process the small summary
|
||||
logger.info(
|
||||
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
|
||||
training_args.process_index,
|
||||
training_args.world_size,
|
||||
training_args.device,
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
f"Process rank: {training_args.process_index}, "
|
||||
f"world size: {training_args.world_size}, device: {training_args.device}, "
|
||||
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
|
||||
f"compute dtype: {str(model_args.compute_dtype)}"
|
||||
)
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
@@ -426,7 +419,7 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
_set_transformers_logging()
|
||||
|
||||
@@ -10,9 +10,7 @@ from ..extras.misc import use_ray
|
||||
|
||||
@dataclass
|
||||
class RayArguments:
|
||||
r"""
|
||||
Arguments pertaining to the Ray training.
|
||||
"""
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -43,9 +41,7 @@ class RayArguments:
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""
|
||||
Arguments pertaining to the trainer.
|
||||
"""
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
|
||||
@@ -20,9 +20,9 @@ from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
__all__ = [
|
||||
"QuantizationMethod",
|
||||
"find_all_linear_modules",
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"find_all_linear_modules",
|
||||
"load_valuehead_params",
|
||||
]
|
||||
|
||||
@@ -81,9 +81,8 @@ def _setup_freeze_tuning(
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.freeze_trainable_layers != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.freeze_trainable_layers
|
||||
)
|
||||
f"`num_layers` {num_layers} should be "
|
||||
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.freeze_trainable_layers
|
||||
@@ -178,7 +177,7 @@ def _setup_lora_tuning(
|
||||
}
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
@@ -263,8 +262,7 @@ def init_adapter(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
r"""Initialize the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
@@ -51,9 +51,8 @@ class TokenizerModule(TypedDict):
|
||||
processor: Optional["ProcessorMixin"]
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
|
||||
r"""Get arguments to load config/tokenizer/model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@@ -68,8 +67,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer and optionally loads processor.
|
||||
r"""Load pretrained tokenizer and optionally loads processor.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@@ -110,9 +108,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
r"""
|
||||
Loads model config.
|
||||
"""
|
||||
r"""Load model config."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
@@ -124,9 +120,7 @@ def load_model(
|
||||
is_trainable: bool = False,
|
||||
add_valuehead: bool = False,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads pretrained model.
|
||||
"""
|
||||
r"""Load pretrained model."""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
@@ -194,8 +188,9 @@ def load_model(
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
param_stats = (
|
||||
f"trainable params: {trainable_params:,} || "
|
||||
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
|
||||
)
|
||||
else:
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
import inspect
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||
r"""
|
||||
Saves VRAM by smartly offloading to RAM.
|
||||
"""
|
||||
r"""Saves VRAM by smartly offloading to RAM."""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@@ -77,13 +75,11 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
|
||||
|
||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||
r"""
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
r"""Only applies gradient checkpointing to trainable layers."""
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
@@ -103,11 +99,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
def _gradient_checkpointing_enable(
|
||||
self: "PreTrainedModel",
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
|
||||
use_unsloth_gc: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
r"""Activates gradient checkpointing for the current model.
|
||||
|
||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||
"""
|
||||
@@ -134,17 +129,18 @@ def _gradient_checkpointing_enable(
|
||||
|
||||
|
||||
def _fp32_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
r"""Prepare the model before training.
|
||||
|
||||
Include:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) add the upcasting of the lm_head in fp32.
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
|
||||
@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
r"""Resize token embeddings."""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -54,14 +54,14 @@ def llama_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
attn_output: torch.Tensor = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
if output_attentions:
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
from .visual import COMPOSITE_MODELS
|
||||
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply LoRA, GaLore or APOLLO.
|
||||
"""
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
|
||||
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
forbidden_modules = {"lm_head"}
|
||||
if model_type == "chatglm":
|
||||
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
|
||||
r"""Find the modules in the expanded blocks to apply lora."""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@@ -34,9 +35,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
|
||||
|
||||
|
||||
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
r"""
|
||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||
"""
|
||||
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
return
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Gets the sequnce lengths in the current batch.
|
||||
r"""Get the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
return seqlens
|
||||
|
||||
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""
|
||||
Prepares the indices and seqlens for flash attn varlen function.
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""Prepare the indices and seqlens for flash attn varlen function.
|
||||
|
||||
Returns:
|
||||
indices: indices of non-masked tokens from the flattened sequence.
|
||||
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
```
|
||||
|
||||
"""
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
import os
|
||||
import random
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
@unique
|
||||
class QuantizationMethod(str, Enum):
|
||||
r"""
|
||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||
"""
|
||||
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
|
||||
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
|
||||
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
data_files = model_args.export_quantization_dataset
|
||||
@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
@@ -101,11 +97,9 @@ def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
init_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
@@ -113,7 +107,7 @@ def configure_quantization(
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if quant_method == QuantizationMethod.GPTQ:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_current_device
|
||||
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
"max_seq_length": model_args.model_max_length or 4096,
|
||||
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
|
||||
def load_unsloth_pretrained_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> Optional["PreTrainedModel"]:
|
||||
r"""
|
||||
Optionally loads pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Optionally load pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||
try:
|
||||
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
|
||||
|
||||
|
||||
def get_unsloth_peft_model(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Gets the peft model for the pretrained model with unsloth. Used in training.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_peft_kwargs = {
|
||||
"model": model,
|
||||
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
|
||||
def load_unsloth_peft_model(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads peft model with unsloth. Used in both training and inference.
|
||||
"""
|
||||
from unsloth import FastLanguageModel
|
||||
r"""Load peft model with unsloth. Used in both training and inference."""
|
||||
from unsloth import FastLanguageModel # type: ignore
|
||||
|
||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||
try:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers.utils import cached_file
|
||||
@@ -30,9 +30,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
|
||||
r"""Load value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
class CompositeModel:
|
||||
model_type: str
|
||||
projector_key: str
|
||||
vision_model_keys: List[str]
|
||||
language_model_keys: List[str]
|
||||
lora_conflict_keys: List[str]
|
||||
vision_model_keys: list[str]
|
||||
language_model_keys: list[str]
|
||||
lora_conflict_keys: list[str]
|
||||
|
||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
||||
for key in self.projector_key.split("."):
|
||||
@@ -51,15 +52,15 @@ class CompositeModel:
|
||||
return module
|
||||
|
||||
|
||||
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
|
||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
|
||||
|
||||
def _register_composite_model(
|
||||
model_type: str,
|
||||
projector_key: Optional[str] = None,
|
||||
vision_model_keys: Optional[List[str]] = None,
|
||||
language_model_keys: Optional[List[str]] = None,
|
||||
lora_conflict_keys: Optional[List[str]] = None,
|
||||
vision_model_keys: Optional[list[str]] = None,
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
):
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
@@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
|
||||
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for fine-tuning quantized VLMs.
|
||||
"""
|
||||
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
@@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
r"""
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
r"""Patch VLMs before loading them."""
|
||||
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
|
||||
# required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
@@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||
r"""
|
||||
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||
"""
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
|
||||
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
forbidden_modules = set()
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
@@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
|
||||
|
||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the number of special tokens per image.
|
||||
"""
|
||||
r"""Compute the number of special tokens per image."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type == "llava":
|
||||
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||
@@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
r"""Compute the patch size of the vit."""
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
r"""Get the vision_feature_select_strategy."""
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
@@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
|
||||
|
||||
def patch_target_modules(
|
||||
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Freezes vision tower for VLM LoRA tuning.
|
||||
"""
|
||||
) -> list[str]:
|
||||
r"""Freezes vision tower for VLM LoRA tuning."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@@ -93,7 +93,7 @@ def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
init_kwargs: dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
|
||||
@@ -19,7 +19,7 @@ import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
r"""Fix the valuehead checkpoint files.
|
||||
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
decoder_state_dict, v_head_state_dict = {}, {}
|
||||
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for fixing the checkpoint for valuehead models.
|
||||
"""
|
||||
r"""A callback for fixing the checkpoint for valuehead models."""
|
||||
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for saving the processor.
|
||||
"""
|
||||
r"""A callback for saving the processor."""
|
||||
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
self.processor = processor
|
||||
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
r"""A callback for converting the PiSSA adapter to a normal one."""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
@@ -177,9 +172,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for logging training and evaluation status.
|
||||
"""
|
||||
r"""A callback for logging training and evaluation status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Progress
|
||||
@@ -188,7 +181,7 @@ class LogCallback(TrainerCallback):
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
# Status
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
@@ -219,7 +212,7 @@ class LogCallback(TrainerCallback):
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
@@ -348,9 +341,7 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""
|
||||
A callback for reporting training status to external logger.
|
||||
"""
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -129,15 +129,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||
log_odds = (chosen_logps - rejected_logps) - (
|
||||
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
||||
)
|
||||
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return orpo_loss
|
||||
|
||||
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes SimPO loss for batched log probabilities of the policy model.
|
||||
"""
|
||||
r"""Compute SimPO loss for batched log probabilities of the policy model."""
|
||||
pi_logratios = chosen_logps - rejected_logps
|
||||
gamma_logratios = self.simpo_gamma / self.beta
|
||||
logits = pi_logratios - gamma_logratios
|
||||
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
policy_rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: Optional["torch.Tensor"],
|
||||
reference_rejected_logps: Optional["torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes loss for preference learning.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute loss for preference learning."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
if self.loss_type == "orpo":
|
||||
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||
all_logps = all_logps / valid_length
|
||||
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if not self.finetuning_args.use_ref_model:
|
||||
return None, None
|
||||
|
||||
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -38,7 +38,7 @@ def run_dpo(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -19,7 +19,7 @@ import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -120,9 +120,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
@@ -130,18 +128,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
r"""
|
||||
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||
"""
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Run forward pass and computes the log probabilities."""
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
|
||||
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute log probabilities of the reference model."""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
batch: dict[str, "torch.Tensor"],
|
||||
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
|
||||
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Subclass and override to accept extra kwargs.
|
||||
"""
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Subclass and override to accept extra kwargs."""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@override
|
||||
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||
r"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
"""
|
||||
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
|
||||
r"""Log `logs` on the various objects watching training, including stored metrics."""
|
||||
# logs either has "loss" or "eval_loss"
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
|
||||
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||
if f"count/{split}" in metric_dict:
|
||||
for key in ("rewards", "logps", "logits"):
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -37,7 +37,7 @@ def run_kto(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@@ -31,10 +31,8 @@ if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Gets reward scores from the API server.
|
||||
"""
|
||||
def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
|
||||
r"""Get reward scores from the API server."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {"model": "model", "messages": messages}
|
||||
response = requests.post(server_url, json=payload, headers=headers)
|
||||
@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
r"""
|
||||
Replaces the default/reward modules in the model. The model is already unwrapped.
|
||||
"""
|
||||
r"""Replace the default/reward modules in the model. The model is already unwrapped."""
|
||||
v_head_layer = model.v_head.summary
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
|
||||
r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.data.dtype == torch.float32:
|
||||
@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
return layer_norm_params
|
||||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""
|
||||
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||
"""
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
|
||||
r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
r"""Inherit PPOTrainer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]],
|
||||
callbacks: Optional[list["TrainerCallback"]],
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
|
||||
if resume_from_checkpoint is not None:
|
||||
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||
|
||||
@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info_rank0(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
|
||||
)
|
||||
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
@@ -339,21 +333,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
return lr_scheduler
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
|
||||
r"""Generate model's responses given queries."""
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
@@ -381,11 +373,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List["torch.Tensor"],
|
||||
responses: List["torch.Tensor"],
|
||||
) -> List["torch.Tensor"]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
queries: list["torch.Tensor"],
|
||||
responses: list["torch.Tensor"],
|
||||
) -> list["torch.Tensor"]:
|
||||
r"""Compute scores using given reward model.
|
||||
|
||||
Both inputs and outputs are put on CPU.
|
||||
"""
|
||||
@@ -394,8 +385,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
@@ -404,7 +395,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
@@ -419,12 +410,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
queries: "torch.Tensor",
|
||||
responses: "torch.Tensor",
|
||||
model_inputs: Dict[str, Any],
|
||||
model_inputs: dict[str, Any],
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional["torch.Tensor"] = None,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
|
||||
r"""Calculate model outputs in multiple batches.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -483,8 +473,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
r"""Save model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -508,5 +497,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
|
||||
self._save(output_dir, state_dict=unwrapped_model.state_dict())
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@@ -37,7 +37,7 @@ def run_ppo(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -53,7 +53,7 @@ def run_ppo(
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||
ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
|
||||
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer for custom optimizer.
|
||||
"""
|
||||
r"""Inherit Trainer for custom optimizer."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
@@ -38,7 +38,7 @@ def run_pt(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -41,7 +39,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer to compute pairwise loss.
|
||||
"""
|
||||
r"""Inherits Trainer to compute pairwise loss."""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
|
||||
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
res: list[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.ploting import plot_loss
|
||||
@@ -37,7 +37,7 @@ def run_rm(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -45,9 +45,7 @@ if is_rouge_available():
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes the token with the largest likelihood to reduce memory footprint.
|
||||
"""
|
||||
r"""Compute the token with the largest likelihood to reduce memory footprint."""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
r"""Compute accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -77,7 +73,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
for i in range(len(preds)):
|
||||
pred, label = preds[i, :-1], labels[i, 1:]
|
||||
@@ -90,15 +86,14 @@ class ComputeAccuracy:
|
||||
|
||||
@dataclass
|
||||
class ComputeSimilarity:
|
||||
r"""
|
||||
Computes text similarity scores and supports `batch_eval_metrics`.
|
||||
r"""Compute text similarity scores and support `batch_eval_metrics`.
|
||||
|
||||
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -109,7 +104,7 @@ class ComputeSimilarity:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -44,21 +44,19 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@@ -99,13 +97,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
inputs: Dict[str, Union["torch.Tensor", Any]],
|
||||
inputs: dict[str, Union["torch.Tensor", Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
**gen_kwargs,
|
||||
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Removes the prompt part in the generated tokens.
|
||||
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""Remove the prompt part in the generated tokens.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
@@ -126,8 +123,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def save_predictions(
|
||||
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
r"""Save model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -43,7 +43,7 @@ def run_sft(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
@@ -43,7 +44,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
|
||||
linear_modules, extra_modules = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
@@ -83,7 +84,7 @@ def load_reference_model(
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
current_device = get_current_device()
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||
)
|
||||
if not is_trainable:
|
||||
@@ -111,7 +112,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
|
||||
|
||||
|
||||
def patch_valuehead_model() -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
@@ -21,7 +21,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
A dummy optimizer used for the GaLore or APOLLO algorithm.
|
||||
"""
|
||||
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
|
||||
|
||||
def __init__(
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
|
||||
) -> None:
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
self.optimizer_dict = optimizer_dict
|
||||
@@ -112,8 +110,7 @@ def create_modelcard_and_push(
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
|
||||
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
@@ -148,9 +145,7 @@ def create_ref_model(
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
r"""Create reward model for PPO training."""
|
||||
if finetuning_args.reward_model_type == "api":
|
||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
|
||||
@@ -189,10 +184,8 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
"""
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
|
||||
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
|
||||
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
return decay_parameters
|
||||
@@ -208,7 +201,7 @@ def _create_galore_optimizer(
|
||||
else:
|
||||
galore_targets = finetuning_args.galore_target
|
||||
|
||||
galore_params: List["torch.nn.Parameter"] = []
|
||||
galore_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
for param in module.parameters():
|
||||
@@ -224,7 +217,7 @@ def _create_galore_optimizer(
|
||||
|
||||
id_galore_params = {id(param) for param in galore_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -251,7 +244,7 @@ def _create_galore_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
|
||||
else:
|
||||
apollo_targets = finetuning_args.apollo_target
|
||||
|
||||
apollo_params: List["torch.nn.Parameter"] = []
|
||||
apollo_params: list[torch.nn.Parameter] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
|
||||
for param in module.parameters():
|
||||
@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-apollo parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
|
||||
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
param_dict: dict[str, list[torch.nn.Parameter]] = {
|
||||
"lora_a": [],
|
||||
"lora_b": [],
|
||||
"lora_b_nodecay": [],
|
||||
@@ -524,7 +517,7 @@ def create_custom_scheduler(
|
||||
) -> None:
|
||||
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
@@ -544,13 +537,13 @@ def create_custom_scheduler(
|
||||
|
||||
def get_batch_logps(
|
||||
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes the log probabilities of the given labels under the given logits.
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
|
||||
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
|
||||
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
|
||||
@@ -564,12 +557,10 @@ def get_batch_logps(
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
@@ -585,9 +576,7 @@ def nested_detach(
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
r"""Get the callback for logging to SwanLab."""
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
@@ -624,7 +613,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
|
||||
|
||||
def get_ray_trainer(
|
||||
training_function: Callable,
|
||||
train_loop_config: Dict[str, Any],
|
||||
train_loop_config: dict[str, Any],
|
||||
ray_args: "RayArguments",
|
||||
) -> "TorchTrainer":
|
||||
if not ray_args.use_ray:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -48,9 +48,9 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _training_function(config: Dict[str, Any]) -> None:
|
||||
def _training_function(config: dict[str, Any]) -> None:
|
||||
args = config.get("args")
|
||||
callbacks: List[Any] = config.get("callbacks")
|
||||
callbacks: list[Any] = config.get("callbacks")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
|
||||
callbacks.append(LogCallback())
|
||||
@@ -84,7 +84,7 @@ def _training_function(config: Dict[str, Any]) -> None:
|
||||
logger.warning(f"Failed to destroy process group: {e}.")
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
|
||||
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
|
||||
args = read_args(args)
|
||||
if "-h" in args or "--help" in args:
|
||||
get_train_args(args)
|
||||
@@ -103,7 +103,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
_training_function(config={"args": args, "callbacks": callbacks})
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
@@ -37,15 +38,12 @@ if is_gradio_available():
|
||||
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
r"""
|
||||
Escapes HTML characters.
|
||||
"""
|
||||
r"""Escape HTML characters."""
|
||||
return text.replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
|
||||
r"""
|
||||
Post-processes the response text.
|
||||
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
|
||||
r"""Post-process the response text.
|
||||
|
||||
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
|
||||
"""
|
||||
@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.engine: Optional["BaseEngine"] = None
|
||||
self.engine: Optional[BaseEngine] = None
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
@@ -160,14 +158,13 @@ class WebChatModel(ChatModel):
|
||||
|
||||
@staticmethod
|
||||
def append(
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
chatbot: list[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
escape_html: bool,
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
|
||||
r"""
|
||||
Adds the user input to chatbot.
|
||||
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
|
||||
r"""Add the user input to chatbot.
|
||||
|
||||
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
|
||||
Output: infer.chatbot, infer.messages, infer.query
|
||||
@@ -180,8 +177,8 @@ class WebChatModel(ChatModel):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
chatbot: List[Dict[str, str]],
|
||||
messages: List[Dict[str, str]],
|
||||
chatbot: list[dict[str, str]],
|
||||
messages: list[dict[str, str]],
|
||||
lang: str,
|
||||
system: str,
|
||||
tools: str,
|
||||
@@ -193,9 +190,8 @@ class WebChatModel(ChatModel):
|
||||
temperature: float,
|
||||
skip_special_tokens: bool,
|
||||
escape_html: bool,
|
||||
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
|
||||
r"""
|
||||
Generates output text in stream.
|
||||
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
|
||||
r"""Generate output text in stream.
|
||||
|
||||
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
|
||||
Output: infer.chatbot, infer.messages
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import signal
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from psutil import Process
|
||||
from yaml import safe_dump, safe_load
|
||||
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def abort_process(pid: int) -> None:
|
||||
r"""
|
||||
Aborts the processes recursively in a bottom-up way.
|
||||
"""
|
||||
r"""Abort the processes recursively in a bottom-up way."""
|
||||
try:
|
||||
children = Process(pid).children()
|
||||
if children:
|
||||
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
|
||||
|
||||
|
||||
def get_save_dir(*paths: str) -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to saved model checkpoints.
|
||||
"""
|
||||
r"""Get the path to saved model checkpoints."""
|
||||
if os.path.sep in paths[-1]:
|
||||
logger.warning_rank0("Found complex path, some features may be not available.")
|
||||
return paths[-1]
|
||||
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
|
||||
|
||||
|
||||
def _get_config_path() -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to user config.
|
||||
"""
|
||||
r"""Get the path to user config."""
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
r"""
|
||||
Loads user config if exists.
|
||||
"""
|
||||
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
r"""Load user config if exists."""
|
||||
try:
|
||||
with open(_get_config_path(), encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves user config.
|
||||
"""
|
||||
r"""Save user config."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the model path according to the model name.
|
||||
"""
|
||||
r"""Get the model path according to the model name."""
|
||||
user_config = load_config()
|
||||
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
|
||||
if (
|
||||
use_modelscope()
|
||||
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the template name if the model is a chat/distill/instruct model.
|
||||
"""
|
||||
r"""Get the template name if the model is a chat/distill/instruct model."""
|
||||
return DEFAULT_TEMPLATE.get(model_name, "default")
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
r"""
|
||||
Gets current date and time.
|
||||
"""
|
||||
r"""Get current date and time."""
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def is_multimodal(model_name: str) -> bool:
|
||||
r"""
|
||||
Judges if the model is a vision language model.
|
||||
"""
|
||||
r"""Judge if the model is a vision language model."""
|
||||
return model_name in MULTIMODAL_SUPPORTED_MODELS
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
r"""
|
||||
Loads dataset_info.json.
|
||||
"""
|
||||
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
|
||||
r"""Load dataset_info.json."""
|
||||
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
||||
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
|
||||
return {}
|
||||
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
return {}
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
r"""
|
||||
Loads the training configuration from config path.
|
||||
"""
|
||||
def load_args(config_path: str) -> Optional[dict[str, Any]]:
|
||||
r"""Load the training configuration from config path."""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
|
||||
r"""
|
||||
Saves the training configuration to config path.
|
||||
"""
|
||||
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
|
||||
r"""Save the training configuration to config path."""
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
|
||||
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Removes args with NoneType or False or empty string value.
|
||||
"""
|
||||
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
|
||||
r"""Remove args with NoneType or False or empty string value."""
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
r"""
|
||||
Generates CLI commands for previewing.
|
||||
"""
|
||||
def gen_cmd(args: dict[str, Any]) -> str:
|
||||
r"""Generate CLI commands for previewing."""
|
||||
cmd_lines = ["llamafactory-cli train "]
|
||||
for k, v in _clean_cmd(args).items():
|
||||
if isinstance(v, dict):
|
||||
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
return cmd_text
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
r"""
|
||||
Saves CLI commands to launch training.
|
||||
"""
|
||||
def save_cmd(args: dict[str, Any]) -> str:
|
||||
r"""Save CLI commands to launch training."""
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
|
||||
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
|
||||
|
||||
|
||||
def load_eval_results(path: os.PathLike) -> str:
|
||||
r"""
|
||||
Gets scores after evaluation.
|
||||
"""
|
||||
r"""Get scores after evaluation."""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
|
||||
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
|
||||
|
||||
|
||||
def create_ds_config() -> None:
|
||||
r"""
|
||||
Creates deepspeed config in the current directory.
|
||||
"""
|
||||
r"""Create deepspeed config in the current directory."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
ds_config = {
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...data import Role
|
||||
from ...extras.packages import is_gradio_available
|
||||
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
r"""
|
||||
Checks if the json schema is valid.
|
||||
"""
|
||||
r"""Check if the json schema is valid."""
|
||||
try:
|
||||
tools = json.loads(text)
|
||||
if tools:
|
||||
@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
|
||||
) -> tuple["Component", "Component", dict[str, "Component"]]:
|
||||
lang = engine.manager.get_elem_by_id("top.lang")
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ...extras.constants import DATA_CONFIG
|
||||
from ...extras.packages import is_gradio_available
|
||||
@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
r"""
|
||||
Checks if the dataset is a local dataset.
|
||||
"""
|
||||
r"""Check if the dataset is a local dataset."""
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def _load_data_file(file_path: str) -> List[Any]:
|
||||
def _load_data_file(file_path: str) -> list[Any]:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
if file_path.endswith(".json"):
|
||||
return json.load(f)
|
||||
@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
|
||||
return list(f)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
r"""
|
||||
Gets the preview samples from the dataset.
|
||||
"""
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
|
||||
r"""Get the preview samples from the dataset."""
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Union
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
|
||||
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
|
||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
else:
|
||||
@@ -47,7 +48,7 @@ def save_model(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
finetuning_type: str,
|
||||
checkpoint_path: Union[str, List[str]],
|
||||
checkpoint_path: Union[str, list[str]],
|
||||
template: str,
|
||||
export_size: int,
|
||||
export_quantization_bit: str,
|
||||
@@ -106,7 +107,7 @@ def save_model(
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import is_multimodal
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...data import TEMPLATES
|
||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
def create_top() -> dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
lang = engine.manager.get_elem_by_id("top.lang")
|
||||
model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name")
|
||||
finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type")
|
||||
model_name: gr.Dropdown = engine.manager.get_elem_by_id("top.model_name")
|
||||
finetuning_type: gr.Dropdown = engine.manager.get_elem_by_id("top.finetuning_type")
|
||||
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_load_btn.click(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
@@ -39,8 +39,7 @@ if is_gradio_available():
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Judges if the quantization is available in this finetuning type.
|
||||
r"""Judge if the quantization is available in this finetuning type.
|
||||
|
||||
Inputs: top.finetuning_type
|
||||
Outputs: top.quantization_bit
|
||||
@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
|
||||
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Gets the available quantization bits.
|
||||
r"""Get the available quantization bits.
|
||||
|
||||
Inputs: top.quantization_method
|
||||
Outputs: top.quantization_bit
|
||||
@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
|
||||
return gr.Dropdown(choices=available_bits)
|
||||
|
||||
|
||||
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
|
||||
r"""
|
||||
Modifys states after changing the training stage.
|
||||
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> tuple[list[str], bool]:
|
||||
r"""Modify states after changing the training stage.
|
||||
|
||||
Inputs: train.training_stage
|
||||
Outputs: train.dataset, train.packing
|
||||
@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
|
||||
return [], TRAINING_STAGES[training_stage] == "pt"
|
||||
|
||||
|
||||
def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
r"""
|
||||
Gets the necessary information of this model.
|
||||
def get_model_info(model_name: str) -> tuple[str, str]:
|
||||
r"""Get the necessary information of this model.
|
||||
|
||||
Inputs: top.model_name
|
||||
Outputs: top.model_path, top.template
|
||||
@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
return get_model_path(model_name), get_template(model_name)
|
||||
|
||||
|
||||
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
|
||||
r"""
|
||||
Gets training infomation for monitor.
|
||||
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
|
||||
r"""Get training infomation for monitor.
|
||||
|
||||
If do_train is True:
|
||||
Inputs: top.lang, train.output_path
|
||||
@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
trainer_log: list[dict[str, Any]] = []
|
||||
with open(trainer_log_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
trainer_log.append(json.loads(line))
|
||||
@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
|
||||
|
||||
|
||||
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all available checkpoints.
|
||||
r"""List all available checkpoints.
|
||||
|
||||
Inputs: top.model_name, top.finetuning_type
|
||||
Outputs: top.checkpoint_path
|
||||
@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
|
||||
def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the saved configuration files.
|
||||
r"""List all the saved configuration files.
|
||||
|
||||
Inputs: train.current_time
|
||||
Outputs: train.config_path
|
||||
@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
|
||||
|
||||
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all available datasets in the dataset dir for the training stage.
|
||||
r"""List all available datasets in the dataset dir for the training stage.
|
||||
|
||||
Inputs: *.dataset_dir, *.training_stage
|
||||
Outputs: *.dataset
|
||||
@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
|
||||
|
||||
|
||||
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the directories that can resume from.
|
||||
r"""List all the directories that can resume from.
|
||||
|
||||
Inputs: top.model_name, top.finetuning_type, train.current_time
|
||||
Outputs: train.output_dir
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .chatter import WebChatModel
|
||||
from .common import create_ds_config, get_time, load_config
|
||||
@@ -26,9 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Engine:
|
||||
r"""
|
||||
A general engine to control the behaviors of Web UI.
|
||||
"""
|
||||
r"""A general engine to control the behaviors of Web UI."""
|
||||
|
||||
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
@@ -39,11 +37,9 @@ class Engine:
|
||||
if not demo_mode:
|
||||
create_ds_config()
|
||||
|
||||
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
|
||||
r"""
|
||||
Updates gradio components according to the (elem_id, properties) mapping.
|
||||
"""
|
||||
output_dict: Dict["Component", "Component"] = {}
|
||||
def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]:
|
||||
r"""Update gradio components according to the (elem_id, properties) mapping."""
|
||||
output_dict: dict[Component, Component] = {}
|
||||
for elem_id, elem_attr in input_dict.items():
|
||||
elem = self.manager.get_elem_by_id(elem_id)
|
||||
output_dict[elem] = elem.__class__(**elem_attr)
|
||||
@@ -51,9 +47,7 @@ class Engine:
|
||||
return output_dict
|
||||
|
||||
def resume(self):
|
||||
r"""
|
||||
Gets the initial value of gradio components and restores training status if necessary.
|
||||
"""
|
||||
r"""Get the initial value of gradio components and restores training status if necessary."""
|
||||
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||
@@ -79,9 +73,7 @@ class Engine:
|
||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||
|
||||
def change_lang(self, lang: str):
|
||||
r"""
|
||||
Updates the displayed language of gradio components.
|
||||
"""
|
||||
r"""Update the displayed language of gradio components."""
|
||||
return {
|
||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||
for elem_name, elem in self.manager.get_elem_iter()
|
||||
|
||||
@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.add_elems("top", create_top())
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
|
||||
lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.add_elems("train", create_train_tab(engine))
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -20,54 +21,41 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Manager:
|
||||
r"""
|
||||
A class to manage all the gradio components in Web UI.
|
||||
"""
|
||||
r"""A class to manage all the gradio components in Web UI."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._id_to_elem: Dict[str, "Component"] = {}
|
||||
self._elem_to_id: Dict["Component", str] = {}
|
||||
self._id_to_elem: dict[str, Component] = {}
|
||||
self._elem_to_id: dict[Component, str] = {}
|
||||
|
||||
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
|
||||
r"""
|
||||
Adds elements to manager.
|
||||
"""
|
||||
def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None:
|
||||
r"""Add elements to manager."""
|
||||
for elem_name, elem in elem_dict.items():
|
||||
elem_id = f"{tab_name}.{elem_name}"
|
||||
self._id_to_elem[elem_id] = elem
|
||||
self._elem_to_id[elem] = elem_id
|
||||
|
||||
def get_elem_list(self) -> List["Component"]:
|
||||
r"""
|
||||
Returns the list of all elements.
|
||||
"""
|
||||
def get_elem_list(self) -> list["Component"]:
|
||||
r"""Return the list of all elements."""
|
||||
return list(self._id_to_elem.values())
|
||||
|
||||
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
|
||||
r"""
|
||||
Returns an iterator over all elements with their names.
|
||||
"""
|
||||
def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]:
|
||||
r"""Return an iterator over all elements with their names."""
|
||||
for elem_id, elem in self._id_to_elem.items():
|
||||
yield elem_id.split(".")[-1], elem
|
||||
|
||||
def get_elem_by_id(self, elem_id: str) -> "Component":
|
||||
r"""
|
||||
Gets element by id.
|
||||
r"""Get element by id.
|
||||
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
return self._id_to_elem[elem_id]
|
||||
|
||||
def get_id_by_elem(self, elem: "Component") -> str:
|
||||
r"""
|
||||
Gets id by element.
|
||||
"""
|
||||
r"""Get id by element."""
|
||||
return self._elem_to_id[elem]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
r"""
|
||||
Gets the base elements that are commonly used.
|
||||
"""
|
||||
def get_base_elems(self) -> set["Component"]:
|
||||
r"""Get the base elements that are commonly used."""
|
||||
return {
|
||||
self._id_to_elem["top.lang"],
|
||||
self._id_to_elem["top.model_name"],
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
@@ -51,17 +52,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Runner:
|
||||
r"""
|
||||
A class to manage the running status of the trainers.
|
||||
"""
|
||||
r"""A class to manage the running status of the trainers."""
|
||||
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
|
||||
r"""Init a runner."""
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.trainer: Optional[Popen] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.running_data: dict[Component, Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
@@ -71,10 +71,8 @@ class Runner:
|
||||
if self.trainer is not None:
|
||||
abort_process(self.trainer.pid)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
r"""
|
||||
Validates the configuration.
|
||||
"""
|
||||
def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
r"""Validate the configuration."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
@@ -116,9 +114,7 @@ class Runner:
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
r"""
|
||||
Cleans the cached memory and resets the runner.
|
||||
"""
|
||||
r"""Clean the cached memory and resets the runner."""
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
gr.Info(finish_info)
|
||||
self.trainer = None
|
||||
@@ -128,10 +124,8 @@ class Runner:
|
||||
torch_gc()
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds and validates the training arguments.
|
||||
"""
|
||||
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the training arguments."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
@@ -291,10 +285,8 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds and validates the evaluation arguments.
|
||||
"""
|
||||
def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the evaluation arguments."""
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
@@ -345,10 +337,8 @@ class Runner:
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
r"""
|
||||
Previews the training commands.
|
||||
"""
|
||||
def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]:
|
||||
r"""Preview the training commands."""
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
@@ -358,10 +348,8 @@ class Runner:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield {output_box: gen_cmd(args)}
|
||||
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
|
||||
r"""
|
||||
Starts the training process.
|
||||
"""
|
||||
def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]:
|
||||
r"""Start the training process."""
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
@@ -383,10 +371,8 @@ class Runner:
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
|
||||
yield from self.monitor()
|
||||
|
||||
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
r"""
|
||||
Builds a dictionary containing the current training configuration.
|
||||
"""
|
||||
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build a dictionary containing the current training configuration."""
|
||||
config_dict = {}
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
@@ -409,9 +395,7 @@ class Runner:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self):
|
||||
r"""
|
||||
Monitors the training progress and logs.
|
||||
"""
|
||||
r"""Monitorgit the training progress and logs."""
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
@@ -469,9 +453,7 @@ class Runner:
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
r"""
|
||||
Saves the training configuration to config path.
|
||||
"""
|
||||
r"""Save the training configuration to config path."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
@@ -487,27 +469,23 @@ class Runner:
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
r"""
|
||||
Loads the training configuration from config path.
|
||||
"""
|
||||
r"""Load the training configuration from config path."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||
|
||||
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||
output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
||||
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
|
||||
r"""
|
||||
Restore the training status if output_dir exists.
|
||||
"""
|
||||
r"""Restore the training status if output_dir exists."""
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||
output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
|
||||
|
||||
Reference in New Issue
Block a user