add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
@@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
@@ -152,6 +153,7 @@ class Template:
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
@override
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -195,7 +197,7 @@ class Llama2Template(Template):
|
||||
return encoded_messages
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
@@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||
|
||||
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
jinja_template = ""
|
||||
|
||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||
@@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
|
||||
Reference in New Issue
Block a user