[inference] fix stop token for object detection (#6624)
* fix stop token * update minicpm data pipeline * fix npu qlora examples Former-commit-id: 844919fadaa8a61dfae47020971ea80730b2346f
This commit is contained in:
@@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
if repetition_penalty is not None
|
||||
else generating_args["repetition_penalty"],
|
||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
eos_token_id=template.get_stop_token_ids(tokenizer),
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -168,7 +168,7 @@ class VllmEngine(BaseEngine):
|
||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
@@ -154,11 +153,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
||||
if "image_bound" in features: # for minicpmv inputs
|
||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
|
||||
new_features = {"data": features}
|
||||
new_features.update({"labels": features["labels"]})
|
||||
features = new_features
|
||||
features["position_ids"] = (
|
||||
torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"])
|
||||
)
|
||||
return {"data": features, "labels": features["labels"]}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -269,9 +269,10 @@ class CpmVPlugin(BasePlugin):
|
||||
messages = deepcopy(messages)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||
|
||||
if len(videos) != 0:
|
||||
assert len(images) == 0, "Only support video and image sft seperately"
|
||||
max_slice_nums = 2
|
||||
use_image_id = False
|
||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||
@@ -286,10 +287,9 @@ class CpmVPlugin(BasePlugin):
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
|
||||
)
|
||||
|
||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||
|
||||
@@ -310,10 +310,7 @@ class CpmVPlugin(BasePlugin):
|
||||
final_text
|
||||
+ text_chunks[i]
|
||||
+ image_processor.get_slice_image_placeholder(
|
||||
image_sizes[0][i],
|
||||
i,
|
||||
max_slice_nums,
|
||||
use_image_id,
|
||||
image_sizes[0][i], i, max_slice_nums, use_image_id
|
||||
)
|
||||
)
|
||||
final_text += text_chunks[-1]
|
||||
@@ -338,7 +335,6 @@ class CpmVPlugin(BasePlugin):
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
@@ -351,6 +347,7 @@ class CpmVPlugin(BasePlugin):
|
||||
for valid_image_nums in valid_image_nums_ls:
|
||||
new_images.append(images[idx : idx + valid_image_nums])
|
||||
idx += valid_image_nums
|
||||
|
||||
images = new_images
|
||||
|
||||
image_inputs = image_processor(
|
||||
@@ -383,7 +380,6 @@ class CpmVPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
|
||||
for input_ids in batch_ids:
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||
@@ -424,8 +420,8 @@ class LlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
@@ -478,8 +474,8 @@ class LlavaNextPlugin(BasePlugin):
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
@@ -529,8 +525,8 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
else:
|
||||
image_seqlen = 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
@@ -586,8 +582,8 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
@@ -840,12 +836,12 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
@@ -89,6 +89,16 @@ class Template:
|
||||
"""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
|
||||
r"""
|
||||
Returns stop token ids.
|
||||
"""
|
||||
stop_token_ids = {tokenizer.eos_token_id}
|
||||
for token in self.stop_words:
|
||||
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
return list(stop_token_ids)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -205,7 +215,7 @@ def _register_template(
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Sequence[str] = [],
|
||||
stop_words: Optional[Sequence[str]] = None,
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = False,
|
||||
@@ -248,7 +258,7 @@ def _register_template(
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
stop_words=stop_words or [],
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
replace_jinja_template=replace_jinja_template,
|
||||
@@ -566,6 +576,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
_register_template(
|
||||
name="cpm_v",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
|
||||
@@ -79,6 +79,8 @@ class CustomTrainer(Trainer):
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
|
||||
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
|
||||
@@ -94,6 +94,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
|
||||
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
|
||||
@@ -19,6 +19,7 @@ from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||
@@ -172,6 +173,7 @@ class Runner:
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
|
||||
# freeze config
|
||||
if args["finetuning_type"] == "freeze":
|
||||
|
||||
Reference in New Issue
Block a user