@@ -58,12 +58,12 @@ class Foo(BaseModel):
5858 assert isinstance (validated_output ["bez" ][0 ], str )
5959
6060
61- @pytest . mark . skip ( reason = "Random model infinitely recurses on complex struct. Use GPT2" )
61+ @if_transformers_installed
6262def test_hugging_face_pipeline_complex_schema ():
6363 # NOTE: This is the real GPT-2 model.
6464 from transformers import pipeline
6565
66- model = pipeline ("text-generation" , "gpt2 " )
66+ model = pipeline ("text-generation" , "distilgpt2 " )
6767
6868 class MultiNum (BaseModel ):
6969 whole : int
@@ -73,10 +73,12 @@ class Tricky(BaseModel):
7373 foo : MultiNum
7474
7575 g = Guard .for_pydantic (Tricky , output_formatter = "jsonformer" )
76- response = g (model , prompt = " Sample:" )
76+ response = g (model , messages = [{ "content" : " Sample:", "role" : "user" }] )
7777 out = response .validated_output
7878 assert isinstance (out , dict )
7979 assert "foo" in out
8080 assert isinstance (out ["foo" ], dict )
81- assert isinstance (out ["foo" ]["whole" ], int | float )
81+ assert isinstance (out ["foo" ]["whole" ], int ) or isinstance (
82+ out ["foo" ]["whole" ], float
83+ )
8284 assert isinstance (out ["foo" ]["frac" ], float )
0 commit comments