Skip to content

Commit 97ba020

Browse files
authored
Merge pull request #3 from hwchase17/harrison/enforce_mypy
enforce mypy and fix errors
2 parents c48e6aa + 0b9fa63 commit 97ba020

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

langchain/formatting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
"""Utilities for formatting strings."""
22
from string import Formatter
3+
from typing import Any, Mapping, Sequence, Union
34

45

56
class StrictFormatter(Formatter):
67
"""A subclass of formatter that checks for extra keys."""
78

8-
def check_unused_args(self, used_args, args, kwargs):
9+
def check_unused_args(
10+
self,
11+
used_args: Sequence[Union[int, str]],
12+
args: Sequence,
13+
kwargs: Mapping[str, Any],
14+
) -> None:
915
"""Check to see if extra parameters are passed."""
1016
extra = set(kwargs).difference(used_args)
1117
if extra:
1218
raise KeyError(extra)
1319

14-
def vformat(self, format_string, args, kwargs):
20+
def vformat(
21+
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
22+
) -> str:
1523
"""Check that no arguments are provided."""
1624
if len(args) > 0:
1725
raise ValueError(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ profile = "black"
33

44
[tool.mypy]
55
ignore_missing_imports = "True"
6+
disallow_untyped_defs = "True"
67
exclude = ["notebooks"]

tests/unit_tests/test_formatting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
from langchain.formatting import formatter
55

66

7-
def test_valid_formatting():
7+
def test_valid_formatting() -> None:
88
"""Test formatting works as expected."""
99
template = "This is a {foo} test."
1010
output = formatter.format(template, foo="good")
1111
expected_output = "This is a good test."
1212
assert output == expected_output
1313

1414

15-
def test_does_not_allow_args():
15+
def test_does_not_allow_args() -> None:
1616
"""Test formatting raises error when args are provided."""
1717
template = "This is a {} test."
1818
with pytest.raises(ValueError):
1919
formatter.format(template, "good")
2020

2121

22-
def test_does_not_allow_extra_kwargs():
22+
def test_does_not_allow_extra_kwargs() -> None:
2323
"""Test formatting does not allow extra key word arguments."""
2424
template = "This is a {foo} test."
2525
with pytest.raises(KeyError):

tests/unit_tests/test_prompt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from langchain.prompt import Prompt
55

66

7-
def test_prompt_valid():
7+
def test_prompt_valid() -> None:
88
"""Test prompts can be constructed."""
99
template = "This is a {foo} test."
1010
input_variables = ["foo"]
@@ -13,23 +13,23 @@ def test_prompt_valid():
1313
assert prompt.input_variables == input_variables
1414

1515

16-
def test_prompt_missing_input_variables():
16+
def test_prompt_missing_input_variables() -> None:
1717
"""Test error is raised when input variables are not provided."""
1818
template = "This is a {foo} test."
19-
input_variables = []
19+
input_variables: list = []
2020
with pytest.raises(ValueError):
2121
Prompt(input_variables=input_variables, template=template)
2222

2323

24-
def test_prompt_extra_input_variables():
24+
def test_prompt_extra_input_variables() -> None:
2525
"""Test error is raised when there are too many input variables."""
2626
template = "This is a {foo} test."
2727
input_variables = ["foo", "bar"]
2828
with pytest.raises(ValueError):
2929
Prompt(input_variables=input_variables, template=template)
3030

3131

32-
def test_prompt_wrong_input_variables():
32+
def test_prompt_wrong_input_variables() -> None:
3333
"""Test error is raised when name of input variable is wrong."""
3434
template = "This is a {foo} test."
3535
input_variables = ["bar"]

0 commit comments

Comments
 (0)