Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ Downloads a specified encrypted file, decrypts it and then behaves identically t
The request body for this route is the same as for
`POST /_matrix/media_proxy/unstable/download_encrypted`.

### `POST /_matrix/media_proxy/unstable/scan_file`

Performs a scan on a file body without uploading to Matrix. This request takes a multi-part / form data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just thinking this might be more easily recognisable as a literal content encoding name:

Suggested change
Performs a scan on a file body without uploading to Matrix. This request takes a multi-part / form data
Performs a scan on a file body without uploading to Matrix. This request takes a `multipart/form-data`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Content-Disposition#as_a_header_for_a_multipart_body is probably the best link I can find on MDN that describes how this is encoded. A bit weak but might be better than nothing

body.

Response format:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like the request format; meanwhile we don't specify the response format except through example (which is obvious enough that it's probably fine, don't get me wrong)


| Parameter | Type | Description |
|-----------|------|--------------------------------------------------------------------|
| `body` | [Blob](https://developer.mozilla.org/en-US/docs/Web/API/Blob) | The file body. |
| `file` | EncryptedFile as JSON string | The metadata (decryption key) of an encrypted file. Follows the format of the `EncryptedFile` structure from the [Matrix specification](https://spec.matrix.org/v1.2/client-server-api/#extensions-to-mroommessage-msgtypes). Only required if the file is encrypted. |
Comment on lines +140 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unimportant: I kinda find these param names a smidge confusing; why is the file not the file? ;p


Example:

```json
{
"clean": false,
"info": "***VIRUS DETECTED***"
}
```

### `GET /_matrix/media_proxy/unstable/public_key`

Expand Down
1 change: 1 addition & 0 deletions src/matrix_content_scanner/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _build_app(self) -> web.Application:
[
web.get("/scan" + _MEDIA_PATH_REGEXP, scan_handler.handle_plain),
web.post("/scan_encrypted", scan_handler.handle_encrypted),
web.post("/scan_file", scan_handler.handle_file),
web.get(
"/download" + _MEDIA_PATH_REGEXP, download_handler.handle_plain
),
Expand Down
123 changes: 100 additions & 23 deletions src/matrix_content_scanner/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import logging
import os
import subprocess
import uuid
from asyncio import Future
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import attr
import magic
from aiohttp import BodyPartReader
from cachetools import TTLCache
from canonicaljson import encode_canonical_json
from humanfriendly import format_size
Expand Down Expand Up @@ -309,6 +311,55 @@ async def _scan_file(

return media

async def scan_file_on_disk(
self, file_path: str, metadata: Optional[JsonDict] = None
) -> None:
"""Scan a file that already exists on disk. The file will be deleted after scanning.

This does not cache the result.

Args:
file_path: The full file path to the source file.
metadata: The metadata attached to the file (e.g. decryption key), or None if
the file isn't encrypted.

Raises:
FileDirtyError if the result of the scan said that the file is dirty.
"""
scan_filename = file_path
if metadata is not None:
with open(file_path, "rb") as f:
content = f.read()
Comment on lines +330 to +332
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. is there any point in writing the encrypted file to disk, if we're just going to read it all back in to memory again?

If we can have clients send the encryption metadata as the first body part, that would mean we could do the right thing a little easier.

Although maybe it would be better if we could decrypt files without having them all in memory

That said: if this is all stuff that was a problem before you, it's also fine to leave it

# If the file is encrypted, we need to decrypt it before we can scan it.
media_content = self._decrypt_file(content, metadata)
scan_filename = self._write_file_to_disk(
str(uuid.uuid4()), media_content
)

# Remove source file now we've decrypted it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also be done in a finally block if the file fails to decrypt?

removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)
Comment on lines +340 to +342
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be quite tempted to pull this out into its own function, to make it easier to modify the removal logic in future

(and if we're being pedantic, really this should be non-blocking/async, but I appreciate this is probably what was already here)


try:
# Check the file's MIME type to see if it's allowed.
self._check_mimetype(scan_filename)

# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(scan_filename)
result = exit_code == 0

# Delete the file now that we've scanned it.
logger.info("Scan has finished, removing file")
finally:
removal_command_parts = self._removal_command.split()
removal_command_parts.append(scan_filename)
subprocess.run(removal_command_parts)

# Raise an error if the result isn't clean.
if result is False:
raise FileDirtyError(cacheable=False)

async def _scan_media(
self,
media: MediaDescription,
Expand Down Expand Up @@ -340,29 +391,32 @@ async def _scan_media(
# If the file is encrypted, we need to decrypt it before we can scan it.
media_content = self._decrypt_file(media_content, metadata)

# Check the file's MIME type to see if it's allowed.
self._check_mimetype(media_content)

# Write the file to disk.
file_path = self._write_file_to_disk(media_path, media_content)

# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
result = exit_code == 0
try:
# Check the file's MIME type to see if it's allowed.
self._check_mimetype(file_path)

# If the exit code isn't part of the ones we should ignore, cache the result.
cacheable = True
if exit_code in self._exit_codes_to_ignore:
logger.info(
"Scan returned exit code %d which must not be cached", exit_code
)
cacheable = False
# Scan the file and see if the result is positive or negative.
exit_code = await self._run_scan(file_path)
result = exit_code == 0

# If the exit code isn't part of the ones we should ignore, cache the result.
cacheable = True
if exit_code in self._exit_codes_to_ignore:
logger.info(
"Scan returned exit code %d which must not be cached", exit_code
)
cacheable = False

# Delete the file now that we've scanned it.
logger.info("Scan has finished, removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)
logger.info("Scan has finished")
finally:
# Delete the file now that we've scanned it.
logger.info("Removing file")
removal_command_parts = self._removal_command.split()
removal_command_parts.append(file_path)
subprocess.run(removal_command_parts)

# Raise an error if the result isn't clean.
if result is False:
Expand Down Expand Up @@ -434,6 +488,30 @@ def _decrypt_file(self, body: bytes, metadata: JsonDict) -> bytes:
info=str(e),
)

async def write_multipart_to_disk(self, multipart: BodyPartReader) -> str:
"""
Writes a multipart file body to the store directory.

Returns:
The full file path to the file.
"""
filename = str(uuid.uuid4())
# Figure out the full absolute path for this file.
full_path = self._store_directory.joinpath(filename).resolve()
logger.info("Writing multipart file to %s", full_path)

# Create any directory we need.
os.makedirs(full_path.parent, exist_ok=True)

with open(full_path, "wb") as fp:
while True:
chunk = await multipart.read_chunk()
if not chunk:
break
fp.write(chunk)
Comment on lines +506 to +511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would probably be better if it was async, but seems that we might have to pull in a library for this. I also guess you aren't the first to do it this way.
If you were interested though, https://pypi.org/project/aiofile/ looks reasonable


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this fails to write the file, should it delete the partial file?

return str(full_path)

def _write_file_to_disk(self, media_path: str, body: bytes) -> str:
"""Writes the given content to disk. The final file name will be a concatenation
of `temp_directory` and the media's `server_name/media_id` path.
Expand Down Expand Up @@ -495,16 +573,15 @@ async def _run_scan(self, file_name: str) -> int:

return retcode

def _check_mimetype(self, media_content: bytes) -> None:
"""Detects the MIME type of the provided bytes, and checks that this type is allowed
def _check_mimetype(self, filepath: str) -> None:
"""Detects the MIME type of the provided file, and checks that this type is allowed
(if an allow list is provided in the configuration)
Args:
media_content: The file's content. If the file is encrypted, this is its
decrypted content.
filepath: The full file path.
Raises:
FileMimeTypeForbiddenError if one of the checks fail.
"""
detected_mimetype = magic.from_buffer(media_content, mime=True)
detected_mimetype = magic.from_file(filepath, mime=True)
logger.debug("Detected MIME type for file is %s", detected_mimetype)

# If there's an allow list for MIME types, check that the MIME type that's been
Expand Down
25 changes: 25 additions & 0 deletions src/matrix_content_scanner/servlets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ async def get_media_metadata_from_request(
return media_path, metadata


async def get_media_metadata_from_filebody(
file_body: JsonDict,
crypto_handler: crypto.CryptoHandler,
) -> JsonDict:
"""Extracts, optionally decrypts, and validates encrypted file metadata from a
request body.

Args:
request: The request to extract the data from.
crypto_handler: The crypto handler to use if we need to decrypt an Olm-encrypted
body.

Raises:
ContentScannerRestError(400) if the request's body is None or if the metadata
didn't pass schema validation.
"""
metadata = _metadata_from_body(file_body, crypto_handler)

validate_encrypted_file_metadata(metadata)

# URL parameter is ignored.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really sure what this comment means


return metadata


def _metadata_from_body(
body: JsonDict, crypto_handler: crypto.CryptoHandler
) -> JsonDict:
Expand Down
64 changes: 61 additions & 3 deletions src/matrix_content_scanner/servlets/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
import json
from typing import TYPE_CHECKING, Optional, Tuple

from aiohttp import web
from aiohttp import BodyPartReader, web

from matrix_content_scanner.servlets import get_media_metadata_from_request, web_handler
from matrix_content_scanner.utils.errors import FileDirtyError
from matrix_content_scanner.servlets import (
get_media_metadata_from_filebody,
get_media_metadata_from_request,
web_handler,
)
from matrix_content_scanner.utils.constants import ErrCode
from matrix_content_scanner.utils.errors import ContentScannerRestError, FileDirtyError
from matrix_content_scanner.utils.types import JsonDict

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,3 +57,55 @@ async def handle_encrypted(self, request: web.Request) -> Tuple[int, JsonDict]:
return await self._scan_and_format(
media_path, metadata, auth_header=request.headers.get("Authorization")
)

@web_handler
async def handle_file(self, request: web.Request) -> Tuple[int, JsonDict]:
"""Handles GET requests to ../scan_file"""
try:
reader = await request.multipart()
except Exception:
raise ContentScannerRestError(
400,
ErrCode.MALFORMED_MULTIPART,
"Request body was not a multipart body.",
)

body = None
metadata: Optional[JsonDict] = None

# Iterate to find the fields.
while True:
field = await reader.next()
if (metadata and body) or field is None:
break
if not isinstance(field, BodyPartReader):
continue
if field.name == "file":
try:
file_json = await field.json()
if file_json is None:
raise Exception("'file' field is empty")
except json.decoder.JSONDecodeError as e:
raise ContentScannerRestError(400, ErrCode.MALFORMED_JSON, str(e))

metadata = await get_media_metadata_from_filebody(
file_json, self._crypto_handler
)
elif field.name == "body":
body = await self._scanner.write_multipart_to_disk(field)

if body is None:
raise ContentScannerRestError(
400, ErrCode.MALFORMED_MULTIPART, "Missing 'body' field"
)

# 'metadata' is optional

try:
await self._scanner.scan_file_on_disk(body, metadata)
except FileDirtyError as e:
res = {"clean": False, "info": e.info}
else:
res = {"clean": True, "info": "File is clean"}

return 200, res
2 changes: 2 additions & 0 deletions src/matrix_content_scanner/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ class ErrCode(str, Enum):
MALFORMED_JSON = "MCS_MALFORMED_JSON"
# The Mime type is not in the allowed list of Mime types.
MIME_TYPE_FORBIDDEN = "MCS_MIME_TYPE_FORBIDDEN"
# The body was not a multipart.
MALFORMED_MULTIPART = "MCS_MALFORMED_MULTIPART"