|
| 1 | +import torch |
| 2 | + |
| 3 | +from vllm.multimodal.base import MultiModalInputs, NestedTensors |
| 4 | + |
| 5 | + |
| 6 | +def assert_nested_tensors_equal(expected: NestedTensors, |
| 7 | + actual: NestedTensors): |
| 8 | + assert type(expected) == type(actual) |
| 9 | + if isinstance(expected, torch.Tensor): |
| 10 | + assert torch.equal(expected, actual) |
| 11 | + else: |
| 12 | + for expected_item, actual_item in zip(expected, actual): |
| 13 | + assert_nested_tensors_equal(expected_item, actual_item) |
| 14 | + |
| 15 | + |
| 16 | +def assert_multimodal_inputs_equal(expected: MultiModalInputs, |
| 17 | + actual: MultiModalInputs): |
| 18 | + assert set(expected.keys()) == set(actual.keys()) |
| 19 | + for key in expected: |
| 20 | + assert_nested_tensors_equal(expected[key], actual[key]) |
| 21 | + |
| 22 | + |
| 23 | +def test_multimodal_input_batch_single_tensor(): |
| 24 | + t = torch.rand([1, 2]) |
| 25 | + result = MultiModalInputs.batch([{"image": t}]) |
| 26 | + assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) |
| 27 | + |
| 28 | + |
| 29 | +def test_multimodal_input_batch_multiple_tensors(): |
| 30 | + a = torch.rand([1, 1, 2]) |
| 31 | + b = torch.rand([1, 1, 2]) |
| 32 | + c = torch.rand([1, 1, 2]) |
| 33 | + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) |
| 34 | + assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) |
| 35 | + |
| 36 | + |
| 37 | +def test_multimodal_input_batch_multiple_heterogeneous_tensors(): |
| 38 | + a = torch.rand([1, 2, 2]) |
| 39 | + b = torch.rand([1, 3, 2]) |
| 40 | + c = torch.rand([1, 4, 2]) |
| 41 | + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) |
| 42 | + assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) |
| 43 | + |
| 44 | + |
| 45 | +def test_multimodal_input_batch_nested_tensors(): |
| 46 | + a = torch.rand([2, 3]) |
| 47 | + b = torch.rand([2, 3]) |
| 48 | + c = torch.rand([2, 3]) |
| 49 | + result = MultiModalInputs.batch([{ |
| 50 | + "image": [a] |
| 51 | + }, { |
| 52 | + "image": [b] |
| 53 | + }, { |
| 54 | + "image": [c] |
| 55 | + }]) |
| 56 | + assert_multimodal_inputs_equal(result, { |
| 57 | + "image": |
| 58 | + torch.stack([a.unsqueeze(0), |
| 59 | + b.unsqueeze(0), |
| 60 | + c.unsqueeze(0)]) |
| 61 | + }) |
| 62 | + |
| 63 | + |
| 64 | +def test_multimodal_input_batch_heterogeneous_lists(): |
| 65 | + a = torch.rand([1, 2, 3]) |
| 66 | + b = torch.rand([1, 2, 3]) |
| 67 | + c = torch.rand([1, 2, 3]) |
| 68 | + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) |
| 69 | + assert_multimodal_inputs_equal( |
| 70 | + result, |
| 71 | + {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) |
| 72 | + |
| 73 | + |
| 74 | +def test_multimodal_input_batch_multiple_batchable_lists(): |
| 75 | + a = torch.rand([1, 2, 3]) |
| 76 | + b = torch.rand([1, 2, 3]) |
| 77 | + c = torch.rand([1, 2, 3]) |
| 78 | + d = torch.rand([1, 2, 3]) |
| 79 | + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) |
| 80 | + assert_multimodal_inputs_equal( |
| 81 | + result, |
| 82 | + {"image": torch.stack([torch.stack([a, b]), |
| 83 | + torch.stack([c, d])])}) |
0 commit comments