modify style & little change
Former-commit-id: c988477d14dc656450d5fec31895781b7f9f7dce
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
@@ -185,13 +189,19 @@ def test_pixtral_plugin():
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
|
||||
for key, value in message.items()} for message in MM_MESSAGES
|
||||
{
|
||||
key: value.replace(
|
||||
"<image>",
|
||||
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
|
||||
+ "[IMG_END]",
|
||||
)
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
# TODO works needed for pixtral plugin test & hack hf engine input below for now
|
||||
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0)
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user