Skip to content

Commit f9d8cda

Browse files
committed
feat: combine ProviderSpec datatypes
currently `RemoteProviderSpec` has an `AdapterSpec` embedded in it. Remove `AdapterSpec`, and put its leftover fields into `RemoteProviderSpec`. Additionally, many of the fields were duplicated between `InlineProviderSpec` and `RemoteProviderSpec`. Move these to `ProviderSpec` so they are shared. Fixup the distro codegen to use `RemoteProviderSpec` directly rather than `remote_provider_spec` which took an AdapterSpec and returned a full provider spec Signed-off-by: Charlie Doern <[email protected]>
1 parent d4e45cd commit f9d8cda

File tree

15 files changed

+369
-496
lines changed

15 files changed

+369
-496
lines changed

llama_stack/core/datatypes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,9 @@ class AutoRoutedProviderSpec(ProviderSpec):
120120
provider_data_validator: str | None = Field(
121121
default=None,
122122
)
123-
124-
@property
125-
def pip_packages(self) -> list[str]:
126-
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
123+
pip_packages: list[str] = Field(
124+
default_factory=list, description="This field should not be accessed for AutoRoutedProviderSpec"
125+
)
127126

128127

129128
# Example: /models, /shields

llama_stack/core/distribution.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
from llama_stack.core.external import load_external_apis
1717
from llama_stack.log import get_logger
1818
from llama_stack.providers.datatypes import (
19-
AdapterSpec,
2019
Api,
2120
InlineProviderSpec,
2221
ProviderSpec,
23-
remote_provider_spec,
22+
RemoteProviderSpec,
2423
)
2524

2625
logger = get_logger(name=__name__, category="core")
@@ -74,27 +73,12 @@ def providable_apis() -> list[Api]:
7473

7574

7675
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
77-
adapter = AdapterSpec(**spec_data["adapter"])
78-
spec = remote_provider_spec(
79-
api=api,
80-
adapter=adapter,
81-
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
82-
)
76+
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
8377
return spec
8478

8579

8680
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
87-
spec = InlineProviderSpec(
88-
api=api,
89-
provider_type=f"inline::{provider_name}",
90-
pip_packages=spec_data.get("pip_packages", []),
91-
module=spec_data["module"],
92-
config_class=spec_data["config_class"],
93-
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
94-
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
95-
provider_data_validator=spec_data.get("provider_data_validator"),
96-
container_image=spec_data.get("container_image"),
97-
)
81+
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
9882
return spec
9983

10084

llama_stack/distributions/starter/starter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ def get_remote_inference_providers() -> list[Provider]:
7676
remote_providers = [
7777
provider
7878
for provider in available_providers()
79-
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
79+
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
8080
]
8181

8282
inference_providers = []
8383
for provider_spec in remote_providers:
84-
provider_type = provider_spec.adapter.adapter_type
84+
provider_type = provider_spec.adapter_type
8585

8686
if provider_type in INFERENCE_PROVIDER_IDS:
8787
provider_id = INFERENCE_PROVIDER_IDS[provider_type]

llama_stack/providers/datatypes.py

Lines changed: 17 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
131131
""",
132132
)
133133

134+
pip_packages: list[str] = Field(
135+
default_factory=list,
136+
description="The pip dependencies needed for this implementation",
137+
)
138+
139+
provider_data_validator: str | None = Field(
140+
default=None,
141+
)
142+
134143
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
135144

136145
# used internally by the resolver; this is a hack for now
@@ -145,56 +154,15 @@ class RoutingTable(Protocol):
145154
async def get_provider_impl(self, routing_key: str) -> Any: ...
146155

147156

148-
# TODO: this can now be inlined into RemoteProviderSpec
149-
@json_schema_type
150-
class AdapterSpec(BaseModel):
151-
adapter_type: str = Field(
152-
...,
153-
description="Unique identifier for this adapter",
154-
)
155-
module: str = Field(
156-
default_factory=str,
157-
description="""
158-
Fully-qualified name of the module to import. The module is expected to have:
159-
160-
- `get_adapter_impl(config, deps)`: returns the adapter implementation
161-
""",
162-
)
163-
pip_packages: list[str] = Field(
164-
default_factory=list,
165-
description="The pip dependencies needed for this implementation",
166-
)
167-
config_class: str = Field(
168-
description="Fully-qualified classname of the config for this provider",
169-
)
170-
provider_data_validator: str | None = Field(
171-
default=None,
172-
)
173-
description: str | None = Field(
174-
default=None,
175-
description="""
176-
A description of the provider. This is used to display in the documentation.
177-
""",
178-
)
179-
180-
181157
@json_schema_type
182158
class InlineProviderSpec(ProviderSpec):
183-
pip_packages: list[str] = Field(
184-
default_factory=list,
185-
description="The pip dependencies needed for this implementation",
186-
)
187159
container_image: str | None = Field(
188160
default=None,
189161
description="""
190162
The container image to use for this implementation. If one is provided, pip_packages will be ignored.
191163
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
192164
""",
193165
)
194-
# module field is inherited from ProviderSpec
195-
provider_data_validator: str | None = Field(
196-
default=None,
197-
)
198166
description: str | None = Field(
199167
default=None,
200168
description="""
@@ -223,44 +191,22 @@ def from_url(cls, url: str) -> "RemoteProviderConfig":
223191

224192
@json_schema_type
225193
class RemoteProviderSpec(ProviderSpec):
226-
adapter: AdapterSpec = Field(
194+
adapter_type: str = Field(
195+
...,
196+
description="Unique identifier for this adapter",
197+
)
198+
199+
description: str | None = Field(
200+
default=None,
227201
description="""
228-
If some code is needed to convert the remote responses into Llama Stack compatible
229-
API responses, specify the adapter here.
202+
A description of the provider. This is used to display in the documentation.
230203
""",
231204
)
232205

233206
@property
234207
def container_image(self) -> str | None:
235208
return None
236209

237-
# module field is inherited from ProviderSpec
238-
239-
@property
240-
def pip_packages(self) -> list[str]:
241-
return self.adapter.pip_packages
242-
243-
@property
244-
def provider_data_validator(self) -> str | None:
245-
return self.adapter.provider_data_validator
246-
247-
248-
def remote_provider_spec(
249-
api: Api,
250-
adapter: AdapterSpec,
251-
api_dependencies: list[Api] | None = None,
252-
optional_api_dependencies: list[Api] | None = None,
253-
) -> RemoteProviderSpec:
254-
return RemoteProviderSpec(
255-
api=api,
256-
provider_type=f"remote::{adapter.adapter_type}",
257-
config_class=adapter.config_class,
258-
module=adapter.module,
259-
adapter=adapter,
260-
api_dependencies=api_dependencies or [],
261-
optional_api_dependencies=optional_api_dependencies or [],
262-
)
263-
264210

265211
class HealthStatus(StrEnum):
266212
OK = "OK"

llama_stack/providers/registry/datasetio.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77

88
from llama_stack.providers.datatypes import (
9-
AdapterSpec,
109
Api,
1110
InlineProviderSpec,
1211
ProviderSpec,
13-
remote_provider_spec,
12+
RemoteProviderSpec,
1413
)
1514

1615

@@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
2524
api_dependencies=[],
2625
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
2726
),
28-
remote_provider_spec(
27+
RemoteProviderSpec(
2928
api=Api.datasetio,
30-
adapter=AdapterSpec(
31-
adapter_type="huggingface",
32-
pip_packages=[
33-
"datasets>=4.0.0",
34-
],
35-
module="llama_stack.providers.remote.datasetio.huggingface",
36-
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
37-
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
38-
),
29+
adapter_type="huggingface",
30+
provider_type="remote::huggingface",
31+
pip_packages=[
32+
"datasets>=4.0.0",
33+
],
34+
module="llama_stack.providers.remote.datasetio.huggingface",
35+
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
36+
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
3937
),
40-
remote_provider_spec(
38+
RemoteProviderSpec(
4139
api=Api.datasetio,
42-
adapter=AdapterSpec(
43-
adapter_type="nvidia",
44-
pip_packages=[
45-
"datasets>=4.0.0",
46-
],
47-
module="llama_stack.providers.remote.datasetio.nvidia",
48-
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
49-
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
50-
),
40+
adapter_type="nvidia",
41+
provider_type="remote::nvidia",
42+
module="llama_stack.providers.remote.datasetio.nvidia",
43+
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
44+
pip_packages=[
45+
"datasets>=4.0.0",
46+
],
47+
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
5148
),
5249
]

llama_stack/providers/registry/eval.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# the root directory of this source tree.
66

77

8-
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
8+
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
99

1010

1111
def available_providers() -> list[ProviderSpec]:
@@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
2525
],
2626
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
2727
),
28-
remote_provider_spec(
28+
RemoteProviderSpec(
2929
api=Api.eval,
30-
adapter=AdapterSpec(
31-
adapter_type="nvidia",
32-
pip_packages=[
33-
"requests",
34-
],
35-
module="llama_stack.providers.remote.eval.nvidia",
36-
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
37-
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
38-
),
30+
adapter_type="nvidia",
31+
pip_packages=[
32+
"requests",
33+
],
34+
provider_type="remote::nvidia",
35+
module="llama_stack.providers.remote.eval.nvidia",
36+
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
37+
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
3938
api_dependencies=[
4039
Api.datasetio,
4140
Api.datasets,

llama_stack/providers/registry/files.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.datatypes import (
8-
AdapterSpec,
9-
Api,
10-
InlineProviderSpec,
11-
ProviderSpec,
12-
remote_provider_spec,
13-
)
7+
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
148
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
159

1610

@@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
2519
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
2620
description="Local filesystem-based file storage provider for managing files and documents locally.",
2721
),
28-
remote_provider_spec(
22+
RemoteProviderSpec(
2923
api=Api.files,
30-
adapter=AdapterSpec(
31-
adapter_type="s3",
32-
pip_packages=["boto3"] + sql_store_pip_packages,
33-
module="llama_stack.providers.remote.files.s3",
34-
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
35-
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
36-
),
24+
provider_type="remote::s3",
25+
adapter_type="s3",
26+
pip_packages=["boto3"] + sql_store_pip_packages,
27+
module="llama_stack.providers.remote.files.s3",
28+
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
29+
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
3730
),
3831
]

0 commit comments

Comments
 (0)