diff --git a/src/__fixtures__/agent-helpers.ts b/src/__fixtures__/agent-helpers.ts new file mode 100644 index 00000000..bc4415d1 --- /dev/null +++ b/src/__fixtures__/agent-helpers.ts @@ -0,0 +1,37 @@ +/** + * Test fixtures and helpers for Agent testing. + * This module provides utilities for testing Agent-related implementations. + */ + +import type { Agent } from '../agent/agent.js' +import type { Message } from '../types/messages.js' +import { AgentState } from '../agent/state.js' +import type { JSONValue } from '../types/json.js' + +/** + * Data for creating a mock Agent. + */ +export interface MockAgentData { + /** + * Messages for the agent. + */ + messages?: Message[] + /** + * Initial state for the agent. + */ + state?: Record +} + +/** + * Helper to create a mock Agent for testing. + * Provides minimal Agent interface with messages and state. + * + * @param data - Optional mock agent data + * @returns Mock Agent object + */ +export function createMockAgent(data?: MockAgentData): Agent { + return { + messages: data?.messages ?? [], + state: new AgentState(data?.state ?? {}), + } as unknown as Agent +} diff --git a/src/__fixtures__/tool-helpers.ts b/src/__fixtures__/tool-helpers.ts index 0fd903f0..085c1e5a 100644 --- a/src/__fixtures__/tool-helpers.ts +++ b/src/__fixtures__/tool-helpers.ts @@ -23,6 +23,7 @@ export function createMockContext( toolUse, agent: { state: new AgentState(agentState), + messages: [], }, } } diff --git a/src/agent/__tests__/agent.hook.test.ts b/src/agent/__tests__/agent.hook.test.ts index 52a5f154..eb77a825 100644 --- a/src/agent/__tests__/agent.hook.test.ts +++ b/src/agent/__tests__/agent.hook.test.ts @@ -9,6 +9,7 @@ import { BeforeToolCallEvent, MessageAddedEvent, ModelStreamEventHook, + type HookRegistry, } from '../../hooks/index.js' import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' import { MockHookProvider } from '../../__fixtures__/mock-hook-provider.js' @@ -301,4 +302,37 @@ describe('Agent Hooks Integration', () => { ) }) }) + + describe('AfterModelCallEvent retryModelCall', () => { + it('retries model call when hook sets retryModelCall', async () => { + let callCount = 0 + const retryHook = { + registerCallbacks: (registry: HookRegistry) => { + registry.addCallback(AfterModelCallEvent, (event: AfterModelCallEvent) => { + callCount++ + if (callCount === 1 && event.error) { + event.retryModelCall = true + } + }) + }, + } + + const model = new MockMessageModel() + .addTurn(new Error('First attempt failed')) + .addTurn({ type: 'textBlock', text: 'Success after retry' }) + + const agent = new Agent({ model, hooks: [retryHook] }) + const result = await agent.invoke('Test') + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'Success after retry' }) + expect(callCount).toBe(2) + }) + + it('does not retry when retryModelCall is not set', async () => { + const model = new MockMessageModel().addTurn(new Error('Failure')) + const agent = new Agent({ model }) + + await expect(agent.invoke('Test')).rejects.toThrow('Failure') + }) + }) }) diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 1af6e055..5a7b4252 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -12,16 +12,15 @@ import { ToolResultBlock, type ToolUseBlock, } from '../index.js' -import { normalizeError, ConcurrentInvocationError, MaxTokensError, ContextWindowOverflowError } from '../errors.js' +import { normalizeError, ConcurrentInvocationError, MaxTokensError } from '../errors.js' import type { BaseModelConfig, Model, StreamOptions } from '../models/model.js' import { ToolRegistry } from '../registry/tool-registry.js' import { AgentState } from './state.js' import type { AgentData } from '../types/agent.js' import { AgentPrinter, getDefaultAppender, type Printer } from './printer.js' -import type { ConversationManager } from '../conversation-manager/conversation-manager.js' +import type { HookProvider } from '../hooks/types.js' import { SlidingWindowConversationManager } from '../conversation-manager/sliding-window-conversation-manager.js' import { HookRegistryImplementation } from '../hooks/registry.js' -import type { HookProvider } from '../hooks/types.js' import { AfterInvocationEvent, AfterModelCallEvent, @@ -74,7 +73,7 @@ export type AgentConfig = { * Conversation manager for handling message history and context overflow. * Defaults to SlidingWindowConversationManager with windowSize of 40. */ - conversationManager?: ConversationManager + conversationManager?: HookProvider /** * Hook providers to register with the agent. * Hooks enable observing and extending agent behavior. @@ -107,7 +106,7 @@ export class Agent implements AgentData { /** * Conversation manager for handling message history and context overflow. */ - public readonly conversationManager: ConversationManager + public readonly conversationManager: HookProvider private _isInvoking: boolean = false private _printer?: Printer @@ -140,10 +139,12 @@ export class Agent implements AgentData { this.state = new AgentState(config?.state) + // Initialize conversation manager this.conversationManager = config?.conversationManager ?? new SlidingWindowConversationManager({ windowSize: 40 }) - // Initialize hooks + // Initialize hooks and register conversation manager hooks this.hooks = new HookRegistryImplementation() + this.hooks.addHook(this.conversationManager) this.hooks.addAllHooks(config?.hooks ?? []) // Create printer if printer is enabled (default: true) @@ -283,8 +284,6 @@ export class Agent implements AgentData { // Continue loop } } finally { - this.conversationManager.applyManagement(this) - // Invoke AfterInvocationEvent hook await this.hooks.invokeCallbacks(new AfterInvocationEvent({ agent: this })) @@ -362,14 +361,14 @@ export class Agent implements AgentData { const modelError = normalizeError(error) // Invoke AfterModelCallEvent hook even on error - await this.hooks.invokeCallbacks(new AfterModelCallEvent({ agent: this, error: modelError })) + const event = await this.hooks.invokeCallbacks(new AfterModelCallEvent({ agent: this, error: modelError })) - if (error instanceof ContextWindowOverflowError) { - // Reduce context and retry - this.conversationManager.reduceContext(this, error) + // Check if hooks request a retry (e.g., after reducing context) + if (event.retryModelCall) { return yield* this.invokeModel(args) } - // Re-throw other errors + + // Re-throw error throw error } } diff --git a/src/conversation-manager/__tests__/conversation-manager.test.ts b/src/conversation-manager/__tests__/conversation-manager.test.ts deleted file mode 100644 index a43d904e..00000000 --- a/src/conversation-manager/__tests__/conversation-manager.test.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { describe, it, expect } from 'vitest' -import { ConversationManager } from '../conversation-manager.js' - -describe('ConversationManager', () => { - // ConversationManager is an abstract base class - // Specific implementations are tested in their own test files - - it('is an abstract class', () => { - expect(ConversationManager).toBeDefined() - }) -}) diff --git a/src/conversation-manager/__tests__/null-conversation-manager.test.ts b/src/conversation-manager/__tests__/null-conversation-manager.test.ts index 6681c599..a79a1ce6 100644 --- a/src/conversation-manager/__tests__/null-conversation-manager.test.ts +++ b/src/conversation-manager/__tests__/null-conversation-manager.test.ts @@ -1,53 +1,42 @@ import { describe, it, expect } from 'vitest' import { NullConversationManager } from '../null-conversation-manager.js' -import { ContextWindowOverflowError, Message, TextBlock } from '../../index.js' -import type { Agent } from '../../agent/agent.js' +import { Message, TextBlock } from '../../index.js' +import { HookRegistryImplementation } from '../../hooks/registry.js' +import { AfterInvocationEvent, AfterModelCallEvent } from '../../hooks/events.js' +import { ContextWindowOverflowError } from '../../errors.js' +import { createMockAgent } from '../../__fixtures__/agent-helpers.js' describe('NullConversationManager', () => { - describe('applyManagement', () => { - it('does not modify messages array', () => { + describe('behavior', () => { + it('does not modify conversation history', async () => { const manager = new NullConversationManager() const messages = [ new Message({ role: 'user', content: [new TextBlock('Hello')] }), new Message({ role: 'assistant', content: [new TextBlock('Hi there')] }), ] - const mockAgent = { messages } as unknown as Agent + const mockAgent = createMockAgent({ messages }) - manager.applyManagement(mockAgent) + const registry = new HookRegistryImplementation() + manager.registerCallbacks(registry) + + await registry.invokeCallbacks(new AfterInvocationEvent({ agent: mockAgent })) expect(mockAgent.messages).toHaveLength(2) expect(mockAgent.messages[0]!.content[0]).toEqual({ type: 'textBlock', text: 'Hello' }) expect(mockAgent.messages[1]!.content[0]).toEqual({ type: 'textBlock', text: 'Hi there' }) }) - }) - - describe('reduceContext', () => { - it('re-throws provided error', () => { - const manager = new NullConversationManager() - const mockAgent = { messages: [] } as unknown as Agent - const testError = new Error('Test error') - - expect(() => { - manager.reduceContext(mockAgent, testError) - }).toThrow(testError) - }) - it('throws ContextWindowOverflowError when no error provided', () => { + it('does not set retryModelCall on context overflow', async () => { const manager = new NullConversationManager() - const mockAgent = { messages: [] } as unknown as Agent + const mockAgent = createMockAgent() + const error = new ContextWindowOverflowError('Context overflow') - expect(() => { - manager.reduceContext(mockAgent) - }).toThrow(ContextWindowOverflowError) - }) + const registry = new HookRegistryImplementation() + manager.registerCallbacks(registry) - it('throws ContextWindowOverflowError with correct message when no error provided', () => { - const manager = new NullConversationManager() - const mockAgent = { messages: [] } as unknown as Agent + const event = await registry.invokeCallbacks(new AfterModelCallEvent({ agent: mockAgent, error })) - expect(() => { - manager.reduceContext(mockAgent) - }).toThrow('Context window overflowed!') + expect(event.retryModelCall).toBeUndefined() }) }) }) diff --git a/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts b/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts index 5ab0baae..ee4a2e9f 100644 --- a/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts +++ b/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts @@ -1,8 +1,29 @@ import { describe, it, expect } from 'vitest' import { SlidingWindowConversationManager } from '../sliding-window-conversation-manager.js' import { ContextWindowOverflowError, Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../index.js' +import { HookRegistryImplementation } from '../../hooks/registry.js' +import { AfterInvocationEvent, AfterModelCallEvent } from '../../hooks/events.js' +import { createMockAgent } from '../../__fixtures__/agent-helpers.js' import type { Agent } from '../../agent/agent.js' +// Helper to trigger sliding window management through hooks +async function triggerSlidingWindow(manager: SlidingWindowConversationManager, agent: Agent): Promise { + const registry = new HookRegistryImplementation() + registry.addHook(manager) + await registry.invokeCallbacks(new AfterInvocationEvent({ agent })) +} + +// Helper to trigger context overflow handling through hooks +async function triggerContextOverflow( + manager: SlidingWindowConversationManager, + agent: Agent, + error: Error +): Promise<{ retryModelCall?: boolean }> { + const registry = new HookRegistryImplementation() + registry.addHook(manager) + return await registry.invokeCallbacks(new AfterModelCallEvent({ agent, error })) +} + describe('SlidingWindowConversationManager', () => { describe('constructor', () => { it('sets default windowSize to 40', () => { @@ -28,42 +49,42 @@ describe('SlidingWindowConversationManager', () => { }) describe('applyManagement', () => { - it('skips reduction when message count is less than window size', () => { + it('skips reduction when message count is less than window size', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 10 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.applyManagement(mockAgent) + await triggerSlidingWindow(manager, mockAgent) expect(mockAgent.messages).toHaveLength(2) }) - it('skips reduction when message count equals window size', () => { + it('skips reduction when message count equals window size', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.applyManagement(mockAgent) + await triggerSlidingWindow(manager, mockAgent) expect(mockAgent.messages).toHaveLength(2) }) - it('calls reduceContext when message count exceeds window size', () => { + it('calls reduceContext when message count exceeds window size', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), new Message({ role: 'user', content: [new TextBlock('Message 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.applyManagement(mockAgent) + await triggerSlidingWindow(manager, mockAgent) // Should have trimmed to window size expect(mockAgent.messages).toHaveLength(2) @@ -71,7 +92,7 @@ describe('SlidingWindowConversationManager', () => { }) describe('reduceContext - tool result truncation', () => { - it('truncates tool results when shouldTruncateResults is true', () => { + it('truncates tool results when shouldTruncateResults is true', async () => { const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) const messages = [ new Message({ @@ -85,16 +106,16 @@ describe('SlidingWindowConversationManager', () => { ], }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) const toolResult = messages[0]!.content[0]! as ToolResultBlock expect(toolResult.status).toBe('error') expect(toolResult.content[0]).toEqual({ type: 'textBlock', text: 'The tool result was too large!' }) }) - it('finds last message with tool results', () => { + it('finds last message with tool results', async () => { const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -120,9 +141,9 @@ describe('SlidingWindowConversationManager', () => { ], }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should truncate the last message with tool results (index 3) const lastToolResult = messages[3]!.content[0]! as ToolResultBlock @@ -135,7 +156,7 @@ describe('SlidingWindowConversationManager', () => { expect(firstToolResult.content[0]).toEqual({ type: 'textBlock', text: 'First result' }) }) - it('returns after successful truncation without trimming messages', () => { + it('returns after successful truncation without trimming messages', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: true }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -151,15 +172,15 @@ describe('SlidingWindowConversationManager', () => { ], }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should not have removed any messages, only truncated tool result expect(mockAgent.messages).toHaveLength(3) }) - it('skips truncation when shouldTruncateResults is false', () => { + it('skips truncation when shouldTruncateResults is false', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -175,9 +196,9 @@ describe('SlidingWindowConversationManager', () => { ], }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should have trimmed messages instead of truncating tool result expect(mockAgent.messages).toHaveLength(2) @@ -187,7 +208,7 @@ describe('SlidingWindowConversationManager', () => { expect(toolResult.status).toBe('success') }) - it('does not truncate already-truncated results', () => { + it('does not truncate already-truncated results', async () => { const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) const messages = [ new Message({ @@ -223,7 +244,7 @@ describe('SlidingWindowConversationManager', () => { ] const mockAgent = { messages: messages2 } as unknown as Agent - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should have trimmed messages since truncation was skipped expect(mockAgent.messages.length).toBeLessThan(3) @@ -231,7 +252,7 @@ describe('SlidingWindowConversationManager', () => { }) describe('reduceContext - message trimming', () => { - it('trims oldest messages when tool results cannot be truncated', () => { + it('trims oldest messages when tool results cannot be truncated', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 3, shouldTruncateResults: false }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -240,15 +261,15 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), new Message({ role: 'user', content: [new TextBlock('Message 3')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) expect(mockAgent.messages).toHaveLength(3) expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) }) - it('calculates correct trim index (messages.length - windowSize)', () => { + it('calculates correct trim index (messages.length - windowSize)', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -256,30 +277,30 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'user', content: [new TextBlock('Message 2')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should remove 2 messages (4 - 2 = 2) expect(mockAgent.messages).toHaveLength(2) }) - it('uses default trim index of 2 when messages <= windowSize', () => { + it('uses default trim index of 2 when messages <= windowSize', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 5 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), new Message({ role: 'user', content: [new TextBlock('Message 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should remove 2 messages (default when count <= windowSize) expect(mockAgent.messages).toHaveLength(1) }) - it('removes messages from start of array using splice', () => { + it('removes messages from start of array using splice', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -287,9 +308,9 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'user', content: [new TextBlock('Message 2')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should keep last 2 messages expect(mockAgent.messages).toHaveLength(2) @@ -299,7 +320,7 @@ describe('SlidingWindowConversationManager', () => { }) describe('reduceContext - tool pair validation', () => { - it('does not trim at index where oldest message is toolResult', () => { + it('does not trim at index where oldest message is toolResult', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) const messages = [ new Message({ @@ -319,9 +340,9 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'assistant', content: [new TextBlock('Response')] }), new Message({ role: 'user', content: [new TextBlock('Message')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should not trim at index 1 (toolResult), should trim at index 2 instead // This means keeping last 2 messages @@ -329,7 +350,7 @@ describe('SlidingWindowConversationManager', () => { expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Response' }) }) - it('does not trim at index where oldest message is toolUse without following toolResult', () => { + it('does not trim at index where oldest message is toolUse without following toolResult', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -340,16 +361,16 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'assistant', content: [new TextBlock('Response')] }), // Not a toolResult new Message({ role: 'user', content: [new TextBlock('Message 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should skip index 1 (toolUse without following toolResult), trim at index 2 expect(mockAgent.messages).toHaveLength(2) expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Response' }) }) - it('allows trim when oldest message is toolUse with following toolResult', () => { + it('allows trim when oldest message is toolUse with following toolResult', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -370,9 +391,9 @@ describe('SlidingWindowConversationManager', () => { new Message({ role: 'assistant', content: [new TextBlock('Response')] }), new Message({ role: 'user', content: [new TextBlock('Message 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should trim at index 3 (5 - 2 = 3) // Index 1 would be toolUse (valid start since toolResult follows) @@ -384,7 +405,7 @@ describe('SlidingWindowConversationManager', () => { expect(mockAgent.messages[1]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) }) - it('allows trim at toolUse when toolResult immediately follows', () => { + it('allows trim at toolUse when toolResult immediately follows', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 3, shouldTruncateResults: false }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), @@ -405,9 +426,9 @@ describe('SlidingWindowConversationManager', () => { }), new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should trim at index 2 (5 - 3 = 2) // Index 2 is toolUse with toolResult at index 3 - this is valid @@ -426,23 +447,23 @@ describe('SlidingWindowConversationManager', () => { }) }) - it('allows trim when oldest message is text or other non-tool content', () => { + it('allows trim when oldest message is text or other non-tool content', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 2 }) const messages = [ new Message({ role: 'user', content: [new TextBlock('Message 1')] }), new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), new Message({ role: 'user', content: [new TextBlock('Message 2')] }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - manager.reduceContext(mockAgent) + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) // Should trim at index 1 (3 - 2 = 1) expect(mockAgent.messages).toHaveLength(2) expect(mockAgent.messages[0]!.content[0]).toEqual({ type: 'textBlock', text: 'Response 1' }) }) - it('throws ContextWindowOverflowError when no valid trim point exists', () => { + it('throws ContextWindowOverflowError when no valid trim point exists', async () => { const manager = new SlidingWindowConversationManager({ windowSize: 0, shouldTruncateResults: false }) const messages = [ new Message({ @@ -456,11 +477,11 @@ describe('SlidingWindowConversationManager', () => { ], }), ] - const mockAgent = { messages } as Agent + const mockAgent = createMockAgent({ messages }) - expect(() => { - manager.reduceContext(mockAgent) - }).toThrow(ContextWindowOverflowError) + await expect( + triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + ).rejects.toThrow(ContextWindowOverflowError) }) }) diff --git a/src/conversation-manager/conversation-manager.ts b/src/conversation-manager/conversation-manager.ts deleted file mode 100644 index beba1667..00000000 --- a/src/conversation-manager/conversation-manager.ts +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Abstract interface for conversation history management. - * - * This module provides the base class for implementing conversation management strategies - * to control the size of message arrays, helping to manage memory usage, control context - * length, and maintain relevant conversation state. - */ - -import type { Message } from '../types/messages.js' - -/** - * Interface for conversation context that can be managed. - * - * This interface defines the minimal set of properties required by conversation managers - * to perform their operations. Using an interface allows for backwards-compatible - * API evolution and better decoupling from specific implementations. - */ -export interface ConversationContext { - /** - * The conversation history of messages that will be managed. - * This array is modified in-place by conversation management operations. - */ - messages: Message[] -} - -/** - * Abstract base class for managing conversation history. - * - * This class provides an interface for implementing conversation management strategies - * to control the size of message arrays/conversation histories, helping to: - * - * - Manage memory usage - * - Control context length - * - Maintain relevant conversation state - */ -export abstract class ConversationManager { - /** - * Creates a new ConversationManager instance. - */ - constructor() {} - - /** - * Applies management strategy to the provided conversation context. - * - * Processes the conversation history to maintain appropriate size by modifying - * the messages list in-place. Implementations should handle message pruning, - * summarization, or other size management techniques to keep the conversation - * context within desired bounds. - * - * @param context - The conversation context whose message history will be managed. - * The messages array is modified in-place. - */ - public abstract applyManagement(context: ConversationContext): void - - /** - * Called when the model's context window is exceeded. - * - * This method should implement the specific strategy for reducing the window size - * when a context overflow occurs. It is typically called after a ContextWindowOverflowError - * is caught during model invocation. - * - * Implementations might use strategies such as: - * - Removing the N oldest messages - * - Summarizing older context - * - Applying importance-based filtering - * - Maintaining critical conversation markers - * - * @param context - The conversation context whose message history will be reduced. - * The messages array is modified in-place. - * @param error - The error that triggered the context reduction, if any. - * - * @throws ContextWindowOverflowError If the context cannot be reduced further, - * such as when the conversation is already minimal or when tool result - * messages cannot be properly converted. - */ - public abstract reduceContext(context: ConversationContext, error?: Error): void -} diff --git a/src/conversation-manager/index.ts b/src/conversation-manager/index.ts index d0ff0586..b702c02f 100644 --- a/src/conversation-manager/index.ts +++ b/src/conversation-manager/index.ts @@ -1,10 +1,9 @@ /** * Conversation Manager exports. * - * This module exports all conversation manager implementations and types. + * This module exports conversation manager implementations. */ -export { ConversationManager, type ConversationContext } from './conversation-manager.js' export { NullConversationManager } from './null-conversation-manager.js' export { SlidingWindowConversationManager, diff --git a/src/conversation-manager/null-conversation-manager.ts b/src/conversation-manager/null-conversation-manager.ts index 12f65dff..a853b1a7 100644 --- a/src/conversation-manager/null-conversation-manager.ts +++ b/src/conversation-manager/null-conversation-manager.ts @@ -2,45 +2,25 @@ * Null implementation of conversation management. * * This module provides a no-op conversation manager that does not modify - * the conversation history, useful for testing and scenarios where conversation + * the conversation history. Useful for testing and scenarios where conversation * management is handled externally. */ -import { ContextWindowOverflowError } from '../errors.js' -import { ConversationManager, type ConversationContext } from './conversation-manager.js' +import type { HookProvider } from '../hooks/types.js' +import type { HookRegistry } from '../hooks/registry.js' /** * A no-op conversation manager that does not modify the conversation history. - * + * Implements HookProvider but registers zero hooks. */ -export class NullConversationManager extends ConversationManager { - /** - * Does nothing to the conversation history. - * - * @param _context - The conversation context whose message history will remain unmodified. - */ - public applyManagement(_context: ConversationContext): void { - // No-op - } - +export class NullConversationManager implements HookProvider { /** - * Does not reduce context and raises an exception. - * - * If an error is provided, re-throws it. Otherwise, throws a new - * ContextWindowOverflowError indicating that the context window has - * overflowed and cannot be reduced. - * - * @param _context - The conversation context whose message history will remain unmodified. - * @param error - The error that triggered the context reduction, if any. + * Registers callbacks with the hook registry. + * This implementation registers no hooks, providing a complete no-op behavior. * - * @throws Error The provided error if one was given. - * @throws ContextWindowOverflowError If no error was provided. + * @param _registry - The hook registry to register callbacks with (unused) */ - public reduceContext(_context: ConversationContext, error?: Error): void { - if (error) { - throw error - } else { - throw new ContextWindowOverflowError('Context window overflowed!') - } + public registerCallbacks(_registry: HookRegistry): void { + // No-op - register zero hooks } } diff --git a/src/conversation-manager/sliding-window-conversation-manager.ts b/src/conversation-manager/sliding-window-conversation-manager.ts index 0f4a519f..081900f7 100644 --- a/src/conversation-manager/sliding-window-conversation-manager.ts +++ b/src/conversation-manager/sliding-window-conversation-manager.ts @@ -7,7 +7,9 @@ import { ContextWindowOverflowError } from '../errors.js' import { Message, TextBlock, ToolResultBlock } from '../types/messages.js' -import { ConversationManager, type ConversationContext } from './conversation-manager.js' +import type { HookProvider } from '../hooks/types.js' +import type { HookRegistry } from '../hooks/registry.js' +import { AfterInvocationEvent, AfterModelCallEvent } from '../hooks/events.js' /** * Configuration for the sliding window conversation manager. @@ -33,8 +35,12 @@ export type SlidingWindowConversationManagerConfig = { * tool usage pairs and avoids invalid window states. When the message count exceeds * the window size, it will either truncate large tool results or remove the oldest * messages while ensuring tool use/result pairs remain valid. + * + * As a HookProvider, it registers callbacks for: + * - AfterInvocationEvent: Applies sliding window management after each invocation + * - AfterModelCallEvent: Reduces context on overflow errors and requests retry */ -export class SlidingWindowConversationManager extends ConversationManager { +export class SlidingWindowConversationManager implements HookProvider { private readonly _windowSize: number private readonly _shouldTruncateResults: boolean @@ -44,28 +50,49 @@ export class SlidingWindowConversationManager extends ConversationManager { * @param config - Configuration options for the sliding window manager. */ constructor(config?: SlidingWindowConversationManagerConfig) { - super() this._windowSize = config?.windowSize ?? 40 this._shouldTruncateResults = config?.shouldTruncateResults ?? true } /** - * Apply the sliding window to the conversation context's messages array to maintain a manageable history size. + * Registers callbacks with the hook registry. + * + * Registers: + * - AfterInvocationEvent callback to apply sliding window management + * - AfterModelCallEvent callback to handle context overflow and request retry + * + * @param registry - The hook registry to register callbacks with + */ + public registerCallbacks(registry: HookRegistry): void { + // Apply sliding window management after each invocation + registry.addCallback(AfterInvocationEvent, (event) => { + this.applyManagement(event.agent.messages) + }) + + // Handle context overflow errors + registry.addCallback(AfterModelCallEvent, (event) => { + if (event.error instanceof ContextWindowOverflowError) { + this.reduceContext(event.agent.messages, event.error) + event.retryModelCall = true + } + }) + } + + /** + * Apply the sliding window to the messages array to maintain a manageable history size. * * This method is called after every event loop cycle to apply a sliding window if the message * count exceeds the window size. If the number of messages is within the window size, no action * is taken. * - * @param context - The conversation context whose messages will be managed. The messages array is modified in-place. + * @param messages - The message array to manage. Modified in-place. */ - public applyManagement(context: ConversationContext): void { - const messages = context.messages - + private applyManagement(messages: Message[]): void { if (messages.length <= this._windowSize) { return } - this.reduceContext(context) + this.reduceContext(messages) } /** @@ -80,15 +107,13 @@ export class SlidingWindowConversationManager extends ConversationManager { * 2. If truncation is not possible or doesn't help, trim oldest messages * 3. When trimming, skip invalid trim points (toolResult at start, or toolUse without following toolResult) * - * @param context - The conversation context whose messages will be reduced. The messages array is modified in-place. + * @param messages - The message array to reduce. Modified in-place. * @param _error - The error that triggered the context reduction, if any. * * @throws ContextWindowOverflowError If the context cannot be reduced further, * such as when the conversation is already minimal or when no valid trim point exists. */ - public reduceContext(context: ConversationContext, _error?: Error): void { - const messages = context.messages - + private reduceContext(messages: Message[], _error?: Error): void { // Try to truncate the tool result first const lastMessageIdxWithToolResults = this.findLastMessageWithToolResults(messages) if (lastMessageIdxWithToolResults !== undefined && this._shouldTruncateResults) { diff --git a/src/hooks/__tests__/events.test.ts b/src/hooks/__tests__/events.test.ts index acc64fd7..132456d2 100644 --- a/src/hooks/__tests__/events.test.ts +++ b/src/hooks/__tests__/events.test.ts @@ -268,6 +268,31 @@ describe('AfterModelCallEvent', () => { const event = new AfterModelCallEvent({ agent, stopData: response }) expect(event._shouldReverseCallbacks()).toBe(true) }) + + it('allows retryModelCall to be set when error is present', () => { + const agent = new Agent() + const error = new Error('Model failed') + const event = new AfterModelCallEvent({ agent, error }) + + // Initially undefined + expect(event.retryModelCall).toBeUndefined() + + // Can be set to true + event.retryModelCall = true + expect(event.retryModelCall).toBe(true) + + // Can be set to false + event.retryModelCall = false + expect(event.retryModelCall).toBe(false) + }) + + it('retryModelCall is optional and defaults to undefined', () => { + const agent = new Agent() + const error = new Error('Model failed') + const event = new AfterModelCallEvent({ agent, error }) + + expect(event.retryModelCall).toBeUndefined() + }) }) describe('ModelStreamEventHook', () => { diff --git a/src/hooks/events.ts b/src/hooks/events.ts index 9aef9e56..f482cc59 100644 --- a/src/hooks/events.ts +++ b/src/hooks/events.ts @@ -175,6 +175,13 @@ export class AfterModelCallEvent extends HookEvent { readonly stopData?: ModelStopData readonly error?: Error + /** + * Optional flag that can be set by hook callbacks to request a retry of the model call. + * Only valid when an error is present. When set to true, the agent will retry the model invocation. + * Typically used after reducing context size in response to a ContextWindowOverflowError. + */ + retryModelCall?: boolean + constructor(data: { agent: AgentData; stopData?: ModelStopData; error?: Error }) { super() this.agent = data.agent diff --git a/src/index.ts b/src/index.ts index fff11141..e364f63c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -119,7 +119,6 @@ export { export type { HookCallback, HookProvider, HookEventConstructor, ModelStopResponse } from './hooks/index.js' // Conversation Manager -export { ConversationManager } from './conversation-manager/conversation-manager.js' export { NullConversationManager } from './conversation-manager/null-conversation-manager.js' export { SlidingWindowConversationManager, diff --git a/src/types/agent.ts b/src/types/agent.ts index 9c3e99a6..147a53a4 100644 --- a/src/types/agent.ts +++ b/src/types/agent.ts @@ -10,6 +10,11 @@ export interface AgentData { * Agent state storage accessible to tools and application logic. */ state: AgentState + + /** + * The conversation history of messages between user and assistant. + */ + messages: Message[] } /** diff --git a/vended_tools/bash/__tests__/bash.test.ts b/vended_tools/bash/__tests__/bash.test.ts index ffd6a3ef..e8f40c04 100644 --- a/vended_tools/bash/__tests__/bash.test.ts +++ b/vended_tools/bash/__tests__/bash.test.ts @@ -16,7 +16,7 @@ describe.skipIf(!isNode || process.platform === 'win32')('bash tool', () => { toolUseId: 'test-id', input: {}, }, - agent: { state }, + agent: { state, messages: [] }, } return { state, context } } diff --git a/vended_tools/file_editor/__tests__/file-editor.test.ts b/vended_tools/file_editor/__tests__/file-editor.test.ts index f2443d59..a2b1f50c 100644 --- a/vended_tools/file_editor/__tests__/file-editor.test.ts +++ b/vended_tools/file_editor/__tests__/file-editor.test.ts @@ -19,7 +19,7 @@ describe('fileEditor tool', () => { toolUseId: 'test-id', input: {}, }, - agent: { state: agentState }, + agent: { state: agentState, messages: [] }, } return { state: agentState, context: toolContext } } diff --git a/vended_tools/notebook/__tests__/notebook.test.ts b/vended_tools/notebook/__tests__/notebook.test.ts index 813bdd27..a5273c69 100644 --- a/vended_tools/notebook/__tests__/notebook.test.ts +++ b/vended_tools/notebook/__tests__/notebook.test.ts @@ -14,7 +14,7 @@ describe('notebook tool', () => { toolUseId: 'test-id', input: {}, }, - agent: { state }, + agent: { state, messages: [] }, } return { state, context } }