Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from base64 import b64encode
from base64 import b64decode, b64encode
from pathlib import Path
from datetime import datetime
from typing import Any, Mapping, Optional, Union, Sequence
Expand All @@ -11,8 +11,6 @@
ByteSize,
ConfigDict,
Field,
FilePath,
Base64Str,
model_serializer,
)

Expand Down Expand Up @@ -89,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):


class Image(BaseModel):
value: Union[FilePath, Base64Str, bytes]
value: Union[str, bytes, Path]

# This overloads the `model_dump` method and returns values depending on the type of the `value` field
@model_serializer
def serialize_model(self):
if isinstance(self.value, Path):
return b64encode(self.value.read_bytes()).decode()
elif isinstance(self.value, bytes):
return b64encode(self.value).decode()
return self.value
if isinstance(self.value, (Path, bytes)):
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()

if isinstance(self.value, str):
if Path(self.value).exists():
return b64encode(Path(self.value).read_bytes()).decode()

if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
raise ValueError(f'File {self.value} does not exist')
Comment on lines +101 to +102
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to handle the edge case where a path does not exist but could be a valid b64 string - would lead to weird behavior


try:
# Try to decode to check if it's already base64
b64decode(self.value)
return self.value
except Exception:
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception


class GenerateRequest(BaseGenerateRequest):
Expand Down
46 changes: 39 additions & 7 deletions tests/test_type_serialization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
from base64 import b64decode, b64encode

from base64 import b64encode
from pathlib import Path

import pytest
from ollama._types import Image
import tempfile


def test_image_serialization():
# Test bytes serialization
def test_image_serialization_bytes():
image_bytes = b'test image bytes'
encoded_string = b64encode(image_bytes).decode()
img = Image(value=image_bytes)
assert img.model_dump() == b64encode(image_bytes).decode()
assert img.model_dump() == encoded_string


# Test base64 string serialization
def test_image_serialization_base64_string():
b64_str = 'dGVzdCBiYXNlNjQgc3RyaW5n'
img = Image(value=b64_str)
assert img.model_dump() == b64decode(b64_str).decode()
assert img.model_dump() == b64_str # Should return as-is if valid base64


def test_image_serialization_plain_string():
img = Image(value='not a path or base64')
assert img.model_dump() == 'not a path or base64' # Should return as-is
Comment on lines +23 to +24
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently just a passthrough. Since spaces are valid b64 I'm hesitant to split on it and check the length. although we could do a heuristic to catch it greedily



def test_image_serialization_path():
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=Path(temp_file.name))
assert img.model_dump() == b64encode(b'test file content').decode()


def test_image_serialization_string_path():
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b'test file content')
temp_file.flush()
img = Image(value=temp_file.name)
assert img.model_dump() == b64encode(b'test file content').decode()

with pytest.raises(ValueError):
img = Image(value='some_path/that/does/not/exist.png')
img.model_dump()

with pytest.raises(ValueError):
img = Image(value='not an image')
img.model_dump()