@@ -24,7 +24,7 @@ class CogStack():
2424
2525 def __init__ (self , hosts : List [str ]):
2626 self .hosts = hosts
27- self .elastic = None
27+ self .elastic : elasticsearch . Elasticsearch
2828
2929 @classmethod
3030 def with_basic_auth (cls ,
@@ -138,11 +138,13 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
138138 -------
139139 CogStack: An instance of the CogStack class.
140140 """
141+ has_encoded_value = False
142+ api_id_value :str
143+ api_key_value :str
144+
141145 if not api_key :
142146 api_key = {"encoded" : input ("Encoded API key: " )}
143- has_encoded_value = False
144- api_id_value , api_key_value = None , None
145- if api_key is not None :
147+ else :
146148 if isinstance (api_key , str ):
147149 # If api_key is a string, it is assumed to be the encoded API key
148150 encoded = api_key
@@ -161,18 +163,18 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
161163 has_encoded_value = encoded is not None and encoded != ''
162164
163165 if (not has_encoded_value ):
164- api_id_value = api_key ["id" ] \
166+ api_id_value = str ( api_key ["id" ] \
165167 if "id" in api_key .keys () and api_key ["id" ] != '' \
166- else input ("API Id: " )
167- api_key_value = api_key ["api_key" ] \
168+ else input ("API Id: " ))
169+ api_key_value = str ( api_key ["api_key" ] \
168170 if "api_key" in api_key .keys () and api_key ["api_key" ] != '' \
169- else getpass .getpass ("API Key: " )
171+ else getpass .getpass ("API Key: " ))
170172
171173 return self .__connect (api_key = encoded if has_encoded_value else (api_id_value , api_key_value ))
172174
173175 def __connect (self ,
174176 basic_auth : Optional [tuple [str ,str ]] = None ,
175- api_key : Optional [Union [str , tuple [str , str ], None ]] = None ) -> 'CogStack' :
177+ api_key : Optional [Union [str , tuple [str , str ]]] = None ) -> 'CogStack' :
176178 """ Connect to Elasticsearch using the provided credentials.
177179 Parameters
178180 ----------
@@ -189,10 +191,10 @@ def __connect(self,
189191 Exception: If the connection to Elasticsearch fails.
190192 """
191193 self .elastic = elasticsearch .Elasticsearch (hosts = self .hosts ,
192- api_key = api_key ,
193- basic_auth = basic_auth ,
194- verify_certs = False ,
195- request_timeout = self .ES_TIMEOUT )
194+ api_key = api_key ,
195+ basic_auth = basic_auth ,
196+ verify_certs = False ,
197+ request_timeout = self .ES_TIMEOUT )
196198 if not self .elastic .ping ():
197199 raise ConnectionError ("CogStack connection failed. " \
198200 "Please check your host list and credentials and try again." )
@@ -240,6 +242,8 @@ def get_index_fields(self, index: Union[str, Sequence[str]]):
240242 If the operation fails for any reason.
241243 """
242244 try :
245+ if len (index ) == 0 :
246+ raise ValueError ('Provide at least one index or index alias name' )
243247 all_mappings = self .elastic .indices \
244248 .get_mapping (index = index , allow_no_indices = False ).body
245249 columns = ['Field' , 'Type' ]
@@ -282,8 +286,10 @@ def count_search_results(self, index: Union[str, Sequence[str]], query: dict):
282286 .. code-block:: json
283287 {"match": {"title": "python"}}}
284288 """
289+ if len (index ) == 0 :
290+ raise ValueError ('Provide at least one index or index alias name' )
285291 query = self .__extract_query (query = query )
286- count = self .elastic .count (index = index , query = query )['count' ]
292+ count = self .elastic .count (index = index , query = query , allow_no_indices = False )['count' ]
287293 return f"Number of documents: { format (count , ',' )} "
288294
289295 def read_data_with_scan (self ,
@@ -340,20 +346,23 @@ def read_data_with_scan(self,
340346 If the search fails or cancelled by the user.
341347 """
342348 try :
349+ if len (index ) == 0 :
350+ raise ValueError ('Provide at least one index or index alias name' )
343351 self .__validate_size (size = size )
344352 if "query" not in query .keys ():
345353 temp_query = query .copy ()
346354 query .clear ()
347355 query ["query" ] = temp_query
348- pr_bar = None
356+ pr_bar : tqdm . tqdm = None
349357
350358 scan_results = es_helpers .scan (self .elastic ,
351359 index = index ,
352360 query = query ,
353361 size = size ,
354362 request_timeout = request_timeout ,
355363 source = False ,
356- fields = include_fields )
364+ fields = include_fields ,
365+ allow_no_indices = False ,)
357366 all_mapped_results = []
358367 results = self .elastic .count (index = index , query = query ["query" ])
359368 pr_bar = tqdm .tqdm (scan_results , total = results ["count" ],
@@ -443,13 +452,15 @@ def read_data_with_scroll(self,
443452 value of `size` parameter.
444453 """
445454 try :
455+ if len (index ) == 0 :
456+ raise ValueError ('Provide at least one index or index alias name' )
446457 self .__validate_size (size = size )
447458 query = self .__extract_query (query = query )
448459 result_count = size
449460 all_mapped_results = []
450461 search_result = None
451- include_fields_map : Sequence [Mapping [str , Any ]] = include_fields \
452- if include_fields is not None else None
462+ include_fields_map : Union [ Sequence [Mapping [str , Any ]], None ] = \
463+ [{ "field" : field } for field in include_fields ] if include_fields is not None else None
453464
454465 pr_bar = tqdm .tqdm (desc = "CogStack retrieved..." ,
455466 disable = not show_progress , colour = 'green' )
@@ -462,6 +473,7 @@ def read_data_with_scroll(self,
462473 source = False ,
463474 scroll = "10m" ,
464475 timeout = f"{ request_timeout } s" ,
476+ allow_no_indices = False ,
465477 rest_total_hits_as_int = True )
466478
467479 pr_bar .total = search_result .body ['hits' ]['total' ]
@@ -470,6 +482,8 @@ def read_data_with_scroll(self,
470482 search_scroll_id = search_result .body ['_scroll_id' ]
471483 all_mapped_results .extend (self .__map_search_results (hits = hits ))
472484 pr_bar .update (len (hits ))
485+ if search_result ["_shards" ]["failed" ] > 0 :
486+ raise LookupError (search_result ["_shards" ]["failures" ])
473487
474488 while search_scroll_id and result_count == size :
475489 # Perform ES scroll request
@@ -559,14 +573,15 @@ def read_data_with_sorting(self,
559573 which can be used as a function parameter to continue the search.
560574 """
561575 try :
576+ if len (index ) == 0 :
577+ raise ValueError ('Provide at least one index or index alias name' )
562578 result_count = size
563579 all_mapped_results = []
564580 if sort is None :
565581 sort = {'id' : 'asc' }
566582 search_after_value = search_after
567- include_fields_map : Sequence [Mapping [str , Any ]] = include_fields \
568- if include_fields is not None \
569- else None
583+ include_fields_map : Union [Sequence [Mapping [str , Any ]], None ] = \
584+ [{"field" : field } for field in include_fields ] if include_fields is not None else None
570585
571586 self .__validate_size (size = size )
572587 query = self .__extract_query (query = query )
@@ -591,6 +606,7 @@ def read_data_with_sorting(self,
591606 search_after = search_after_value ,
592607 timeout = f"{ request_timeout } s" ,
593608 track_scores = True ,
609+ track_total_hits = True ,
594610 allow_no_indices = False ,
595611 rest_total_hits_as_int = True )
596612 hits = search_result ['hits' ]['hits' ]
@@ -599,6 +615,8 @@ def read_data_with_sorting(self,
599615 pr_bar .update (result_count )
600616 search_after_value = hits [- 1 ]['sort' ]
601617 pr_bar .total = pr_bar .total if pr_bar .total else search_result .body ['hits' ]['total' ]
618+ if search_result ["_shards" ]["failed" ] > 0 :
619+ raise LookupError (search_result ["_shards" ]["failures" ])
602620 except BaseException as err :
603621 if isinstance (err , KeyboardInterrupt ):
604622 pr_bar .bar_format = "%s{l_bar}%s{bar}%s{r_bar}" % ("\033 [0;33m" ,
@@ -619,12 +637,12 @@ def read_data_with_sorting(self,
619637
620638 def __extract_query (self , query : dict ):
621639 if "query" in query .keys ():
622- query = query ['query' ]
640+ return query ['query' ]
623641 return query
624642
625643 def __validate_size (self , size ):
626644 if size > 10000 :
627- raise ValueError ('Size must not be greater then 10000' )
645+ raise ValueError ('Size must not be greater than 10000' )
628646
629647 def __map_search_results (self , hits : Iterable ):
630648 hit : dict
0 commit comments