Skip to content

Commit d236074

Browse files
mattfraghotham
andauthored
chore: update the groq inference impl to use openai-python for openai-compat functions (#3348)
# What does this PR do? update Groq inference provider to use OpenAIMixin for openai-compat endpoints changes on api.groq.com - - json_schema is now supported for specific models, see https://console.groq.com/docs/structured-outputs#supported-models - response_format with streaming is now supported for models that support response_format - groq no longer returns a 400 error if tools are provided and tool_choice is not "required" ## Test Plan ``` $ GROQ_API_KEY=... uv run llama stack build --image-type venv --providers inference=remote::groq --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model groq/llama-3.3-70b-versatile tests/integration/inference/test_openai_completion.py -k 'not store' ... SKIPPED [3] tests/integration/inference/test_openai_completion.py:44: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support OpenAI completions. SKIPPED [3] tests/integration/inference/test_openai_completion.py:94: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support vllm extra_body parameters. SKIPPED [4] tests/integration/inference/test_openai_completion.py:73: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support n param. SKIPPED [1] tests/integration/inference/test_openai_completion.py:100: Model groq/llama-3.3-70b-versatile hosted by remote::groq doesn't support chat completion calls with base64 encoded files. ======================= 8 passed, 11 skipped, 8 deselected, 2 warnings in 5.13s ======================== ``` --------- Co-authored-by: raghotham <[email protected]>
1 parent ecd9d8d commit d236074

File tree

3 files changed

+10
-134
lines changed

3 files changed

+10
-134
lines changed

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def available_providers() -> list[ProviderSpec]:
248248
api=Api.inference,
249249
adapter=AdapterSpec(
250250
adapter_type="groq",
251-
pip_packages=["litellm"],
251+
pip_packages=["litellm", "openai"],
252252
module="llama_stack.providers.remote.inference.groq",
253253
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
254254
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",

llama_stack/providers/remote/inference/groq/groq.py

Lines changed: 8 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,15 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from collections.abc import AsyncIterator
8-
from typing import Any
97

10-
from openai import AsyncOpenAI
11-
12-
from llama_stack.apis.inference import (
13-
OpenAIChatCompletion,
14-
OpenAIChatCompletionChunk,
15-
OpenAIChoiceDelta,
16-
OpenAIChunkChoice,
17-
OpenAIMessageParam,
18-
OpenAIResponseFormatParam,
19-
OpenAISystemMessageParam,
20-
)
218
from llama_stack.providers.remote.inference.groq.config import GroqConfig
229
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
23-
from llama_stack.providers.utils.inference.openai_compat import (
24-
prepare_openai_completion_params,
25-
)
10+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
2611

2712
from .models import MODEL_ENTRIES
2813

2914

30-
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
15+
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
3116
_config: GroqConfig
3217

3318
def __init__(self, config: GroqConfig):
@@ -40,122 +25,14 @@ def __init__(self, config: GroqConfig):
4025
)
4126
self.config = config
4227

28+
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
29+
get_api_key = LiteLLMOpenAIMixin.get_api_key
30+
31+
def get_base_url(self) -> str:
32+
return f"{self.config.url}/openai/v1"
33+
4334
async def initialize(self):
4435
await super().initialize()
4536

4637
async def shutdown(self):
4738
await super().shutdown()
48-
49-
def _get_openai_client(self) -> AsyncOpenAI:
50-
return AsyncOpenAI(
51-
base_url=f"{self.config.url}/openai/v1",
52-
api_key=self.get_api_key(),
53-
)
54-
55-
async def openai_chat_completion(
56-
self,
57-
model: str,
58-
messages: list[OpenAIMessageParam],
59-
frequency_penalty: float | None = None,
60-
function_call: str | dict[str, Any] | None = None,
61-
functions: list[dict[str, Any]] | None = None,
62-
logit_bias: dict[str, float] | None = None,
63-
logprobs: bool | None = None,
64-
max_completion_tokens: int | None = None,
65-
max_tokens: int | None = None,
66-
n: int | None = None,
67-
parallel_tool_calls: bool | None = None,
68-
presence_penalty: float | None = None,
69-
response_format: OpenAIResponseFormatParam | None = None,
70-
seed: int | None = None,
71-
stop: str | list[str] | None = None,
72-
stream: bool | None = None,
73-
stream_options: dict[str, Any] | None = None,
74-
temperature: float | None = None,
75-
tool_choice: str | dict[str, Any] | None = None,
76-
tools: list[dict[str, Any]] | None = None,
77-
top_logprobs: int | None = None,
78-
top_p: float | None = None,
79-
user: str | None = None,
80-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
81-
model_obj = await self.model_store.get_model(model)
82-
83-
# Groq does not support json_schema response format, so we need to convert it to json_object
84-
if response_format and response_format.type == "json_schema":
85-
response_format.type = "json_object"
86-
schema = response_format.json_schema.get("schema", {})
87-
response_format.json_schema = None
88-
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
89-
if messages and messages[0].role == "system":
90-
messages[0].content = messages[0].content + json_instructions
91-
else:
92-
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
93-
94-
# Groq returns a 400 error if tools are provided but none are called
95-
# So, set tool_choice to "required" to attempt to force a call
96-
if tools and (not tool_choice or tool_choice == "auto"):
97-
tool_choice = "required"
98-
99-
params = await prepare_openai_completion_params(
100-
model=model_obj.provider_resource_id,
101-
messages=messages,
102-
frequency_penalty=frequency_penalty,
103-
function_call=function_call,
104-
functions=functions,
105-
logit_bias=logit_bias,
106-
logprobs=logprobs,
107-
max_completion_tokens=max_completion_tokens,
108-
max_tokens=max_tokens,
109-
n=n,
110-
parallel_tool_calls=parallel_tool_calls,
111-
presence_penalty=presence_penalty,
112-
response_format=response_format,
113-
seed=seed,
114-
stop=stop,
115-
stream=stream,
116-
stream_options=stream_options,
117-
temperature=temperature,
118-
tool_choice=tool_choice,
119-
tools=tools,
120-
top_logprobs=top_logprobs,
121-
top_p=top_p,
122-
user=user,
123-
)
124-
125-
# Groq does not support streaming requests that set response_format
126-
fake_stream = False
127-
if stream and response_format:
128-
params["stream"] = False
129-
fake_stream = True
130-
131-
response = await self._get_openai_client().chat.completions.create(**params)
132-
133-
if fake_stream:
134-
chunk_choices = []
135-
for choice in response.choices:
136-
delta = OpenAIChoiceDelta(
137-
content=choice.message.content,
138-
role=choice.message.role,
139-
tool_calls=choice.message.tool_calls,
140-
)
141-
chunk_choice = OpenAIChunkChoice(
142-
delta=delta,
143-
finish_reason=choice.finish_reason,
144-
index=choice.index,
145-
logprobs=None,
146-
)
147-
chunk_choices.append(chunk_choice)
148-
chunk = OpenAIChatCompletionChunk(
149-
id=response.id,
150-
choices=chunk_choices,
151-
object="chat.completion.chunk",
152-
created=response.created,
153-
model=response.model,
154-
)
155-
156-
async def _fake_stream_generator():
157-
yield chunk
158-
159-
return _fake_stream_generator()
160-
else:
161-
return response

tests/unit/providers/inference/test_inference_client_caching.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
3333
with request_provider_data_context(
3434
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
3535
):
36-
openai_client = inference_adapter._get_openai_client()
37-
assert openai_client.api_key == api_key
36+
assert inference_adapter.client.api_key == api_key
3837

3938

4039
def test_openai_provider_openai_client_caching():

0 commit comments

Comments
 (0)