Skip to content

Commit a452b24

Browse files
committed
fix(storage): return models instead of contained fields
this way, if one day relevant fields are added to the top objects it isn't a breaking change
1 parent 845b9f3 commit a452b24

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

src/storage/src/storage3/_async/vectors.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
ListVectorsResponse,
1818
MetadataConfiguration,
1919
QueryVectorsResponse,
20-
VectorBucket,
2120
VectorData,
2221
VectorFilter,
23-
VectorIndex,
24-
VectorMatch,
2522
VectorObject,
2623
)
2724
from .request import AsyncRequestBuilder
@@ -60,13 +57,13 @@ async def create_index(
6057
)
6158
await self._request.send(http_method="POST", path=["CreateIndex"], body=body)
6259

63-
async def get_index(self, index_name: str) -> Optional[VectorIndex]:
60+
async def get_index(self, index_name: str) -> Optional[GetVectorIndexResponse]:
6461
body = self.with_metadata(indexName=index_name)
6562
try:
6663
data = await self._request.send(
6764
http_method="POST", path=["GetIndex"], body=body
6865
)
69-
return GetVectorIndexResponse.model_validate_json(data.content).index
66+
return GetVectorIndexResponse.model_validate_json(data.content)
7067
except StorageApiError:
7168
return None
7269

@@ -115,14 +112,14 @@ async def put(self, vectors: List[VectorObject]) -> None:
115112

116113
async def get(
117114
self, *keys: str, return_data: bool = True, return_metadata: bool = True
118-
) -> List[VectorMatch]:
115+
) -> GetVectorsResponse:
119116
body = self.with_metadata(
120117
keys=keys, returnData=return_data, returnMetadata=return_metadata
121118
)
122119
data = await self._request.send(
123120
http_method="POST", path=["GetVectors"], body=body
124121
)
125-
return GetVectorsResponse.model_validate_json(data.content).vectors
122+
return GetVectorsResponse.model_validate_json(data.content)
126123

127124
async def list(
128125
self,
@@ -153,7 +150,7 @@ async def query(
153150
filter: Optional[VectorFilter] = None,
154151
return_distance: bool = True,
155152
return_metadata: bool = True,
156-
) -> List[VectorMatch]:
153+
) -> QueryVectorsResponse:
157154
body = self.with_metadata(
158155
queryVector=dict(query_vector),
159156
topK=topK,
@@ -164,7 +161,7 @@ async def query(
164161
data = await self._request.send(
165162
http_method="POST", path=["QueryVectors"], body=body
166163
)
167-
return QueryVectorsResponse.model_validate_json(data.content).vectors
164+
return QueryVectorsResponse.model_validate_json(data.content)
168165

169166
async def delete(self, keys: List[str]) -> None:
170167
if len(keys) < 1 or len(keys) > 500:
@@ -186,15 +183,13 @@ async def create_bucket(self, bucket_name: str) -> None:
186183
http_method="POST", path=["CreateVectorBucket"], body=body
187184
)
188185

189-
async def get_bucket(self, bucket_name: str) -> Optional[VectorBucket]:
186+
async def get_bucket(self, bucket_name: str) -> Optional[GetVectorBucketResponse]:
190187
body = {"vectorBucketName": bucket_name}
191188
try:
192189
data = await self._request.send(
193190
http_method="POST", path=["GetVectorBucket"], body=body
194191
)
195-
return GetVectorBucketResponse.model_validate_json(
196-
data.content
197-
).vectorBucket
192+
return GetVectorBucketResponse.model_validate_json(data.content)
198193
except StorageApiError:
199194
return None
200195

0 commit comments

Comments
 (0)