Skip to content

clean_docstring #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
yield chunk.output


def clean_docstring(docstring: Optional[str]) -> str:
"""Clean a docstring by removing empty lines at start/end and normalizing indentation."""
if not docstring:
return ""

# Split into lines and remove empty lines at start/end
lines = [line.rstrip() for line in docstring.split("\n")]
while lines and not lines[0].strip():
lines.pop(0)
while lines and not lines[-1].strip():
lines.pop()

if not lines:
return ""

# Find and remove common indentation
indent = min(len(line) - len(line.lstrip()) for line in lines if line.strip())
lines = [line[indent:] if line.strip() else "" for line in lines]

return "\n".join(lines)


def wrap_run_template(
client: Callable[[], APIClient],
agent_id: str,
Expand All @@ -165,7 +187,7 @@ def wrap_run_template(

if not version and (fn.__doc__ or model):
version = VersionProperties(
instructions=fn.__doc__,
instructions=clean_docstring(fn.__doc__),
model=model,
)

Expand Down
62 changes: 61 additions & 1 deletion workflowai/core/client/_fn_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator
from typing import AsyncIterator, Union
from unittest.mock import Mock

import pytest
Expand All @@ -13,6 +13,7 @@
_RunnableStreamAgent, # pyright: ignore [reportPrivateUsage]
_RunnableStreamOutputOnlyAgent, # pyright: ignore [reportPrivateUsage]
agent_wrapper,
clean_docstring,
extract_fn_spec,
get_generic_args,
is_async_iterator,
Expand Down Expand Up @@ -113,3 +114,62 @@ async def test_fn_stream_output_only(self, mock_api_client: Mock):
assert len(chunks) == 1
assert isinstance(chunks[0], HelloTaskOutput)
assert chunks[0] == HelloTaskOutput(message="Hello, World!")


@pytest.mark.parametrize(
("value", "expected"),
[
# Empty docstrings
("", ""),
(None, ""),

# Single line docstrings
("Hello world", "Hello world"),
(" Hello world ", "Hello world"),

# Docstring with empty lines at start/end
("""

Hello world

""", "Hello world"),

# Multi-line docstring with indentation
("""
First line
Second line
Indented line
Last line
""", "First line\nSecond line\n Indented line\nLast line"),

# Docstring with empty lines in between
("""
First line

Second line

Third line
""", "First line\n\nSecond line\n\nThird line"),

# Real-world example
("""
Find the capital city of the country where the input city is located.

Guidelines:
1. First identify the country where the input city is located
2. Then provide the capital city of that country
3. Include an interesting historical or cultural fact about the capital
4. Be accurate and precise with geographical information
5. If the input city is itself the capital, still provide the information
""",
"Find the capital city of the country where the input city is located.\n\n"
"Guidelines:\n"
"1. First identify the country where the input city is located\n"
"2. Then provide the capital city of that country\n"
"3. Include an interesting historical or cultural fact about the capital\n"
"4. Be accurate and precise with geographical information\n"
"5. If the input city is itself the capital, still provide the information"),
],
)
def test_clean_docstring(value: Union[str, None], expected: str):
assert clean_docstring(value) == expected