diff --git a/docs/api.md b/docs/api.md index d37fed9..ba783ec 100644 --- a/docs/api.md +++ b/docs/api.md @@ -128,6 +128,26 @@ Downloads a specified encrypted file, decrypts it and then behaves identically t The request body for this route is the same as for `POST /_matrix/media_proxy/unstable/download_encrypted`. +### `POST /_matrix/media_proxy/unstable/scan_file` + +Performs a scan on a file body without uploading to Matrix. This request takes a multi-part / form data +body. + +Response format: + +| Parameter | Type | Description | +|-----------|------|--------------------------------------------------------------------| +| `body` | [Blob](https://developer.mozilla.org/en-US/docs/Web/API/Blob) | The file body. | +| `file` | EncryptedFile as JSON string | The metadata (decryption key) of an encrypted file. Follows the format of the `EncryptedFile` structure from the [Matrix specification](https://spec.matrix.org/v1.2/client-server-api/#extensions-to-mroommessage-msgtypes). Only required if the file is encrypted. | + +Example: + +```json +{ + "clean": false, + "info": "***VIRUS DETECTED***" +} +``` ### `GET /_matrix/media_proxy/unstable/public_key` diff --git a/src/matrix_content_scanner/httpserver.py b/src/matrix_content_scanner/httpserver.py index d1189d2..311dde8 100644 --- a/src/matrix_content_scanner/httpserver.py +++ b/src/matrix_content_scanner/httpserver.py @@ -109,6 +109,7 @@ def _build_app(self) -> web.Application: [ web.get("/scan" + _MEDIA_PATH_REGEXP, scan_handler.handle_plain), web.post("/scan_encrypted", scan_handler.handle_encrypted), + web.post("/scan_file", scan_handler.handle_file), web.get( "/download" + _MEDIA_PATH_REGEXP, download_handler.handle_plain ), diff --git a/src/matrix_content_scanner/scanner/scanner.py b/src/matrix_content_scanner/scanner/scanner.py index d8474b1..c6bd8ec 100644 --- a/src/matrix_content_scanner/scanner/scanner.py +++ b/src/matrix_content_scanner/scanner/scanner.py @@ -7,12 +7,14 @@ import logging import os import subprocess +import uuid from asyncio import Future from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr import magic +from aiohttp import BodyPartReader from cachetools import TTLCache from canonicaljson import encode_canonical_json from humanfriendly import format_size @@ -309,6 +311,55 @@ async def _scan_file( return media + async def scan_file_on_disk( + self, file_path: str, metadata: Optional[JsonDict] = None + ) -> None: + """Scan a file that already exists on disk. The file will be deleted after scanning. + + This does not cache the result. + + Args: + file_path: The full file path to the source file. + metadata: The metadata attached to the file (e.g. decryption key), or None if + the file isn't encrypted. + + Raises: + FileDirtyError if the result of the scan said that the file is dirty. + """ + scan_filename = file_path + if metadata is not None: + with open(file_path, "rb") as f: + content = f.read() + # If the file is encrypted, we need to decrypt it before we can scan it. + media_content = self._decrypt_file(content, metadata) + scan_filename = self._write_file_to_disk( + str(uuid.uuid4()), media_content + ) + + # Remove source file now we've decrypted it. + removal_command_parts = self._removal_command.split() + removal_command_parts.append(file_path) + subprocess.run(removal_command_parts) + + try: + # Check the file's MIME type to see if it's allowed. + self._check_mimetype(scan_filename) + + # Scan the file and see if the result is positive or negative. + exit_code = await self._run_scan(scan_filename) + result = exit_code == 0 + + # Delete the file now that we've scanned it. + logger.info("Scan has finished, removing file") + finally: + removal_command_parts = self._removal_command.split() + removal_command_parts.append(scan_filename) + subprocess.run(removal_command_parts) + + # Raise an error if the result isn't clean. + if result is False: + raise FileDirtyError(cacheable=False) + async def _scan_media( self, media: MediaDescription, @@ -340,29 +391,32 @@ async def _scan_media( # If the file is encrypted, we need to decrypt it before we can scan it. media_content = self._decrypt_file(media_content, metadata) - # Check the file's MIME type to see if it's allowed. - self._check_mimetype(media_content) - # Write the file to disk. file_path = self._write_file_to_disk(media_path, media_content) - # Scan the file and see if the result is positive or negative. - exit_code = await self._run_scan(file_path) - result = exit_code == 0 + try: + # Check the file's MIME type to see if it's allowed. + self._check_mimetype(file_path) - # If the exit code isn't part of the ones we should ignore, cache the result. - cacheable = True - if exit_code in self._exit_codes_to_ignore: - logger.info( - "Scan returned exit code %d which must not be cached", exit_code - ) - cacheable = False + # Scan the file and see if the result is positive or negative. + exit_code = await self._run_scan(file_path) + result = exit_code == 0 + + # If the exit code isn't part of the ones we should ignore, cache the result. + cacheable = True + if exit_code in self._exit_codes_to_ignore: + logger.info( + "Scan returned exit code %d which must not be cached", exit_code + ) + cacheable = False - # Delete the file now that we've scanned it. - logger.info("Scan has finished, removing file") - removal_command_parts = self._removal_command.split() - removal_command_parts.append(file_path) - subprocess.run(removal_command_parts) + logger.info("Scan has finished") + finally: + # Delete the file now that we've scanned it. + logger.info("Removing file") + removal_command_parts = self._removal_command.split() + removal_command_parts.append(file_path) + subprocess.run(removal_command_parts) # Raise an error if the result isn't clean. if result is False: @@ -434,6 +488,30 @@ def _decrypt_file(self, body: bytes, metadata: JsonDict) -> bytes: info=str(e), ) + async def write_multipart_to_disk(self, multipart: BodyPartReader) -> str: + """ + Writes a multipart file body to the store directory. + + Returns: + The full file path to the file. + """ + filename = str(uuid.uuid4()) + # Figure out the full absolute path for this file. + full_path = self._store_directory.joinpath(filename).resolve() + logger.info("Writing multipart file to %s", full_path) + + # Create any directory we need. + os.makedirs(full_path.parent, exist_ok=True) + + with open(full_path, "wb") as fp: + while True: + chunk = await multipart.read_chunk() + if not chunk: + break + fp.write(chunk) + + return str(full_path) + def _write_file_to_disk(self, media_path: str, body: bytes) -> str: """Writes the given content to disk. The final file name will be a concatenation of `temp_directory` and the media's `server_name/media_id` path. @@ -495,16 +573,15 @@ async def _run_scan(self, file_name: str) -> int: return retcode - def _check_mimetype(self, media_content: bytes) -> None: - """Detects the MIME type of the provided bytes, and checks that this type is allowed + def _check_mimetype(self, filepath: str) -> None: + """Detects the MIME type of the provided file, and checks that this type is allowed (if an allow list is provided in the configuration) Args: - media_content: The file's content. If the file is encrypted, this is its - decrypted content. + filepath: The full file path. Raises: FileMimeTypeForbiddenError if one of the checks fail. """ - detected_mimetype = magic.from_buffer(media_content, mime=True) + detected_mimetype = magic.from_file(filepath, mime=True) logger.debug("Detected MIME type for file is %s", detected_mimetype) # If there's an allow list for MIME types, check that the MIME type that's been diff --git a/src/matrix_content_scanner/servlets/__init__.py b/src/matrix_content_scanner/servlets/__init__.py index 53e4ff7..37e070e 100644 --- a/src/matrix_content_scanner/servlets/__init__.py +++ b/src/matrix_content_scanner/servlets/__init__.py @@ -180,6 +180,31 @@ async def get_media_metadata_from_request( return media_path, metadata +async def get_media_metadata_from_filebody( + file_body: JsonDict, + crypto_handler: crypto.CryptoHandler, +) -> JsonDict: + """Extracts, optionally decrypts, and validates encrypted file metadata from a + request body. + + Args: + request: The request to extract the data from. + crypto_handler: The crypto handler to use if we need to decrypt an Olm-encrypted + body. + + Raises: + ContentScannerRestError(400) if the request's body is None or if the metadata + didn't pass schema validation. + """ + metadata = _metadata_from_body(file_body, crypto_handler) + + validate_encrypted_file_metadata(metadata) + + # URL parameter is ignored. + + return metadata + + def _metadata_from_body( body: JsonDict, crypto_handler: crypto.CryptoHandler ) -> JsonDict: diff --git a/src/matrix_content_scanner/servlets/scan.py b/src/matrix_content_scanner/servlets/scan.py index 461581e..833a1c0 100644 --- a/src/matrix_content_scanner/servlets/scan.py +++ b/src/matrix_content_scanner/servlets/scan.py @@ -2,12 +2,18 @@ # # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial # Please see LICENSE files in the repository root for full details. +import json from typing import TYPE_CHECKING, Optional, Tuple -from aiohttp import web +from aiohttp import BodyPartReader, web -from matrix_content_scanner.servlets import get_media_metadata_from_request, web_handler -from matrix_content_scanner.utils.errors import FileDirtyError +from matrix_content_scanner.servlets import ( + get_media_metadata_from_filebody, + get_media_metadata_from_request, + web_handler, +) +from matrix_content_scanner.utils.constants import ErrCode +from matrix_content_scanner.utils.errors import ContentScannerRestError, FileDirtyError from matrix_content_scanner.utils.types import JsonDict if TYPE_CHECKING: @@ -51,3 +57,55 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]: return await self._scan_and_format( media_path, metadata, auth_header=request.headers.get("Authorization") ) + + @web_handler + async def handle_file(self, request: web.Request) -> Tuple[int, JsonDict]: + """Handles GET requests to ../scan_file""" + try: + reader = await request.multipart() + except Exception: + raise ContentScannerRestError( + 400, + ErrCode.MALFORMED_MULTIPART, + "Request body was not a multipart body.", + ) + + body = None + metadata: Optional[JsonDict] = None + + # Iterate to find the fields. + while True: + field = await reader.next() + if (metadata and body) or field is None: + break + if not isinstance(field, BodyPartReader): + continue + if field.name == "file": + try: + file_json = await field.json() + if file_json is None: + raise Exception("'file' field is empty") + except json.decoder.JSONDecodeError as e: + raise ContentScannerRestError(400, ErrCode.MALFORMED_JSON, str(e)) + + metadata = await get_media_metadata_from_filebody( + file_json, self._crypto_handler + ) + elif field.name == "body": + body = await self._scanner.write_multipart_to_disk(field) + + if body is None: + raise ContentScannerRestError( + 400, ErrCode.MALFORMED_MULTIPART, "Missing 'body' field" + ) + + # 'metadata' is optional + + try: + await self._scanner.scan_file_on_disk(body, metadata) + except FileDirtyError as e: + res = {"clean": False, "info": e.info} + else: + res = {"clean": True, "info": "File is clean"} + + return 200, res diff --git a/src/matrix_content_scanner/utils/constants.py b/src/matrix_content_scanner/utils/constants.py index aac0b06..ab8931c 100644 --- a/src/matrix_content_scanner/utils/constants.py +++ b/src/matrix_content_scanner/utils/constants.py @@ -32,3 +32,5 @@ class ErrCode(str, Enum): MALFORMED_JSON = "MCS_MALFORMED_JSON" # The Mime type is not in the allowed list of Mime types. MIME_TYPE_FORBIDDEN = "MCS_MIME_TYPE_FORBIDDEN" + # The body was not a multipart. + MALFORMED_MULTIPART = "MCS_MALFORMED_MULTIPART"