Skip to content

Commit b0e0409

Browse files
authored
Bugfix/image encoding (#327)
* Fix image serialization
1 parent c5c61a3 commit b0e0409

File tree

2 files changed

+57
-17
lines changed

2 files changed

+57
-17
lines changed

ollama/_types.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from base64 import b64encode
2+
from base64 import b64decode, b64encode
33
from pathlib import Path
44
from datetime import datetime
55
from typing import Any, Mapping, Optional, Union, Sequence
@@ -11,8 +11,6 @@
1111
ByteSize,
1212
ConfigDict,
1313
Field,
14-
FilePath,
15-
Base64Str,
1614
model_serializer,
1715
)
1816

@@ -89,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):
8987

9088

9189
class Image(BaseModel):
92-
value: Union[FilePath, Base64Str, bytes]
90+
value: Union[str, bytes, Path]
9391

94-
# This overloads the `model_dump` method and returns values depending on the type of the `value` field
9592
@model_serializer
9693
def serialize_model(self):
97-
if isinstance(self.value, Path):
98-
return b64encode(self.value.read_bytes()).decode()
99-
elif isinstance(self.value, bytes):
100-
return b64encode(self.value).decode()
101-
return self.value
94+
if isinstance(self.value, (Path, bytes)):
95+
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
96+
97+
if isinstance(self.value, str):
98+
if Path(self.value).exists():
99+
return b64encode(Path(self.value).read_bytes()).decode()
100+
101+
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
102+
raise ValueError(f'File {self.value} does not exist')
103+
104+
try:
105+
# Try to decode to check if it's already base64
106+
b64decode(self.value)
107+
return self.value
108+
except Exception:
109+
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception
102110

103111

104112
class GenerateRequest(BaseGenerateRequest):

tests/test_type_serialization.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,48 @@
1-
from base64 import b64decode, b64encode
2-
1+
from base64 import b64encode
2+
from pathlib import Path
33

4+
import pytest
45
from ollama._types import Image
6+
import tempfile
57

68

7-
def test_image_serialization():
8-
# Test bytes serialization
9+
def test_image_serialization_bytes():
910
image_bytes = b'test image bytes'
11+
encoded_string = b64encode(image_bytes).decode()
1012
img = Image(value=image_bytes)
11-
assert img.model_dump() == b64encode(image_bytes).decode()
13+
assert img.model_dump() == encoded_string
14+
1215

13-
# Test base64 string serialization
16+
def test_image_serialization_base64_string():
1417
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
1518
img = Image(value=b64_str)
16-
assert img.model_dump() == b64decode(b64_str).decode()
19+
assert img.model_dump() == b64_str # Should return as-is if valid base64
20+
21+
22+
def test_image_serialization_plain_string():
23+
img = Image(value='not a path or base64')
24+
assert img.model_dump() == 'not a path or base64' # Should return as-is
25+
26+
27+
def test_image_serialization_path():
28+
with tempfile.NamedTemporaryFile() as temp_file:
29+
temp_file.write(b'test file content')
30+
temp_file.flush()
31+
img = Image(value=Path(temp_file.name))
32+
assert img.model_dump() == b64encode(b'test file content').decode()
33+
34+
35+
def test_image_serialization_string_path():
36+
with tempfile.NamedTemporaryFile() as temp_file:
37+
temp_file.write(b'test file content')
38+
temp_file.flush()
39+
img = Image(value=temp_file.name)
40+
assert img.model_dump() == b64encode(b'test file content').decode()
41+
42+
with pytest.raises(ValueError):
43+
img = Image(value='some_path/that/does/not/exist.png')
44+
img.model_dump()
45+
46+
with pytest.raises(ValueError):
47+
img = Image(value='not an image')
48+
img.model_dump()

0 commit comments

Comments
 (0)