Skip to content

Commit b618a35

Browse files
committed
models - move abstract class
1 parent 98c5a37 commit b618a35

File tree

21 files changed

+587
-676
lines changed

21 files changed

+587
-676
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
)
3030
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
3131
from ..models.bedrock import BedrockModel
32+
from ..models.model import Model
3233
from ..telemetry.metrics import EventLoopMetrics
3334
from ..telemetry.tracer import get_tracer
3435
from ..tools.registry import ToolRegistry
3536
from ..tools.watcher import ToolWatcher
3637
from ..types.content import ContentBlock, Message, Messages
3738
from ..types.exceptions import ContextWindowOverflowException
38-
from ..types.models import Model
3939
from ..types.tools import ToolResult, ToolUse
4040
from ..types.traces import AttributeValue
4141
from .agent_result import AgentResult

src/strands/event_loop/streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import logging
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional
66

7+
from ..models.model import Model
78
from ..types.content import ContentBlock, Message, Messages
8-
from ..types.models import Model
99
from ..types.streaming import (
1010
ContentBlockDeltaEvent,
1111
ContentBlockStart,

src/strands/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
This package includes an abstract base Model class along with concrete implementations for specific providers.
44
"""
55

6-
from . import bedrock
6+
from . import bedrock, model
77
from .bedrock import BedrockModel
8+
from .model import Model
89

9-
__all__ = ["bedrock", "BedrockModel"]
10+
__all__ = ["bedrock", "model", "BedrockModel", "Model"]

src/strands/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from ..tools import convert_pydantic_to_tool_spec
1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
20-
from ..types.models import Model
2120
from ..types.streaming import StreamEvent
2221
from ..types.tools import ToolSpec
22+
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -361,7 +361,7 @@ async def stream(
361361
"""
362362
logger.debug("formatting request")
363363
request = self.format_request(messages, tool_specs, system_prompt)
364-
logger.debug("formatted request=<%s>", request)
364+
logger.debug("request=<%s>", request)
365365

366366
logger.debug("invoking model")
367367
try:

src/strands/models/bedrock.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
from pydantic import BaseModel
1515
from typing_extensions import TypedDict, Unpack, override
1616

17-
from ..event_loop.streaming import process_stream
17+
from ..event_loop import streaming
1818
from ..tools import convert_pydantic_to_tool_spec
1919
from ..types.content import Messages
2020
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
21-
from ..types.models import Model
2221
from ..types.streaming import StreamEvent
2322
from ..types.tools import ToolSpec
23+
from .model import Model
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -335,7 +335,7 @@ async def stream(
335335
"""
336336
logger.debug("formatting request")
337337
request = self.format_request(messages, tool_specs, system_prompt)
338-
logger.debug("formatted request=<%s>", request)
338+
logger.debug("request=<%s>", request)
339339

340340
logger.debug("invoking model")
341341
streaming = self.config.get("streaming", True)
@@ -542,7 +542,7 @@ async def structured_output(
542542
tool_spec = convert_pydantic_to_tool_spec(output_model)
543543

544544
response = self.stream(messages=prompt, tool_specs=[tool_spec])
545-
async for event in process_stream(response, prompt):
545+
async for event in streaming.process_stream(response, prompt):
546546
yield event
547547

548548
stop_reason, messages, _, _ = event["stop"]

src/strands/models/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
16-
from ..types.models.openai import OpenAIModel
1716
from ..types.streaming import StreamEvent
1817
from ..types.tools import ToolSpec
18+
from .openai import OpenAIModel
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -121,7 +121,7 @@ async def stream(
121121
"""
122122
logger.debug("formatting request")
123123
request = self.format_request(messages, tool_specs, system_prompt)
124-
logger.debug("formatted request=<%s>", request)
124+
logger.debug("request=<%s>", request)
125125

126126
logger.debug("invoking model")
127127
response = self.client.chat.completions.create(**request)

src/strands/models/llamaapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
from ..types.content import ContentBlock, Messages
1919
from ..types.exceptions import ModelThrottledException
20-
from ..types.models import Model
2120
from ..types.streaming import StreamEvent, Usage
2221
from ..types.tools import ToolResult, ToolSpec, ToolUse
22+
from .model import Model
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -340,7 +340,7 @@ async def stream(
340340
"""
341341
logger.debug("formatting request")
342342
request = self.format_request(messages, tool_specs, system_prompt)
343-
logger.debug("formatted request=<%s>", request)
343+
logger.debug("request=<%s>", request)
344344

345345
logger.debug("invoking model")
346346
try:

src/strands/models/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.exceptions import ModelThrottledException
17-
from ..types.models import Model
1817
from ..types.streaming import StopReason, StreamEvent
1918
from ..types.tools import ToolResult, ToolSpec, ToolUse
19+
from .model import Model
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -409,7 +409,7 @@ async def stream(
409409
"""
410410
logger.debug("formatting request")
411411
request = self.format_request(messages, tool_specs, system_prompt)
412-
logger.debug("formatted request=<%s>", request)
412+
logger.debug("request=<%s>", request)
413413

414414
logger.debug("invoking model")
415415
try:
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
"""Model-related type definitions for the SDK."""
1+
"""Abstract base class for Agent model providers."""
22

33
import abc
44
import logging
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union
66

77
from pydantic import BaseModel
88

9-
from ..content import Messages
10-
from ..streaming import StreamEvent
11-
from ..tools import ToolSpec
9+
from ..types.content import Messages
10+
from ..types.streaming import StreamEvent
11+
from ..types.tools import ToolSpec
1212

1313
logger = logging.getLogger(__name__)
1414

1515
T = TypeVar("T", bound=BaseModel)
1616

1717

1818
class Model(abc.ABC):
19-
"""Abstract base class for AI model implementations.
19+
"""Abstract base class for Agent model providers.
2020
2121
This class defines the interface for all model implementations in the Strands Agents SDK. It provides a
2222
standardized way to configure and process requests for different AI model providers.

src/strands/models/ollama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from typing_extensions import TypedDict, Unpack, override
1313

1414
from ..types.content import ContentBlock, Messages
15-
from ..types.models import Model
1615
from ..types.streaming import StopReason, StreamEvent
1716
from ..types.tools import ToolSpec
17+
from .model import Model
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -296,7 +296,7 @@ async def stream(
296296
"""
297297
logger.debug("formatting request")
298298
request = self.format_request(messages, tool_specs, system_prompt)
299-
logger.debug("formatted request=<%s>", request)
299+
logger.debug("request=<%s>", request)
300300

301301
logger.debug("invoking model")
302302
tool_requested = False

0 commit comments

Comments
 (0)