diff --git a/s3file/forms.py b/s3file/forms.py index 8285ffb..f7ac608 100644 --- a/s3file/forms.py +++ b/s3file/forms.py @@ -5,6 +5,7 @@ from django.conf import settings from django.utils.functional import cached_property +from storages.utils import safe_join from s3file.storages import storage @@ -18,6 +19,7 @@ class S3FileInputMixin: upload_path = getattr( settings, 'S3FILE_UPLOAD_PATH', pathlib.PurePosixPath('tmp', 's3file') ) + upload_path = safe_join(storage.location, upload_path) expires = settings.SESSION_COOKIE_AGE @property diff --git a/s3file/middleware.py b/s3file/middleware.py index 41159e0..b50cb86 100644 --- a/s3file/middleware.py +++ b/s3file/middleware.py @@ -29,6 +29,7 @@ def __call__(self, request): def get_files_from_storage(paths): """Return S3 file where the name does not include the path.""" for path in paths: + path = path.replace(os.path.dirname(storage.location) + '/', '', 1) try: f = storage.open(path) f.name = os.path.basename(path) diff --git a/s3file/storages.py b/s3file/storages.py index 73f5f82..d6b7ee5 100644 --- a/s3file/storages.py +++ b/s3file/storages.py @@ -2,12 +2,21 @@ import datetime import hmac import json +import os from django.conf import settings from django.core.files.storage import FileSystemStorage, default_storage +from django.utils._os import safe_join class S3MockStorage(FileSystemStorage): + @property + def location(self): + return settings.AWS_LOCATION + + def path(self, name): + return safe_join(os.path.abspath(self.base_location), self.location, name) + class connection: class meta: class client: diff --git a/tests/test_forms.py b/tests/test_forms.py index 040dcee..8185d14 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -3,12 +3,12 @@ from contextlib import contextmanager import pytest -from django.core.files.storage import default_storage from django.forms import ClearableFileInput from selenium.common.exceptions import NoSuchElementException from selenium.webdriver.support.expected_conditions import staleness_of from selenium.webdriver.support.wait import WebDriverWait +from s3file.storages import storage from tests.testapp.forms import UploadForm try: @@ -35,11 +35,14 @@ def url(self): @pytest.fixture def freeze(self, monkeypatch): """Freeze datetime and UUID.""" - monkeypatch.setattr('s3file.forms.S3FileInputMixin.upload_folder', 'tmp') + monkeypatch.setattr( + 's3file.forms.S3FileInputMixin.upload_folder', + os.path.join(storage.location, 'tmp'), + ) def test_value_from_datadict(self, client, upload_file): with open(upload_file) as f: - uploaded_file = default_storage.save('test.jpg', f) + uploaded_file = storage.save('test.jpg', f) response = client.post(reverse('upload'), { 'file': json.dumps([uploaded_file]), 's3file': '["file"]', @@ -96,7 +99,7 @@ def test_get_conditions(self, freeze): assert all(condition in conditions for condition in [ {"bucket": 'test-bucket'}, {"success_action_status": "201"}, - ['starts-with', '$key', 'tmp'], + ['starts-with', '$key', 'custom/location/tmp'], ["starts-with", "$Content-Type", ""] ]), conditions @@ -139,7 +142,7 @@ def test_file_insert(self, request, driver, live_server, upload_file, freeze): assert file_input.get_attribute('name') == 'file' with wait_for_page_load(driver, timeout=10): file_input.submit() - assert default_storage.exists('tmp/%s.txt' % request.node.name) + assert storage.exists('tmp/%s.txt' % request.node.name) with pytest.raises(NoSuchElementException): error = driver.find_element_by_xpath('//body[@JSError]') @@ -208,5 +211,7 @@ def test_media(self): assert ClearableFileInput().media._js == ['s3file/js/s3file.js'] def test_upload_folder(self): - assert ClearableFileInput().upload_folder.startswith('tmp/s3file/') - assert len(ClearableFileInput().upload_folder) == 33 + assert ClearableFileInput().upload_folder.startswith( + 'custom/location/tmp/s3file/' + ) + assert len(ClearableFileInput().upload_folder) == 49 diff --git a/tests/test_middleware.py b/tests/test_middleware.py index c51064e..578b0f2 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,16 +1,16 @@ from django.core.files.base import ContentFile -from django.core.files.storage import default_storage from django.core.files.uploadedfile import SimpleUploadedFile from s3file.middleware import S3FileMiddleware +from s3file.storages import storage class TestS3FileMiddleware: def test_get_files_from_storage(self): content = b'test_get_files_from_storage' - default_storage.save('test_get_files_from_storage', ContentFile(content)) - files = S3FileMiddleware.get_files_from_storage(['test_get_files_from_storage']) + name = storage.save('test_get_files_from_storage', ContentFile(content)) + files = S3FileMiddleware.get_files_from_storage([name]) file = next(files) assert file.read() == content @@ -21,8 +21,10 @@ def test_process_request(self, rf): assert request.FILES.getlist('file') assert request.FILES.get('file').read() == b'uploaded' - default_storage.save('s3_file.txt', ContentFile(b's3file')) - request = rf.post('/', data={'file': '["s3_file.txt"]', 's3file': '["file"]'}) + storage.save('s3_file.txt', ContentFile(b's3file')) + request = rf.post('/', data={ + 'file': '["custom/location/s3_file.txt"]', 's3file': '["file"]' + }) S3FileMiddleware(lambda x: None)(request) assert request.FILES.getlist('file') assert request.FILES.get('file').read() == b's3file' diff --git a/tests/testapp/settings.py b/tests/testapp/settings.py index e5bca0b..8a17174 100644 --- a/tests/testapp/settings.py +++ b/tests/testapp/settings.py @@ -53,3 +53,4 @@ AWS_S3_REGION_NAME = 'eu-central-1' AWS_S3_SIGNATURE_VERSION = 's3v4' AWS_DEFAULT_ACL = None +AWS_LOCATION = 'custom/location/'