Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions llama_stack/core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None,
)

@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")


# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):
Expand Down
22 changes: 3 additions & 19 deletions llama_stack/core/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)

logger = get_logger(name=__name__, category="core")
Expand Down Expand Up @@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:


def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
return spec


def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
return spec


Expand Down
4 changes: 2 additions & 2 deletions llama_stack/distributions/starter/starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [
provider
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
]

inference_providers = []
for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type
provider_type = provider_spec.adapter_type

if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
Expand Down
88 changes: 17 additions & 71 deletions llama_stack/providers/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""",
)

pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)

provider_data_validator: str | None = Field(
default=None,
)

is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")

# used internally by the resolver; this is a hack for now
Expand All @@ -145,56 +154,15 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ...


# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:

- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)


@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field(
default=None,
description="""
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
Expand Down Expand Up @@ -223,44 +191,22 @@ def from_url(cls, url: str) -> "RemoteProviderConfig":

@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field(
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)

description: str | None = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here.
A description of the provider. This is used to display in the documentation.
""",
)

@property
def container_image(self) -> str | None:
return None

# module field is inherited from ProviderSpec

@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages

@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator


def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)


class HealthStatus(StrEnum):
OK = "OK"
Expand Down
41 changes: 19 additions & 22 deletions llama_stack/providers/registry/datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@


from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)


Expand All @@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[
"datasets>=4.0.0",
],
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
]
21 changes: 10 additions & 11 deletions llama_stack/providers/registry/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree.


from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec


def available_providers() -> list[ProviderSpec]:
Expand All @@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
adapter_type="nvidia",
pip_packages=[
"requests",
],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
api_dependencies=[
Api.datasetio,
Api.datasets,
Expand Down
23 changes: 8 additions & 15 deletions llama_stack/providers/registry/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages


Expand All @@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
provider_type="remote::s3",
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
]
Loading
Loading