|
3 | 3 | from time import sleep |
4 | 4 | import typer |
5 | 5 | from pydantic import BaseModel, Json, TypeAdapter |
| 6 | +from pydantic_core import SchemaValidator, core_schema |
6 | 7 | from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type |
7 | 8 | import json, requests |
8 | 9 |
|
|
13 | 14 | from examples.openai.prompting import ToolsPromptStyle |
14 | 15 | from examples.openai.subprocesses import spawn_subprocess |
15 | 16 |
|
16 | | -def _get_params_schema(fn: Callable[[Any], Any], verbose): |
17 | | - if isinstance(fn, OpenAPIMethod): |
18 | | - return fn.parameters_schema |
19 | | - |
20 | | - # converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False) |
21 | | - schema = TypeAdapter(fn).json_schema() |
22 | | - # Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs. |
23 | | - if verbose: |
24 | | - sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n') |
25 | | - return schema |
| 17 | +def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]): |
| 18 | + args_validator = SchemaValidator(core_schema.call_schema( |
| 19 | + arguments=ta.core_schema['arguments_schema'], |
| 20 | + function=fn, |
| 21 | + )) |
| 22 | + return lambda **kwargs: args_validator.validate_python(kwargs) |
26 | 23 |
|
27 | 24 | def completion_with_tool_usage( |
28 | 25 | *, |
@@ -50,18 +47,28 @@ def completion_with_tool_usage( |
50 | 47 | schema = type_adapter.json_schema() |
51 | 48 | response_format=ResponseFormat(type="json_object", schema=schema) |
52 | 49 |
|
53 | | - tool_map = {fn.__name__: fn for fn in tools} |
54 | | - tools_schemas = [ |
55 | | - Tool( |
56 | | - type="function", |
57 | | - function=ToolFunction( |
58 | | - name=fn.__name__, |
59 | | - description=fn.__doc__ or '', |
60 | | - parameters=_get_params_schema(fn, verbose=verbose) |
| 50 | + tool_map = {} |
| 51 | + tools_schemas = [] |
| 52 | + for fn in tools: |
| 53 | + if isinstance(fn, OpenAPIMethod): |
| 54 | + tool_map[fn.__name__] = fn |
| 55 | + parameters_schema = fn.parameters_schema |
| 56 | + else: |
| 57 | + ta = TypeAdapter(fn) |
| 58 | + tool_map[fn.__name__] = make_call_adapter(ta, fn) |
| 59 | + parameters_schema = ta.json_schema() |
| 60 | + if verbose: |
| 61 | + sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n') |
| 62 | + tools_schemas.append( |
| 63 | + Tool( |
| 64 | + type="function", |
| 65 | + function=ToolFunction( |
| 66 | + name=fn.__name__, |
| 67 | + description=fn.__doc__ or '', |
| 68 | + parameters=parameters_schema, |
| 69 | + ) |
61 | 70 | ) |
62 | 71 | ) |
63 | | - for fn in tools |
64 | | - ] |
65 | 72 |
|
66 | 73 | i = 0 |
67 | 74 | while (max_iterations is None or i < max_iterations): |
@@ -106,7 +113,7 @@ def completion_with_tool_usage( |
106 | 113 | sys.stdout.write(f'⚙️ {pretty_call}') |
107 | 114 | sys.stdout.flush() |
108 | 115 | tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments) |
109 | | - sys.stdout.write(f" -> {tool_result}\n") |
| 116 | + sys.stdout.write(f" → {tool_result}\n") |
110 | 117 | messages.append(Message( |
111 | 118 | tool_call_id=tool_call.id, |
112 | 119 | role="tool", |
@@ -203,6 +210,8 @@ def main( |
203 | 210 | if std_tools: |
204 | 211 | tool_functions.extend(collect_functions(StandardTools)) |
205 | 212 |
|
| 213 | + sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n') |
| 214 | + |
206 | 215 | response_model: Union[type, Json[Any]] = None #str |
207 | 216 | if format: |
208 | 217 | if format in types: |
|
0 commit comments