Skip to content

Commit 0c53106

Browse files
committed
Fix passing inline images to vision models
- Fix regression: Inline images were not getting passed to the AI models since #992 - Format inline images passed to Gemini models correctly - Format inline images passed to Anthropic models correctly Verified vision working with inline and url images for OpenAI, Anthropic and Gemini models. Resolves #1112
1 parent 1ce1d2f commit 0c53106

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

src/khoj/processor/conversation/anthropic/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from khoj.processor.conversation.utils import (
1616
ThreadedGenerator,
1717
commit_conversation_trace,
18+
get_image_from_base64,
1819
get_image_from_url,
1920
)
2021
from khoj.utils.helpers import (
@@ -232,7 +233,11 @@ def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=Non
232233
if part["type"] == "text":
233234
content.append({"type": "text", "text": part["text"]})
234235
elif part["type"] == "image_url":
235-
image = get_image_from_url(part["image_url"]["url"], type="b64")
236+
image_data = part["image_url"]["url"]
237+
if image_data.startswith("http"):
238+
image = get_image_from_url(image_data, type="b64")
239+
else:
240+
image = get_image_from_base64(image_data, type="b64")
236241
# Prefix each image with text block enumerating the image number
237242
# This helps the model reference the image in its response. Recommended by Anthropic
238243
content.extend(

src/khoj/processor/conversation/google/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from khoj.processor.conversation.utils import (
1919
ThreadedGenerator,
2020
commit_conversation_trace,
21+
get_image_from_base64,
2122
get_image_from_url,
2223
)
2324
from khoj.utils.helpers import (
@@ -245,7 +246,11 @@ def format_messages_for_gemini(
245246
message_content = []
246247
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1):
247248
if item["type"] == "image_url":
248-
image = get_image_from_url(item["image_url"]["url"], type="bytes")
249+
image_data = item["image_url"]["url"]
250+
if image_data.startswith("http"):
251+
image = get_image_from_url(image_data, type="bytes")
252+
else:
253+
image = get_image_from_base64(image_data, type="bytes")
249254
message_content += [gtypes.Part.from_bytes(data=image.content, mime_type=image.type)]
250255
else:
251256
message_content += [gtypes.Part.from_text(text=item.get("text", ""))]

src/khoj/processor/conversation/utils.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ def construct_structured_message(
345345
constructed_messages.append({"type": "text", "text": attached_file_context})
346346
if vision_enabled and images:
347347
for image in images:
348-
if image.startswith("https://"):
349-
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
348+
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
350349
return constructed_messages
351350

352351
if not is_none_or_empty(attached_file_context):
@@ -664,6 +663,23 @@ class ImageWithType:
664663
type: str
665664

666665

666+
def get_image_from_base64(image: str, type="b64"):
667+
# Extract image type and base64 data from inline image data
668+
image_base64 = image.split(",", 1)[1]
669+
image_type = image.split(";", 1)[0].split(":", 1)[1]
670+
671+
# Convert image to desired format
672+
if type == "b64":
673+
return ImageWithType(content=image_base64, type=image_type)
674+
elif type == "pil":
675+
image_data = base64.b64decode(image_base64)
676+
image_pil = PIL.Image.open(BytesIO(image_data))
677+
return ImageWithType(content=image_pil, type=image_type)
678+
elif type == "bytes":
679+
image_data = base64.b64decode(image_base64)
680+
return ImageWithType(content=image_data, type=image_type)
681+
682+
667683
def get_image_from_url(image_url: str, type="pil"):
668684
try:
669685
response = requests.get(image_url)

src/khoj/routers/api_chat.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,10 @@ async def event_generator(q: str, images: list[str]):
675675
image_bytes = base64.b64decode(base64_data)
676676
webp_image_bytes = convert_image_to_webp(image_bytes)
677677
uploaded_image = upload_user_image_to_bucket(webp_image_bytes, request.user.object.id)
678-
if uploaded_image:
679-
uploaded_images.append(uploaded_image)
678+
if not uploaded_image:
679+
base64_webp_image = base64.b64encode(webp_image_bytes).decode("utf-8")
680+
uploaded_image = f"data:image/webp;base64,{base64_webp_image}"
681+
uploaded_images.append(uploaded_image)
680682

681683
query_files: Dict[str, str] = {}
682684
if raw_query_files:

0 commit comments

Comments
 (0)