add docstrings, refactor logger

Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 93d4570a59
commit 7f71276ad8
30 changed files with 334 additions and 57 deletions

View File

@@ -16,6 +16,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger
from .data_utils import Role
@@ -152,6 +153,7 @@ class Template:
@dataclass
class Llama2Template(Template):
@override
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@@ -195,7 +197,7 @@ class Llama2Template(Template):
return encoded_messages
TEMPLATES: Dict[str, Template] = {}
TEMPLATES: Dict[str, "Template"] = {}
def _register_template(
@@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"