update readme

Former-commit-id: b8d0170fe0d094acce85dcb5f91775e4685ee055
This commit is contained in:
hiyouga
2024-05-27 18:14:02 +08:00
parent b0d9966663
commit 97a23e1cbe
10 changed files with 71 additions and 62 deletions

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@@ -46,7 +46,7 @@ def preprocess_feedback_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
@@ -82,7 +82,7 @@ def preprocess_feedback_dataset(
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@@ -44,7 +44,7 @@ def preprocess_pairwise_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
@@ -70,7 +70,7 @@ def preprocess_pairwise_dataset(
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@@ -37,13 +37,13 @@ def preprocess_supervised_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")

View File

@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ...extras.constants import IMAGE_TOKEN
from ...extras.logging import get_logger
from ..utils import Role
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
@@ -37,7 +36,7 @@ def preprocess_unsupervised_dataset(
continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"]
if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
@@ -57,7 +56,7 @@ def preprocess_unsupervised_dataset(
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
model_inputs["input_ids"].append(input_ids)

View File

@@ -26,6 +26,7 @@ class Template:
format_separator: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
force_system: bool
@@ -209,6 +210,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
@@ -256,6 +258,7 @@ def _register_template(
format_separator=format_separator or default_separator_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
@@ -730,7 +733,7 @@ _register_template(
_register_template(
name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
@@ -738,7 +741,7 @@ _register_template(
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
@@ -766,7 +769,6 @@ _register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(slots=["<|function_output|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.",
stop_words=["<|end|>"],