Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion components/renku_data_services/notebooks/core_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ async def get_data_sources(
nb_config: NotebooksConfig,
user: AnonymousAPIUser | AuthenticatedAPIUser,
server_name: str,
session_location: SessionLocation,
data_connectors_stream: AsyncIterator[DataConnectorWithSecrets],
work_dir: PurePosixPath,
data_connectors_overrides: list[SessionDataConnectorOverride],
Expand All @@ -286,8 +287,23 @@ async def get_data_sources(
secrets={str(secret.secret_id): secret.name for secret in dc.secrets},
storage_class=nb_config.cloud_storage.storage_class,
)
if len(dc.secrets) > 0:
if len(dc.secrets) > 0 and session_location == SessionLocation.local:
dcs_secrets[str(dc.data_connector.id)] = dc.secrets
elif len(dc.secrets) > 0 and session_location == SessionLocation.remote:
# NOTE: special handling for remote sessions; collect all secrets in the "all" key
dcs_secrets_all = dcs_secrets.get("all", [])
dcs_secrets_all.extend(
[
DataConnectorSecret(
name=f"{str(dc.data_connector.id)}_{dcs.name}",
user_id=dcs.user_id,
data_connector_id=dcs.data_connector_id,
secret_id=dcs.secret_id,
)
for dcs in dc.secrets
]
)
dcs_secrets["all"] = dcs_secrets_all
if isinstance(user, AuthenticatedAPIUser) and len(dcs_secrets) > 0:
secret_key = await user_repo.get_or_create_user_secret_key(user)
user_secret_key = get_encryption_key(secret_key.encode(), user.id.encode()).decode("utf-8")
Expand All @@ -305,6 +321,7 @@ async def get_data_sources(
# NOTE: if 'skip' is true, we do not mount that data connector
if dco.skip:
del dcs[dc_id]
dcs_secrets[dc_id] = [] # Unset any data connector secret
continue
if dco.target_path is not None and not PurePosixPath(dco.target_path).is_absolute():
dco.target_path = (work_dir / dco.target_path).as_posix()
Expand Down Expand Up @@ -686,6 +703,8 @@ def get_remote_env(
SessionEnvItem(name="RSC_REMOTE_KIND", value=remote.kind.value),
SessionEnvItem(name="RSC_FIRECREST_API_URL", value=remote.api_url),
SessionEnvItem(name="RSC_FIRECREST_SYSTEM_NAME", value=remote.system_name),
# TODO: remove fake start
SessionEnvItem(name="RSC_FAKE_START", value="true"),
]
if remote.partition:
env.append(SessionEnvItem(name="RSC_FIRECREST_PARTITION", value=remote.partition))
Expand Down Expand Up @@ -787,6 +806,7 @@ async def start_session(
nb_config=nb_config,
server_name=server_name,
user=user,
session_location=session_location,
data_connectors_stream=data_connectors_stream,
work_dir=work_dir,
data_connectors_overrides=launch_request.data_connectors_overrides or [],
Expand Down