Skip to content

Commit 4cc39aa

Browse files
authored
Merge pull request #4 from hwchase17/harrison/add_llms
add llm objects
2 parents 97ba020 + f1d60b9 commit 4cc39aa

File tree

14 files changed

+199
-2
lines changed

14 files changed

+199
-2
lines changed

β€ŽMakefileβ€Ž

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: format lint tests
1+
.PHONY: format lint tests integration_tests
22

33
format:
44
black .
@@ -11,4 +11,7 @@ lint:
1111
mypy .
1212

1313
tests:
14-
pytest tests
14+
pytest tests/unit_tests
15+
16+
integration_tests:
17+
pytest tests/integration_tests

β€Žlangchain/llms/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Wrappers on top of large language models."""

β€Žlangchain/llms/base.pyβ€Ž

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Base interface for large language models to expose."""
2+
from abc import ABC, abstractmethod
3+
from typing import List, Optional
4+
5+
6+
class LLM(ABC):
7+
"""LLM wrapper should take in a prompt and return a string."""
8+
9+
@abstractmethod
10+
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
11+
"""Run the LLM on the given prompt and input."""

β€Žlangchain/llms/cohere.pyβ€Ž

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Wrapper around Cohere APIs."""
2+
import os
3+
from typing import Any, Dict, List, Optional
4+
5+
from pydantic import BaseModel, Extra, root_validator
6+
7+
from langchain.llms.base import LLM
8+
9+
10+
def remove_stop_tokens(text: str, stop: List[str]) -> str:
11+
"""Remove stop tokens, should they occur at end."""
12+
for s in stop:
13+
if text.endswith(s):
14+
return text[: -len(s)]
15+
return text
16+
17+
18+
class Cohere(BaseModel, LLM):
19+
"""Wrapper around Cohere large language models."""
20+
21+
client: Any
22+
model: str = "gptd-instruct-tft"
23+
max_tokens: int = 256
24+
temperature: float = 0.6
25+
k: int = 0
26+
p: int = 1
27+
frequency_penalty: int = 0
28+
presence_penalty: int = 0
29+
30+
class Config:
31+
"""Configuration for this pydantic object."""
32+
33+
extra = Extra.forbid
34+
35+
@root_validator()
36+
def template_is_valid(cls, values: Dict) -> Dict:
37+
"""Validate that api key python package exists in environment."""
38+
if "COHERE_API_KEY" not in os.environ:
39+
raise ValueError(
40+
"Did not find Cohere API key, please add an environment variable"
41+
" `COHERE_API_KEY` which contains it."
42+
)
43+
try:
44+
import cohere
45+
46+
values["client"] = cohere.Client(os.environ["COHERE_API_KEY"])
47+
except ImportError:
48+
raise ValueError(
49+
"Could not import cohere python package. "
50+
"Please it install it with `pip install cohere`."
51+
)
52+
return values
53+
54+
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
55+
"""Call out to Cohere's generate endpoint."""
56+
response = self.client.generate(
57+
model=self.model,
58+
prompt=prompt,
59+
max_tokens=self.max_tokens,
60+
temperature=self.temperature,
61+
k=self.k,
62+
p=self.p,
63+
frequency_penalty=self.frequency_penalty,
64+
presence_penalty=self.presence_penalty,
65+
stop_sequences=stop,
66+
)
67+
text = response.generations[0].text
68+
# If stop tokens are provided, Cohere's endpoint returns them.
69+
# In order to make this consistent with other endpoints, we strip them.
70+
if stop is not None:
71+
text = remove_stop_tokens(text, stop)
72+
return text

β€Žlangchain/llms/openai.pyβ€Ž

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Wrapper around OpenAI APIs."""
2+
import os
3+
from typing import Any, Dict, List, Mapping, Optional
4+
5+
from pydantic import BaseModel, Extra, root_validator
6+
7+
from langchain.llms.base import LLM
8+
9+
10+
class OpenAI(BaseModel, LLM):
11+
"""Wrapper around OpenAI large language models."""
12+
13+
client: Any
14+
model_name: str = "text-davinci-002"
15+
temperature: float = 0.7
16+
max_tokens: int = 256
17+
top_p: int = 1
18+
frequency_penalty: int = 0
19+
presence_penalty: int = 0
20+
n: int = 1
21+
best_of: int = 1
22+
23+
class Config:
24+
"""Configuration for this pydantic object."""
25+
26+
extra = Extra.forbid
27+
28+
@root_validator()
29+
def validate_environment(cls, values: Dict) -> Dict:
30+
"""Validate that api key python package exists in environment."""
31+
if "OPENAI_API_KEY" not in os.environ:
32+
raise ValueError(
33+
"Did not find OpenAI API key, please add an environment variable"
34+
" `OPENAI_API_KEY` which contains it."
35+
)
36+
try:
37+
import openai
38+
39+
values["client"] = openai.Completion
40+
except ImportError:
41+
raise ValueError(
42+
"Could not import openai python package. "
43+
"Please it install it with `pip install openai`."
44+
)
45+
return values
46+
47+
@property
48+
def default_params(self) -> Mapping[str, Any]:
49+
"""Get the default parameters for calling OpenAI API."""
50+
return {
51+
"temperature": self.temperature,
52+
"max_tokens": self.max_tokens,
53+
"top_p": self.top_p,
54+
"frequency_penalty": self.frequency_penalty,
55+
"presence_penalty": self.presence_penalty,
56+
"n": self.n,
57+
"best_of": self.best_of,
58+
}
59+
60+
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
61+
"""Call out to OpenAI's create endpoint."""
62+
response = self.client.create(
63+
model=self.model_name, prompt=prompt, stop=stop, **self.default_params
64+
)
65+
return response["choices"][0]["text"]

β€Žrequirements.txtβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
-e .
22
pytest
3+
pytest-dotenv
34
black
45
isort
56
mypy
67
flake8
78
flake8-docstrings
9+
cohere
10+
openai

β€Žtests/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""All tests for this package."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""All integration tests (tests that call out to an external API)."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""All integration tests for LLM objects."""
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Test Cohere API wrapper."""
2+
3+
from langchain.llms.cohere import Cohere
4+
5+
6+
def test_cohere_call() -> None:
7+
"""Test valid call to cohere."""
8+
llm = Cohere(max_tokens=10)
9+
output = llm("Say foo:")
10+
assert isinstance(output, str)

0 commit comments

Comments
Β (0)