diff --git a/flask_pymongo/__init__.py b/flask_pymongo/__init__.py index a27abb5..1737505 100644 --- a/flask_pymongo/__init__.py +++ b/flask_pymongo/__init__.py @@ -27,6 +27,7 @@ __all__ = ("PyMongo", "ASCENDING", "DESCENDING", "BSONObjectIdConverter", "BSONProvider") import hashlib +import warnings from mimetypes import guess_type from typing import Any @@ -183,12 +184,24 @@ def get_upload(filename): response.headers["Content-Disposition"] = f"attachment; filename={filename}" response.content_length = fileobj.length response.last_modified = fileobj.upload_date - # Compute the sha1 sum of the file for the etag. - pos = fileobj.tell() - raw_data = fileobj.read() - fileobj.seek(pos) - digest = hashlib.sha1(raw_data).hexdigest() - response.set_etag(digest) + + # GridFS does not manage its own checksum. + # Try to use a sha1 sum that we have added during a save_file. + # Fall back to a legacy md5 sum if it exists. + # Otherwise, compute the sha1 sum directly. + try: + etag = fileobj.sha1 + except AttributeError: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + etag = fileobj.md5 + if etag is None: + pos = fileobj.tell() + raw_data = fileobj.read() + fileobj.seek(pos) + etag = hashlib.sha1(raw_data).hexdigest() + response.set_etag(etag) + response.cache_control.max_age = cache_for response.cache_control.public = True response.make_conditional(request) @@ -237,5 +250,23 @@ def save_upload(filename): db_obj = self.db assert db_obj is not None, "Please initialize the app before calling save_file!" storage = GridFS(db_obj, base) - id = storage.put(fileobj, filename=filename, content_type=content_type, **kwargs) - return id + + # GridFS does not manage its own checksum, so we attach a sha1 to the file + # for use as an etag. + hashingfile = _Wrapper(fileobj) + with storage.new_file(filename=filename, content_type=content_type, **kwargs) as grid_file: + grid_file.write(hashingfile) + grid_file.sha1 = hashingfile.hash.hexdigest() + return grid_file._id + + +class _Wrapper: + def __init__(self, file): + self.file = file + self.hash = hashlib.sha1() + + def read(self, n): + data = self.file.read(n) + if data: + self.hash.update(data) + return data diff --git a/tests/test_gridfs.py b/tests/test_gridfs.py index ee6b6e3..8e8bf10 100644 --- a/tests/test_gridfs.py +++ b/tests/test_gridfs.py @@ -1,6 +1,7 @@ from __future__ import annotations -from hashlib import sha1 +import warnings +from hashlib import md5, sha1 from io import BytesIO import pytest @@ -95,6 +96,26 @@ def test_it_sets_supports_conditional_gets(self): resp = self.mongo.send_file("myfile.txt") assert resp.status_code == 304 + def test_it_sets_supports_conditional_gets_md5(self): + # a basic conditional GET + md5_hash = md5(self.myfile.getvalue()).hexdigest() + environ_args = { + "method": "GET", + "headers": { + "If-None-Match": md5_hash, + }, + } + storage = storage = GridFS(self.mongo.db) + with storage.new_file(filename="myfile.txt") as grid_file: + grid_file.write(self.myfile.getvalue()) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + grid_file.set("md5", md5_hash) + + with self.app.test_request_context(**environ_args): + resp = self.mongo.send_file("myfile.txt") + assert resp.status_code == 304 + def test_it_sets_cache_headers(self): resp = self.mongo.send_file("myfile.txt", cache_for=60) assert resp.cache_control.max_age == 60 diff --git a/tests/util.py b/tests/util.py index 9f2ab83..892b5f3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -33,5 +33,5 @@ def setUp(self): def tearDown(self): assert self.mongo.cx is not None self.mongo.cx.drop_database(self.dbname) - + self.mongo.cx.close() super().tearDown()