modify style & little change

Former-commit-id: c988477d14dc656450d5fec31895781b7f9f7dce
This commit is contained in:
KUANGDD
2024-10-23 15:24:07 +08:00
parent 7d135bbdb8
commit d0889012c2
7 changed files with 45 additions and 25 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
import pytest
import torch
@@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]
@@ -185,13 +189,19 @@ def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
for key, value in message.items()} for message in MM_MESSAGES
{
key: value.replace(
"<image>",
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
+ "[IMG_END]",
)
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
# TODO works needed for pixtral plugin test & hack hf engine input below for now
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0)
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)