@@ -22,11 +22,17 @@ class MockFastAPI:
2222
2323 def __init__ (self ) -> None :
2424 """Initialize mock class."""
25- self .routers : list [Any ] = []
25+ self .routers : list [tuple [ Any , Optional [ str ]] ] = []
2626
2727 def include_router (self , router : Any , prefix : Optional [str ] = None ) -> None :
2828 """Register new router."""
29- self .routers .append (router )
29+ self .routers .append ((router , prefix ))
30+
31+ def get_routers (self ) -> list [Any ]:
32+ return [r [0 ] for r in self .routers ]
33+
34+ def get_router_prefix (self , router : Any ) -> Optional [str ]:
35+ return list (filter (lambda r : r [0 ] == router , self .routers ))[0 ][1 ]
3036
3137
3238def test_include_routers () -> None :
@@ -36,12 +42,30 @@ def test_include_routers() -> None:
3642
3743 # are all routers added?
3844 assert len (app .routers ) == 9
39- assert root .router in app .routers
40- assert info .router in app .routers
41- assert models .router in app .routers
42- assert query .router in app .routers
43- assert health .router in app .routers
44- assert config .router in app .routers
45- assert feedback .router in app .routers
46- assert streaming_query .router in app .routers
47- assert authorized .router in app .routers
45+ assert root .router in app .get_routers ()
46+ assert info .router in app .get_routers ()
47+ assert models .router in app .get_routers ()
48+ assert query .router in app .get_routers ()
49+ assert streaming_query .router in app .get_routers ()
50+ assert config .router in app .get_routers ()
51+ assert feedback .router in app .get_routers ()
52+ assert health .router in app .get_routers ()
53+ assert authorized .router in app .get_routers ()
54+
55+
56+ def test_check_prefixes () -> None :
57+ """Test the router prefixes."""
58+ app = MockFastAPI ()
59+ include_routers (app )
60+
61+ # are all routers added?
62+ assert len (app .routers ) == 9
63+ assert app .get_router_prefix (root .router ) is None
64+ assert app .get_router_prefix (info .router ) == "/v1"
65+ assert app .get_router_prefix (models .router ) == "/v1"
66+ assert app .get_router_prefix (query .router ) == "/v1"
67+ assert app .get_router_prefix (streaming_query .router ) == "/v1"
68+ assert app .get_router_prefix (config .router ) == "/v1"
69+ assert app .get_router_prefix (feedback .router ) == "/v1"
70+ assert app .get_router_prefix (health .router ) is None
71+ assert app .get_router_prefix (authorized .router ) is None
0 commit comments