Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 8 additions & 27 deletions llama_stack/providers/remote/inference/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from collections.abc import Iterable
from typing import Any

from databricks.sdk import WorkspaceClient

from llama_stack.apis.inference import (
Inference,
Model,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

Expand Down Expand Up @@ -72,31 +71,13 @@ async def openai_completion(
) -> OpenAICompletion:
raise NotImplementedError()

async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue

self._model_cache[endpoint.name] = model

return list(self._model_cache.values())
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]

async def should_refresh_models(self) -> bool:
return False
44 changes: 32 additions & 12 deletions llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable
from typing import Any

from openai import NOT_GIVEN, AsyncOpenAI
Expand Down Expand Up @@ -111,6 +111,18 @@ def get_extra_client_params(self) -> dict[str, Any]:
"""
return {}

async def list_provider_model_ids(self) -> Iterable[str]:
"""
List available models from the provider.

Child classes can override this method to provide a custom implementation
for listing models. The default implementation uses the AsyncOpenAI client
to list models from the OpenAI-compatible endpoint.

:return: An iterable of model IDs or None if not implemented
"""
return [m.id async for m in self.client.models.list()]

@property
def client(self) -> AsyncOpenAI:
"""
Expand Down Expand Up @@ -387,28 +399,36 @@ async def list_models(self) -> list[Model] | None:
"""
self._model_cache = {}

async for m in self.client.models.list():
if self.allowed_models and m.id not in self.allowed_models:
logger.info(f"Skipping model {m.id} as it is not in the allowed models list")
# give subclasses a chance to provide custom model listing
iterable = await self.list_provider_model_ids()
if not hasattr(iterable, "__iter__"):
raise TypeError(
f"Failed to list models: {self.__class__.__name__}.list_provider_model_ids() must return an iterable of "
f"strings or None, but returned {type(iterable).__name__}"
)
provider_models_ids = list(iterable)
logger.info(f"{self.__class__.__name__}.list_provider_model_ids() returned {len(provider_models_ids)} models")

for provider_model_id in provider_models_ids:
if self.allowed_models and provider_model_id not in self.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
if metadata := self.embedding_model_metadata.get(m.id):
# This is an embedding model - augment with metadata
if metadata := self.embedding_model_metadata.get(provider_model_id):
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.embedding,
metadata=metadata,
)
else:
# This is an LLM
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
provider_resource_id=provider_model_id,
identifier=provider_model_id,
model_type=ModelType.llm,
)
self._model_cache[m.id] = model
self._model_cache[provider_model_id] = model

return list(self._model_cache.values())

Expand Down
124 changes: 124 additions & 0 deletions tests/unit/providers/utils/inference/test_openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import json
from collections.abc import Iterable
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch

import pytest
Expand Down Expand Up @@ -498,6 +499,129 @@ def get_base_url(self):
return "default-base-url"


class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom list_provider_model_ids override"""

def __init__(self, custom_model_ids):
self._custom_model_ids = custom_model_ids

async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list"""
return self._custom_model_ids


class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom list_provider_model_ids() implementation functionality"""

@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]

@pytest.fixture
def adapter(self, custom_model_ids_list):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin

async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await adapter.list_models()

assert result is not None
assert len(result) == 3

assert set(custom_model_ids_list) == {m.identifier for m in result}

async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() results are cached"""
assert len(adapter._model_cache) == 0

await adapter.list_models()

assert set(custom_model_ids_list) == set(adapter._model_cache.keys())

async def test_respects_allowed_models(self):
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=["model-1", "model-2", "model-3"])
mixin.allowed_models = ["model-1"]

result = await mixin.list_models()

assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"

async def test_with_empty_list(self):
"""Test that custom list_provider_model_ids() handles empty list correctly"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[])

result = await mixin.list_models()

assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0

async def test_wrong_type_raises_error(self):
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[["nested", "list"], {"key": "value"}])

with pytest.raises(TypeError, match="unhashable type"):
await mixin.list_models()

async def test_non_iterable_raises_error(self):
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=42)

with pytest.raises(
TypeError,
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()

async def test_with_none_items_raises_error(self):
"""Test that list_provider_model_ids() returning list with None items causes error"""
mixin = CustomListProviderModelIdsImplementation(custom_model_ids=[None, "valid-model", None])

with pytest.raises(Exception, match="Input should be a valid string"):
await mixin.list_models()

async def test_accepts_various_iterables(self):
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""

class TupleAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
return ("model-1", "model-2", "model-3")

mixin = TupleAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 3

class GeneratorAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
def gen():
yield "gen-model-1"
yield "gen-model-2"

return gen()

mixin = GeneratorAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2

class SetAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str] | None:
return {"set-model-1", "set-model-2"}

mixin = SetAdapter()
result = await mixin.list_models()
assert result is not None
assert len(result) == 2


class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality"""

Expand Down
Loading