Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions dnastack/cli/commands/explorer/questions/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def describe_question(question_id: str, output: str, context: Optional[str], end
arg_names=['--output-file'],
help='Output file path for results'
),
ArgumentSpec(
name='local_federated',
arg_names=['--local-federated'],
help='Query collections directly via local federation instead of using server-side federation',
type=bool,
default=False
),
DATA_OUTPUT_ARG,
CONTEXT_ARG,
SINGLE_ENDPOINT_ID_ARG,
Expand All @@ -112,13 +119,15 @@ def ask_question(
args: tuple,
collections: Optional[JsonLike],
output_file: Optional[str],
local_federated: bool,
output: str,
context: Optional[str],
endpoint_id: Optional[str]
):
"""Ask a federated question with the provided parameters"""
trace = Span()
client = get_explorer_client(context=context, endpoint_id=endpoint_id, trace=trace)


# Parse collections if provided
if collections:
Expand Down Expand Up @@ -162,12 +171,20 @@ def ask_question(
collection_ids = [col.id for col in question.collections]

# Execute the question
results_iter = client.ask_federated_question(
question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)
if local_federated:
results_iter = client.ask_question_local_federated(
federated_question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)
else:
results_iter = client.ask_federated_question(
question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)

# Collect results
results = list(results_iter)
Expand Down
240 changes: 238 additions & 2 deletions dnastack/client/explorer/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Dict, Any, TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

if TYPE_CHECKING:
from dnastack.client.explorer.models import FederatedQuestion
Expand All @@ -10,7 +12,8 @@
from dnastack.client.explorer.models import (
FederatedQuestion,
FederatedQuestionListResponse,
FederatedQuestionQueryRequest
FederatedQuestionQueryRequest,
QuestionCollection
)
from dnastack.client.result_iterator import ResultLoader, InactiveLoaderError, ResultIterator
from dnastack.client.service_registry.models import ServiceType
Expand Down Expand Up @@ -136,6 +139,55 @@ def ask_federated_question(
)
)

def ask_question_local_federated(
self,
federated_question_id: str,
inputs: Dict[str, str],
collections: Optional[List[str]] = None,
trace: Optional[Span] = None
) -> 'ResultIterator[Dict[str, Any]]':
"""
Query collections directly via local federation instead of server-side federation.

Args:
federated_question_id: The ID of the federated question to ask
inputs: Dictionary of parameter name -> value mappings
collections: Optional list of collection IDs to query. If None, all collections are used.
trace: Optional tracing span

Returns:
ResultIterator[Dict[str, Any]]: Iterator over aggregated query results in federated format
"""
# Get federated question metadata to obtain per-collection question IDs
question = self.describe_federated_question(federated_question_id, trace=trace)

# Filter collections if specified
if collections is not None:
# Create a map of collection ID to QuestionCollection for filtering
collection_map = {col.id: col for col in question.collections}
target_collections = [collection_map[cid] for cid in collections if cid in collection_map]

# Check for invalid collection IDs
invalid_ids = [cid for cid in collections if cid not in collection_map]
if invalid_ids:
raise ClientError(
response=None,
trace_context=trace,
message=f"Invalid collection IDs for question '{federated_question_id}': {', '.join(invalid_ids)}"
)
else:
target_collections = question.collections

# Create the result loader for local federation
return ResultIterator(
LocalFederatedQuestionQueryResultLoader(
explorer_client=self,
collections=target_collections,
inputs=inputs,
trace=trace
)
)


class FederatedQuestionListResultLoader(ResultLoader):
"""
Expand Down Expand Up @@ -248,4 +300,188 @@ def load(self) -> List[Dict[str, Any]]:
raise ClientError(e.response, e.trace, "Invalid question parameters")
else:

raise ClientError(e.response, e.trace, "Failed to execute federated question")
raise ClientError(e.response, e.trace, "Failed to execute federated question")


class LocalFederatedQuestionQueryResultLoader(ResultLoader):
"""
Result loader for local federation queries that queries each collection directly.
"""

def __init__(
self,
explorer_client: 'ExplorerClient',
collections: List[QuestionCollection],
inputs: Dict[str, str],
trace: Optional[Span] = None
):
self.__explorer_client = explorer_client
self.__collections = collections
self.__inputs = inputs
self.__trace = trace
self.__loaded = False

def has_more(self) -> bool:
return not self.__loaded

def load(self) -> List[Dict[str, Any]]:
if self.__loaded:
raise InactiveLoaderError("LocalFederatedQuestionQueryResultLoader")

# Execute parallel queries to each collection
with ThreadPoolExecutor() as executor:
# Submit all queries
future_to_collection = {
executor.submit(
self._query_single_collection,
collection
): collection
for collection in self.__collections
}

# Collect results
results = []
for future in as_completed(future_to_collection):
result = future.result()
results.append(result)

# Return results directly as a list to match federated format
self.__loaded = True
return results # Return as list to match federated endpoint format

def _query_single_collection(self, collection: QuestionCollection) -> Dict[str, Any]:
"""
Query a single collection and return the result in federated format.
Handles Data Connect pagination by following next_page_url links.
"""
start_time = time.time()

# Build the collection-specific endpoint URL
# Note: explorer URL already ends with /api/, so we don't need to add it again
initial_url = urljoin(
self.__explorer_client.url,
f"collections/{collection.slug}/questions/{collection.question_id}/query"
)

try:
# Collect all data across all pages
all_data = []
data_model = None
current_url = None
visited_urls = []

with self.__explorer_client._session as session:
# First request - POST with params to initiate query
response = session.post(
initial_url,
json={"params": self.__inputs},
trace_context=self.__trace
)
visited_urls.append(initial_url)

while True:
# Parse the Data Connect response
table_data = response.json()

# Capture data model from first response
if data_model is None and 'data_model' in table_data:
data_model = table_data['data_model']

# Add data from this page
if 'data' in table_data and isinstance(table_data['data'], list):
# Add collection_name to each item
for item in table_data['data']:
item['collection_name'] = collection.name
all_data.extend(table_data['data'])

# Check for next page
pagination = table_data.get('pagination')
if pagination and pagination.get('next_page_url'):
current_url = pagination['next_page_url']
# Handle relative URLs
if current_url and not current_url.startswith(('http://', 'https://')):
current_url = urljoin(visited_urls[-1], current_url)

# Prevent infinite loops
if current_url in visited_urls:
break

# Follow pagination with GET request
response = session.get(
current_url,
trace_context=self.__trace
)
visited_urls.append(current_url)
else:
# No more pages
break

# Build final aggregated response
aggregated_table_data = {
"data": all_data,
"data_model": data_model,
"pagination": None # No pagination in aggregated result
}

# Return in federated format
return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": aggregated_table_data,
"error": None,
"failureInfo": None
}

except HttpError as e:
# Calculate response time
response_time_ms = int((time.time() - start_time) * 1000)

# Determine failure reason
status_code = e.response.status_code if e.response else None
if status_code == 401:
reason = "UNAUTHORIZED"
message = f"Authentication required for collection {collection.name}"
elif status_code == 403:
reason = "FORBIDDEN"
message = f"Access denied to collection {collection.name}"
elif status_code == 404:
reason = "NOT_FOUND"
message = f"Question not found in collection {collection.name}"
elif status_code == 400:
reason = "BAD_REQUEST"
message = f"Invalid parameters for collection {collection.name}"
elif status_code and status_code >= 500:
reason = "SERVER_ERROR"
message = f"Server error for collection {collection.name}"
else:
reason = "UNKNOWN"
message = str(e)

# Return error in federated format
return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": None,
"error": message,
"failureInfo": {
"reason": reason,
"message": message,
"responseTimeMs": response_time_ms
}
}

except Exception as e:
# Handle non-HTTP errors
response_time_ms = int((time.time() - start_time) * 1000)

return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": None,
"error": str(e),
"failureInfo": {
"reason": "CLIENT_ERROR",
"message": str(e),
"responseTimeMs": response_time_ms
}
}
Loading
Loading