Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions images/pyspark-notebook/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ ENV SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M
COPY setup_spark.py /opt/setup-scripts/

# Setup Spark
RUN SPARK_VERSION="${spark_version}" \
HADOOP_VERSION="${hadoop_version}" \
SCALA_VERSION="${scala_version}" \
SPARK_DOWNLOAD_URL="${spark_download_url}" \
/opt/setup-scripts/setup_spark.py
RUN /opt/setup-scripts/setup_spark.py \
--spark-version="${spark_version}" \
--hadoop-version="${hadoop_version}" \
--scala-version="${scala_version}" \
--spark-download-url="${spark_download_url}"

# Configure IPython system-wide
COPY ipython_kernel_config.py "/etc/ipython/"
Expand Down
29 changes: 17 additions & 12 deletions images/pyspark-notebook/setup_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

# Requirements:
# - Run as the root user
# - Required env variables: SPARK_HOME, HADOOP_VERSION, SPARK_DOWNLOAD_URL
# - Optional env variables: SPARK_VERSION, SCALA_VERSION
# - Required env variable: SPARK_HOME

import argparse
import logging
import os
import subprocess
Expand All @@ -27,13 +27,10 @@ def get_all_refs(url: str) -> list[str]:
return [a["href"] for a in soup.find_all("a", href=True)]


def get_spark_version() -> str:
def get_latest_spark_version() -> str:
"""
If ${SPARK_VERSION} env variable is non-empty, simply returns it
Otherwise, returns the last stable version of Spark using spark archive
Returns the last stable version of Spark using spark archive
"""
if (version := os.environ["SPARK_VERSION"]) != "":
return version
LOGGER.info("Downloading Spark versions information")
all_refs = get_all_refs("https://archive.apache.org/dist/spark/")
stable_versions = [
Expand Down Expand Up @@ -106,12 +103,20 @@ def configure_spark(spark_dir_name: str, spark_home: Path) -> None:
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

spark_version = get_spark_version()
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--spark-version", required=True)
arg_parser.add_argument("--hadoop-version", required=True)
arg_parser.add_argument("--scala-version", required=True)
arg_parser.add_argument("--spark-download-url", type=Path, required=True)
args = arg_parser.parse_args()

args.spark_version = args.spark_version or get_latest_spark_version()

spark_dir_name = download_spark(
spark_version=spark_version,
hadoop_version=os.environ["HADOOP_VERSION"],
scala_version=os.environ["SCALA_VERSION"],
spark_download_url=Path(os.environ["SPARK_DOWNLOAD_URL"]),
spark_version=args.spark_version,
hadoop_version=args.hadoop_version,
scala_version=args.scala_version,
spark_download_url=args.spark_download_url,
)
configure_spark(
spark_dir_name=spark_dir_name, spark_home=Path(os.environ["SPARK_HOME"])
Expand Down