1313# limitations under the License.
1414#
1515from __future__ import annotations
16- from typing import TYPE_CHECKING
16+ from typing import TYPE_CHECKING , Any
17+ from .row_response import row_key
18+ from dataclasses import dataclass
19+ from google .cloud .bigtable .row_filters import RowFilter
1720
1821if TYPE_CHECKING :
19- from google .cloud .bigtable .row_filters import RowFilter
2022 from google .cloud .bigtable import RowKeySamples
2123
2224
25+ @dataclass
26+ class _RangePoint :
27+ """Model class for a point in a row range"""
28+
29+ key : row_key
30+ is_inclusive : bool
31+
32+
33+ @dataclass
34+ class RowRange :
35+ start : _RangePoint | None
36+ end : _RangePoint | None
37+
38+ def __init__ (
39+ self ,
40+ start_key : str | bytes | None = None ,
41+ end_key : str | bytes | None = None ,
42+ start_is_inclusive : bool | None = None ,
43+ end_is_inclusive : bool | None = None ,
44+ ):
45+ # check for invalid combinations of arguments
46+ if start_is_inclusive is None :
47+ start_is_inclusive = True
48+ elif start_key is None :
49+ raise ValueError ("start_is_inclusive must be set with start_key" )
50+ if end_is_inclusive is None :
51+ end_is_inclusive = False
52+ elif end_key is None :
53+ raise ValueError ("end_is_inclusive must be set with end_key" )
54+ # ensure that start_key and end_key are bytes
55+ if isinstance (start_key , str ):
56+ start_key = start_key .encode ()
57+ elif start_key is not None and not isinstance (start_key , bytes ):
58+ raise ValueError ("start_key must be a string or bytes" )
59+ if isinstance (end_key , str ):
60+ end_key = end_key .encode ()
61+ elif end_key is not None and not isinstance (end_key , bytes ):
62+ raise ValueError ("end_key must be a string or bytes" )
63+
64+ self .start = (
65+ _RangePoint (start_key , start_is_inclusive )
66+ if start_key is not None
67+ else None
68+ )
69+ self .end = (
70+ _RangePoint (end_key , end_is_inclusive ) if end_key is not None else None
71+ )
72+
73+ def _to_dict (self ) -> dict [str , bytes ]:
74+ """Converts this object to a dictionary"""
75+ output = {}
76+ if self .start is not None :
77+ key = "start_key_closed" if self .start .is_inclusive else "start_key_open"
78+ output [key ] = self .start .key
79+ if self .end is not None :
80+ key = "end_key_closed" if self .end .is_inclusive else "end_key_open"
81+ output [key ] = self .end .key
82+ return output
83+
84+
2385class ReadRowsQuery :
2486 """
2587 Class to encapsulate details of a read row request
2688 """
2789
2890 def __init__ (
29- self , row_keys : list [str | bytes ] | str | bytes | None = None , limit = None
91+ self ,
92+ row_keys : list [str | bytes ] | str | bytes | None = None ,
93+ row_ranges : list [RowRange ] | RowRange | None = None ,
94+ limit : int | None = None ,
95+ row_filter : RowFilter | None = None ,
3096 ):
31- pass
97+ """
98+ Create a new ReadRowsQuery
3299
33- def set_limit (self , limit : int ) -> ReadRowsQuery :
34- raise NotImplementedError
100+ Args:
101+ - row_keys: row keys to include in the query
102+ a query can contain multiple keys, but ranges should be preferred
103+ - row_ranges: ranges of rows to include in the query
104+ - limit: the maximum number of rows to return. None or 0 means no limit
105+ default: None (no limit)
106+ - row_filter: a RowFilter to apply to the query
107+ """
108+ self .row_keys : set [bytes ] = set ()
109+ self .row_ranges : list [RowRange | dict [str , bytes ]] = []
110+ if row_ranges :
111+ if isinstance (row_ranges , RowRange ):
112+ row_ranges = [row_ranges ]
113+ for r in row_ranges :
114+ self .add_range (r )
115+ if row_keys :
116+ if not isinstance (row_keys , list ):
117+ row_keys = [row_keys ]
118+ for k in row_keys :
119+ self .add_key (k )
120+ self .limit : int | None = limit
121+ self .filter : RowFilter | dict [str , Any ] | None = row_filter
35122
36- def set_filter (self , filter : "RowFilter" ) -> ReadRowsQuery :
37- raise NotImplementedError
123+ @property
124+ def limit (self ) -> int | None :
125+ return self ._limit
38126
39- def add_rows (self , row_id_list : list [str ]) -> ReadRowsQuery :
40- raise NotImplementedError
127+ @limit .setter
128+ def limit (self , new_limit : int | None ):
129+ """
130+ Set the maximum number of rows to return by this query.
131+
132+ None or 0 means no limit
133+
134+ Args:
135+ - new_limit: the new limit to apply to this query
136+ Returns:
137+ - a reference to this query for chaining
138+ Raises:
139+ - ValueError if new_limit is < 0
140+ """
141+ if new_limit is not None and new_limit < 0 :
142+ raise ValueError ("limit must be >= 0" )
143+ self ._limit = new_limit
144+
145+ @property
146+ def filter (self ) -> RowFilter | dict [str , Any ] | None :
147+ return self ._filter
148+
149+ @filter .setter
150+ def filter (self , row_filter : RowFilter | dict [str , Any ] | None ):
151+ """
152+ Set a RowFilter to apply to this query
153+
154+ Args:
155+ - row_filter: a RowFilter to apply to this query
156+ Can be a RowFilter object or a dict representation
157+ Returns:
158+ - a reference to this query for chaining
159+ """
160+ if not (
161+ isinstance (row_filter , dict )
162+ or isinstance (row_filter , RowFilter )
163+ or row_filter is None
164+ ):
165+ raise ValueError ("row_filter must be a RowFilter or dict" )
166+ self ._filter = row_filter
167+
168+ def add_key (self , row_key : str | bytes ):
169+ """
170+ Add a row key to this query
171+
172+ A query can contain multiple keys, but ranges should be preferred
173+
174+ Args:
175+ - row_key: a key to add to this query
176+ Returns:
177+ - a reference to this query for chaining
178+ Raises:
179+ - ValueError if an input is not a string or bytes
180+ """
181+ if isinstance (row_key , str ):
182+ row_key = row_key .encode ()
183+ elif not isinstance (row_key , bytes ):
184+ raise ValueError ("row_key must be string or bytes" )
185+ self .row_keys .add (row_key )
41186
42187 def add_range (
43- self , start_key : str | bytes | None = None , end_key : str | bytes | None = None
44- ) -> ReadRowsQuery :
45- raise NotImplementedError
188+ self ,
189+ row_range : RowRange | dict [str , bytes ],
190+ ):
191+ """
192+ Add a range of row keys to this query.
193+
194+ Args:
195+ - row_range: a range of row keys to add to this query
196+ Can be a RowRange object or a dict representation in
197+ RowRange proto format
198+ """
199+ if not (isinstance (row_range , dict ) or isinstance (row_range , RowRange )):
200+ raise ValueError ("row_range must be a RowRange or dict" )
201+ self .row_ranges .append (row_range )
46202
47203 def shard (self , shard_keys : "RowKeySamples" | None = None ) -> list [ReadRowsQuery ]:
48204 """
@@ -54,3 +210,27 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery
54210 query (if possible)
55211 """
56212 raise NotImplementedError
213+
214+ def _to_dict (self ) -> dict [str , Any ]:
215+ """
216+ Convert this query into a dictionary that can be used to construct a
217+ ReadRowsRequest protobuf
218+ """
219+ row_ranges = []
220+ for r in self .row_ranges :
221+ dict_range = r ._to_dict () if isinstance (r , RowRange ) else r
222+ row_ranges .append (dict_range )
223+ row_keys = list (self .row_keys )
224+ row_keys .sort ()
225+ row_set = {"row_keys" : row_keys , "row_ranges" : row_ranges }
226+ final_dict : dict [str , Any ] = {
227+ "rows" : row_set ,
228+ }
229+ dict_filter = (
230+ self .filter .to_dict () if isinstance (self .filter , RowFilter ) else self .filter
231+ )
232+ if dict_filter :
233+ final_dict ["filter" ] = dict_filter
234+ if self .limit is not None :
235+ final_dict ["rows_limit" ] = self .limit
236+ return final_dict
0 commit comments