Skip to content

Commit f3c5228

Browse files
authored
fix: Merge pull request #3 from AI21Labs/pr_fixes_1
fix: Fix LC CR
2 parents 7bd791e + 88e79a9 commit f3c5228

File tree

15 files changed

+166
-105
lines changed

15 files changed

+166
-105
lines changed

.github/workflows/_release.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ jobs:
166166
- name: Run integration tests
167167
if: ${{ startsWith(inputs.working-directory, 'libs/partners/') }}
168168
env:
169+
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
169170
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
170171
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
171172
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}

libs/partners/ai21/Makefile

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@ all: help
55

66
# Define a variable for the test file path.
77
TEST_FILE ?= tests/unit_tests/
8-
9-
test:
10-
poetry run pytest $(TEST_FILE)
11-
12-
tests:
8+
integration_test integration_tests: TEST_FILE = tests/integration_tests/
9+
test tests integration_test integration_tests:
1310
poetry run pytest $(TEST_FILE)
1411

1512

libs/partners/ai21/README.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,75 @@
11
# langchain-ai21
2+
3+
This package contains the LangChain integrations for [AI21](https://docs.ai21.com/) through their [AI21](https://pypi.org/project/ai21/) SDK.
4+
5+
## Installation and Setup
6+
7+
- Install the AI21 partner package
8+
```bash
9+
pip install langchain-ai21
10+
```
11+
- Get an AI21 api key and set it as an environment variable (`AI21_API_KEY`)
12+
13+
14+
## Chat Models
15+
16+
This package contains the `ChatAI21` class, which is the recommended way to interface with AI21 Chat models.
17+
18+
To use, install the requirements, and configure your environment.
19+
20+
```bash
21+
export AI21_API_KEY=your-api-key
22+
```
23+
24+
Then initialize
25+
26+
```python
27+
from langchain_core.messages import HumanMessage
28+
from langchain_ai21.chat_models import ChatAI21
29+
30+
chat = ChatAI21(model="j2-ultra")
31+
messages = [HumanMessage(content="Hello from AI21")]
32+
chat.invoke(messages)
33+
```
34+
35+
## LLMs
36+
You can use AI21's generative AI models as Langchain LLMs:
37+
38+
```python
39+
from langchain.prompts import PromptTemplate
40+
from langchain_ai21 import AI21LLM
41+
42+
llm = AI21LLM(model="j2-ultra")
43+
44+
template = """Question: {question}
45+
46+
Answer: Let's think step by step."""
47+
prompt = PromptTemplate.from_template(template)
48+
49+
chain = prompt | llm
50+
51+
question = "Which scientist discovered relativity?"
52+
print(chain.invoke({"question": question}))
53+
```
54+
55+
## Embeddings
56+
57+
You can use AI21's embeddings models as:
58+
59+
### Query
60+
61+
```python
62+
from langchain_ai21 import AI21Embeddings
63+
64+
embeddings = AI21Embeddings()
65+
embeddings.embed_query("Hello! This is some query")
66+
```
67+
68+
### Document
69+
70+
```python
71+
from langchain_ai21 import AI21Embeddings
72+
73+
embeddings = AI21Embeddings()
74+
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
75+
```
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from langchain_ai21.chat_models import ChatAI21
22
from langchain_ai21.embeddings import AI21Embeddings
3-
from langchain_ai21.llms import AI21
3+
from langchain_ai21.llms import AI21LLM
44

55
__all__ = [
6-
"AI21",
6+
"AI21LLM",
77
"ChatAI21",
88
"AI21Embeddings",
99
]

libs/partners/ai21/langchain_ai21/chat_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ class ChatAI21(BaseChatModel, AI21Base):
8888
model = ChatAI21()
8989
"""
9090

91-
model: str = "j2-ultra"
91+
model: str
92+
"""Model type you wish to interact with.
93+
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
9294
num_results: int = 1
9395
"""The number of responses to generate for a given prompt."""
9496

libs/partners/ai21/langchain_ai21/llms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from langchain_ai21.ai21_base import AI21Base
1818

1919

20-
class AI21(BaseLLM, AI21Base):
20+
class AI21LLM(BaseLLM, AI21Base):
2121
"""AI21LLM large language models.
2222
2323
Example:
@@ -28,8 +28,9 @@ class AI21(BaseLLM, AI21Base):
2828
model = AI21LLM()
2929
"""
3030

31-
model: str = "j2-ultra"
32-
"""Model type you wish to interact with."""
31+
model: str
32+
"""Model type you wish to interact with.
33+
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
3334

3435
num_results: int = 1
3536
"""The number of responses to generate for a given prompt."""

libs/partners/ai21/poetry.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/partners/ai21/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77

88
[tool.poetry.dependencies]
99
python = ">=3.8.1,<4.0"
10-
langchain-core = ">=0.0.12"
10+
langchain-core = "^0.1.22"
1111
ai21 = "^2.0.0"
1212

1313
[tool.poetry.group.test]

libs/partners/ai21/tests/integration_tests/test_chat_models.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
"""Test ChatAI21 chat model."""
2-
import pytest
32
from langchain_core.messages import HumanMessage
43
from langchain_core.outputs import ChatGeneration
54

65
from langchain_ai21.chat_models import ChatAI21
76

87

9-
@pytest.mark.requires("ai21")
108
def test_invoke() -> None:
119
"""Test invoke tokens from AI21."""
12-
llm = ChatAI21()
10+
llm = ChatAI21(model="j2-ultra")
1311

1412
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
1513
assert isinstance(result.content, str)
1614

1715

18-
@pytest.mark.requires("ai21")
1916
def test_generation() -> None:
2017
"""Test invoke tokens from AI21."""
21-
llm = ChatAI21()
18+
llm = ChatAI21(model="j2-ultra")
2219
message = HumanMessage(content="Hello")
2320

2421
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
@@ -31,10 +28,9 @@ def test_generation() -> None:
3128
assert generation.text == generation.message.content
3229

3330

34-
@pytest.mark.requires("ai21")
3531
async def test_ageneration() -> None:
3632
"""Test invoke tokens from AI21."""
37-
llm = ChatAI21()
33+
llm = ChatAI21(model="j2-ultra")
3834
message = HumanMessage(content="Hello")
3935

4036
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))

libs/partners/ai21/tests/integration_tests/test_llms.py

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,47 @@
11
"""Test AI21LLM llm."""
22

3-
import pytest
4-
from ai21.models import Penalty
53

6-
from langchain_ai21.llms import AI21
4+
from langchain_ai21.llms import AI21LLM
75

86

9-
def _generate_llm_client_parameters() -> AI21:
10-
return AI21(
11-
max_tokens=2,
12-
temperature=0,
13-
top_p=1,
14-
top_k_return=0,
15-
num_results=1,
7+
def _generate_llm() -> AI21LLM:
8+
"""
9+
Testing AI21LLm using non default parameters with the following parameters
10+
"""
11+
return AI21LLM(
12+
model="j2-ultra",
13+
max_tokens=2, # Use less tokens for a faster response
14+
temperature=0, # for a consistent response
1615
epoch=1,
17-
count_penalty=Penalty(
18-
scale=0,
19-
apply_to_emojis=False,
20-
apply_to_numbers=False,
21-
apply_to_stopwords=False,
22-
apply_to_punctuation=False,
23-
apply_to_whitespaces=False,
24-
),
25-
frequency_penalty=Penalty(
26-
scale=0,
27-
apply_to_emojis=False,
28-
apply_to_numbers=False,
29-
apply_to_stopwords=False,
30-
apply_to_punctuation=False,
31-
apply_to_whitespaces=False,
32-
),
33-
presence_penalty=Penalty(
34-
scale=0,
35-
apply_to_emojis=False,
36-
apply_to_numbers=False,
37-
apply_to_stopwords=False,
38-
apply_to_punctuation=False,
39-
apply_to_whitespaces=False,
40-
),
4116
)
4217

4318

44-
@pytest.mark.requires("ai21")
4519
def test_stream() -> None:
4620
"""Test streaming tokens from AI21."""
47-
llm = AI21()
21+
llm = AI21LLM(
22+
model="j2-ultra",
23+
)
4824

4925
for token in llm.stream("I'm Pickle Rick"):
5026
assert isinstance(token, str)
5127

5228

53-
@pytest.mark.requires("ai21")
5429
async def test_abatch() -> None:
5530
"""Test streaming tokens from AI21LLM."""
56-
llm = AI21()
31+
llm = AI21LLM(
32+
model="j2-ultra",
33+
)
5734

5835
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
5936
for token in result:
6037
assert isinstance(token, str)
6138

6239

63-
@pytest.mark.requires("ai21")
6440
async def test_abatch_tags() -> None:
6541
"""Test batch tokens from AI21LLM."""
66-
llm = AI21()
42+
llm = AI21LLM(
43+
model="j2-ultra",
44+
)
6745

6846
result = await llm.abatch(
6947
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
@@ -72,37 +50,39 @@ async def test_abatch_tags() -> None:
7250
assert isinstance(token, str)
7351

7452

75-
@pytest.mark.requires("ai21")
7653
def test_batch() -> None:
7754
"""Test batch tokens from AI21LLM."""
78-
llm = AI21()
55+
llm = AI21LLM(
56+
model="j2-ultra",
57+
)
7958

8059
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
8160
for token in result:
8261
assert isinstance(token, str)
8362

8463

85-
@pytest.mark.requires("ai21")
8664
async def test_ainvoke() -> None:
8765
"""Test invoke tokens from AI21LLM."""
88-
llm = AI21()
66+
llm = AI21LLM(
67+
model="j2-ultra",
68+
)
8969

9070
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
9171
assert isinstance(result, str)
9272

9373

94-
@pytest.mark.requires("ai21")
9574
def test_invoke() -> None:
9675
"""Test invoke tokens from AI21LLM."""
97-
llm = AI21()
76+
llm = AI21LLM(
77+
model="j2-ultra",
78+
)
9879

9980
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
10081
assert isinstance(result, str)
10182

10283

103-
@pytest.mark.requires("ai21")
10484
def test__generate() -> None:
105-
llm = _generate_llm_client_parameters()
85+
llm = _generate_llm()
10686
llm_result = llm.generate(
10787
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
10888
stop=["##"],
@@ -112,9 +92,8 @@ def test__generate() -> None:
11292
assert llm_result.llm_output["token_count"] != 0 # type: ignore
11393

11494

115-
@pytest.mark.requires("ai21")
11695
async def test__agenerate() -> None:
117-
llm = _generate_llm_client_parameters()
96+
llm = _generate_llm()
11897
llm_result = await llm.agenerate(
11998
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
12099
stop=["##"],

0 commit comments

Comments
 (0)