11"""Test AI21LLM llm."""
22
3- import pytest
4- from ai21 .models import Penalty
53
6- from langchain_ai21 .llms import AI21
4+ from langchain_ai21 .llms import AI21LLM
75
86
9- def _generate_llm_client_parameters () -> AI21 :
10- return AI21 (
11- max_tokens = 2 ,
12- temperature = 0 ,
13- top_p = 1 ,
14- top_k_return = 0 ,
15- num_results = 1 ,
7+ def _generate_llm () -> AI21LLM :
8+ """
9+ Testing AI21LLm using non default parameters with the following parameters
10+ """
11+ return AI21LLM (
12+ model = "j2-ultra" ,
13+ max_tokens = 2 , # Use less tokens for a faster response
14+ temperature = 0 , # for a consistent response
1615 epoch = 1 ,
17- count_penalty = Penalty (
18- scale = 0 ,
19- apply_to_emojis = False ,
20- apply_to_numbers = False ,
21- apply_to_stopwords = False ,
22- apply_to_punctuation = False ,
23- apply_to_whitespaces = False ,
24- ),
25- frequency_penalty = Penalty (
26- scale = 0 ,
27- apply_to_emojis = False ,
28- apply_to_numbers = False ,
29- apply_to_stopwords = False ,
30- apply_to_punctuation = False ,
31- apply_to_whitespaces = False ,
32- ),
33- presence_penalty = Penalty (
34- scale = 0 ,
35- apply_to_emojis = False ,
36- apply_to_numbers = False ,
37- apply_to_stopwords = False ,
38- apply_to_punctuation = False ,
39- apply_to_whitespaces = False ,
40- ),
4116 )
4217
4318
44- @pytest .mark .requires ("ai21" )
4519def test_stream () -> None :
4620 """Test streaming tokens from AI21."""
47- llm = AI21 ()
21+ llm = AI21LLM (
22+ model = "j2-ultra" ,
23+ )
4824
4925 for token in llm .stream ("I'm Pickle Rick" ):
5026 assert isinstance (token , str )
5127
5228
53- @pytest .mark .requires ("ai21" )
5429async def test_abatch () -> None :
5530 """Test streaming tokens from AI21LLM."""
56- llm = AI21 ()
31+ llm = AI21LLM (
32+ model = "j2-ultra" ,
33+ )
5734
5835 result = await llm .abatch (["I'm Pickle Rick" , "I'm not Pickle Rick" ])
5936 for token in result :
6037 assert isinstance (token , str )
6138
6239
63- @pytest .mark .requires ("ai21" )
6440async def test_abatch_tags () -> None :
6541 """Test batch tokens from AI21LLM."""
66- llm = AI21 ()
42+ llm = AI21LLM (
43+ model = "j2-ultra" ,
44+ )
6745
6846 result = await llm .abatch (
6947 ["I'm Pickle Rick" , "I'm not Pickle Rick" ], config = {"tags" : ["foo" ]}
@@ -72,37 +50,39 @@ async def test_abatch_tags() -> None:
7250 assert isinstance (token , str )
7351
7452
75- @pytest .mark .requires ("ai21" )
7653def test_batch () -> None :
7754 """Test batch tokens from AI21LLM."""
78- llm = AI21 ()
55+ llm = AI21LLM (
56+ model = "j2-ultra" ,
57+ )
7958
8059 result = llm .batch (["I'm Pickle Rick" , "I'm not Pickle Rick" ])
8160 for token in result :
8261 assert isinstance (token , str )
8362
8463
85- @pytest .mark .requires ("ai21" )
8664async def test_ainvoke () -> None :
8765 """Test invoke tokens from AI21LLM."""
88- llm = AI21 ()
66+ llm = AI21LLM (
67+ model = "j2-ultra" ,
68+ )
8969
9070 result = await llm .ainvoke ("I'm Pickle Rick" , config = {"tags" : ["foo" ]})
9171 assert isinstance (result , str )
9272
9373
94- @pytest .mark .requires ("ai21" )
9574def test_invoke () -> None :
9675 """Test invoke tokens from AI21LLM."""
97- llm = AI21 ()
76+ llm = AI21LLM (
77+ model = "j2-ultra" ,
78+ )
9879
9980 result = llm .invoke ("I'm Pickle Rick" , config = dict (tags = ["foo" ]))
10081 assert isinstance (result , str )
10182
10283
103- @pytest .mark .requires ("ai21" )
10484def test__generate () -> None :
105- llm = _generate_llm_client_parameters ()
85+ llm = _generate_llm ()
10686 llm_result = llm .generate (
10787 prompts = ["Hey there, my name is Pickle Rick. What is your name?" ],
10888 stop = ["##" ],
@@ -112,9 +92,8 @@ def test__generate() -> None:
11292 assert llm_result .llm_output ["token_count" ] != 0 # type: ignore
11393
11494
115- @pytest .mark .requires ("ai21" )
11695async def test__agenerate () -> None :
117- llm = _generate_llm_client_parameters ()
96+ llm = _generate_llm ()
11897 llm_result = await llm .agenerate (
11998 prompts = ["Hey there, my name is Pickle Rick. What is your name?" ],
12099 stop = ["##" ],
0 commit comments