|
1 | 1 | import argparse
|
| 2 | +import logging |
2 | 3 | import os
|
3 | 4 | import pathlib
|
4 | 5 | import shutil
|
|
10 | 11 |
|
11 | 12 | from envyaml import EnvYAML
|
12 | 13 |
|
| 14 | +# Configure logging |
| 15 | +logging.basicConfig( |
| 16 | + format="%(asctime)s - %(levelname)s - %(message)s", |
| 17 | + level=logging.INFO, |
| 18 | + datefmt="%Y-%m-%d %H:%M:%S", |
| 19 | +) |
| 20 | +logging.Formatter.converter = time.gmtime |
13 | 21 |
|
14 | 22 | parser = argparse.ArgumentParser()
|
15 | 23 | subparsers = parser.add_subparsers(help="Command to run.", dest="command")
|
|
20 | 28 | )
|
21 | 29 |
|
22 | 30 | # -------------Launch------------------
|
23 |
| -parser_launch = subparsers.add_parser( |
24 |
| - "launch", help="Launch Azure resources" |
25 |
| -) |
| 31 | +parser_launch = subparsers.add_parser("launch", help="Launch Azure resources") |
26 | 32 |
|
27 | 33 | # -------------Start-------------------
|
28 | 34 | parser_start = subparsers.add_parser(
|
29 | 35 | "start", help="Start services using specificed start up commands"
|
30 | 36 | )
|
31 | 37 |
|
32 | 38 | # -------------Upload----------------
|
33 |
| -parser_upload = subparsers.add_parser( |
34 |
| - "upload", help="Encrypt and upload data." |
35 |
| -) |
| 39 | +parser_upload = subparsers.add_parser("upload", help="Encrypt and upload data.") |
36 | 40 |
|
37 | 41 | # -------------Run--------------------
|
38 | 42 | parser_run = subparsers.add_parser(
|
|
45 | 49 | )
|
46 | 50 |
|
47 | 51 | # -------------Stop-------------------
|
48 |
| -parser_stop = subparsers.add_parser( |
49 |
| - "stop", help="Stop previously started service" |
50 |
| -) |
| 52 | +parser_stop = subparsers.add_parser("stop", help="Stop previously started service") |
51 | 53 |
|
52 | 54 | # -------------Teardown---------------
|
53 |
| -parser_teardown = subparsers.add_parser( |
54 |
| - "teardown", help="Teardown Azure resources" |
55 |
| -) |
| 55 | +parser_teardown = subparsers.add_parser("teardown", help="Teardown Azure resources") |
56 | 56 |
|
57 | 57 | if __name__ == "__main__":
|
58 | 58 | oc_config = os.environ.get("MC2_CONFIG")
|
59 | 59 | if not oc_config:
|
60 |
| - raise Exception("Please set the environment variable `MC2_CONFIG` to the path of your config file") |
| 60 | + raise Exception( |
| 61 | + "Please set the environment variable `MC2_CONFIG` to the path of your config file" |
| 62 | + ) |
61 | 63 |
|
62 | 64 | mc2.set_config(oc_config)
|
63 | 65 | args = parser.parse_args()
|
|
75 | 77 |
|
76 | 78 | # If the nodes have been manually specified, don't do anything
|
77 | 79 | if config_launch.get("head") or config_launch.get("workers"):
|
78 |
| - print("Node addresses have been manually specified in the config "\ |
79 |
| - "... doing nothing") |
| 80 | + logging.warning( |
| 81 | + "Node addresses have been manually specified in the config " |
| 82 | + "... doing nothing" |
| 83 | + ) |
80 | 84 | quit()
|
81 | 85 |
|
82 | 86 | # Create resource group
|
|
123 | 127 |
|
124 | 128 | encrypted_data = [d + ".enc" for d in data]
|
125 | 129 |
|
126 |
| - print("Encrypting and uploading data...") |
127 |
| - |
128 | 130 | dst_dir = config_upload.get("dst", "")
|
129 | 131 | for i in range(len(data)):
|
130 | 132 | # Encrypt data
|
131 | 133 | if enc_format == "xgb":
|
132 | 134 | mc2.encrypt_data(data[i], encrypted_data[i], None, "xgb")
|
133 | 135 | elif enc_format == "sql":
|
134 | 136 | if schemas is None:
|
135 |
| - raise Exception("Please specify a schema when uploading data for Opaque SQL") |
| 137 | + raise Exception( |
| 138 | + "Please specify a schema when uploading data for Opaque SQL" |
| 139 | + ) |
136 | 140 | # Remove temporary files from a previous run
|
137 | 141 | if os.path.exists(encrypted_data[i]):
|
138 | 142 | if os.path.isdir(encrypted_data[i]):
|
|
150 | 154 | if dst_dir:
|
151 | 155 | remote_path = os.path.join(dst_dir, filename)
|
152 | 156 | mc2.upload_file(encrypted_data[i], remote_path, use_azure)
|
153 |
| - print("Uploaded data to {}".format(remote_path)) |
154 | 157 |
|
155 | 158 | # Remove temporary directory
|
156 | 159 | if os.path.isdir(encrypted_data[i]):
|
|
163 | 166 | script = config_run["script"]
|
164 | 167 |
|
165 | 168 | if config_run["compute"] == "xgb":
|
166 |
| - print("run() unimplemented for secure-xgboost") |
| 169 | + logging.error("run() unimplemented for secure-xgboost") |
167 | 170 | quit()
|
168 | 171 | elif config_run["compute"] == "sql":
|
169 | 172 | mc2.configure_job(config)
|
|
183 | 186 | remote_results = config_download.get("src", [])
|
184 | 187 | local_results_dir = config_download["dst"]
|
185 | 188 |
|
186 |
| - print("Downloading and decrypting data") |
187 |
| - |
188 | 189 | # Create the local results directory if it doesn't exist
|
189 | 190 | if not os.path.exists(local_results_dir):
|
190 | 191 | pathlib.Path(local_results_dir).mkdir(parents=True, exist_ok=True)
|
|
195 | 196 |
|
196 | 197 | # Fetch file
|
197 | 198 | mc2.download_file(remote_result, local_result, use_azure)
|
198 |
| - print("Downloaded result to ", local_result) |
199 | 199 |
|
200 | 200 | # Decrypt data
|
201 | 201 | if enc_format == "xgb":
|
202 | 202 | mc2.decrypt_data(local_result, local_result + ".dec", "xgb")
|
203 |
| - print("Decrypted result saved to ", local_result + ".dec") |
204 | 203 | elif enc_format == "sql":
|
205 | 204 | mc2.decrypt_data(local_result, local_result + ".dec", "sql")
|
206 |
| - print("Decrypted result saved to ", local_result + ".dec") |
207 | 205 | else:
|
208 | 206 | raise Exception("Specified format {} not supported".format(enc_format))
|
209 | 207 |
|
|
213 | 211 | os.remove(local_result)
|
214 | 212 |
|
215 | 213 | elif args.command == "stop":
|
216 |
| - print("Currently unsupported") |
| 214 | + logging.error("`opaque stop` is currently unsupported") |
217 | 215 | pass
|
218 | 216 |
|
219 | 217 | elif args.command == "teardown":
|
220 | 218 | config_teardown = config["teardown"]
|
221 | 219 |
|
222 | 220 | # If the nodes have been manually specified, don't do anything
|
223 | 221 | if config["launch"].get("head") or config["launch"].get("workers"):
|
224 |
| - print("Node addresses have been manually specified in the config "\ |
225 |
| - "... doing nothing") |
| 222 | + logging.warning( |
| 223 | + "Node addresses have been manually specified in the config " |
| 224 | + "... doing nothing" |
| 225 | + ) |
226 | 226 | quit()
|
227 | 227 |
|
228 | 228 | delete_container = config_teardown.get("container")
|
|
240 | 240 | delete_resource_group = config_teardown.get("resource_group")
|
241 | 241 | if delete_resource_group:
|
242 | 242 | mc2.delete_resource_group()
|
| 243 | + |
| 244 | + else: |
| 245 | + logging.error( |
| 246 | + "Unsupported command specified. Please type `opaque -h` for a list of supported commands." |
| 247 | + ) |
0 commit comments