Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions ricecooker/classes/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,27 @@ def truncate_fields(self):
)
self.source_url = self.source_url[: config.MAX_SOURCE_URL_LENGTH]

def file_dict(self, filename=None):
if not filename:
filename = self.get_filename()
return {
"size": self.size,
"preset": self.get_preset(),
"filename": filename,
"original_filename": self.original_filename,
"language": self.language,
"source_url": self.source_url,
"duration": self.duration,
}

def to_dict(self):
filename = self.get_filename()

# If file was successfully downloaded, return dict
# Otherwise return None
if filename:
if os.path.isfile(config.get_storage_path(filename)):
return {
"size": self.size,
"preset": self.get_preset(),
"filename": filename,
"original_filename": self.original_filename,
"language": self.language,
"source_url": self.source_url,
"duration": self.duration,
}
return self.file_dict(filename=filename)
else:
config.LOGGER.warning(
"File not found: {}".format(config.get_storage_path(filename))
Expand Down Expand Up @@ -609,6 +614,7 @@ class StudioFile(File):
"""

skip_upload = True
size = None

def __init__(self, checksum, ext, preset, is_primary=False, **kwargs):
kwargs["preset"] = preset
Expand All @@ -629,8 +635,12 @@ def validate(self):
self.filename, e
)
)
self.size = int(response.headers.get("Content-Length", 0))
self._validated = True

def to_dict(self):
return self.file_dict()

def __str__(self):
return self.filename

Expand Down
15 changes: 15 additions & 0 deletions ricecooker/utils/pipeline/file_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import os
import tempfile
import threading
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
Expand Down Expand Up @@ -104,6 +105,20 @@ class FileHandler(Handler):
# Subclasses can define this list to specify which exceptions should be caught and reported
HANDLED_EXCEPTIONS = []

def __init__(self):
super().__init__()
self._thread_local = threading.local()

@property
def _output_path(self):
"""Thread-safe output path property."""
return getattr(self._thread_local, "output_path", None)

@_output_path.setter
def _output_path(self, value):
"""Thread-safe output path setter."""
self._thread_local.output_path = value

def _get_context(self, context: Optional[Dict] = None):
fields = set(get_type_hints(self.CONTEXT_CLASS).keys())
context = {k: v for k, v in (context or {}).items() if k in fields}
Expand Down
21 changes: 17 additions & 4 deletions ricecooker/utils/pipeline/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from requests.exceptions import HTTPError
from requests.exceptions import InvalidSchema
from requests.exceptions import InvalidURL
from requests.exceptions import Timeout

from .context import ContextMetadata
from .context import FileMetadata
Expand Down Expand Up @@ -139,10 +140,19 @@ class CatchAllWebResourceDownloadHandler(WebResourceHandler):

PATTERNS = [""]

HANDLED_EXCEPTIONS = [HTTPError, ConnectionError, InvalidURL, InvalidSchema]
HANDLED_EXCEPTIONS = [
HTTPError,
ConnectionError,
InvalidURL,
InvalidSchema,
Timeout,
]

def handle_file(self, path, default_ext=None):
r = config.DOWNLOAD_SESSION.get(path, stream=True)
# Use explicit timeout to prevent hanging downloads
# (connection_timeout, read_timeout) - connection timeout for establishing connection,
# read timeout for time between receiving data chunks (prevents stuck downloads)
r = config.DOWNLOAD_SESSION.get(path, stream=True, timeout=(30, 60))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly a curiosity question -- how did you decide what timeout values to use here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't - Claude chose them, and they seemed fine to me!

original_filename = extract_filename_from_request(path, r)
default_ext = extract_path_ext(original_filename, default_ext=default_ext)
r.raise_for_status()
Expand Down Expand Up @@ -434,7 +444,10 @@ def execute(
# The download stage is special, as we expect it to always return a file
# if it does not, we raise an exception to prevent further processing
raise InvalidFileException(f"No file could be downloaded from {path}")

# Ensure all downloaded files are actually in storage
for metadata in metadata_list:
if metadata.path == path:
raise InvalidFileException(f"{path} failed to transfer")
if not metadata.path.startswith(os.path.abspath(config.STORAGE_DIRECTORY)):
raise InvalidFileException(f"{path} failed to transfer to storage")

return metadata_list
1,657 changes: 833 additions & 824 deletions tests/cassettes/test_youtubevideo_process_file.yaml

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions tests/pipeline/test_file_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import threading

import pytest

from ricecooker.utils.pipeline.context import FileMetadata
from ricecooker.utils.pipeline.exceptions import InvalidFileException
from ricecooker.utils.pipeline.file_handler import FileHandler

Expand Down Expand Up @@ -31,3 +34,56 @@ def test_write_file_with_exception_still_checks_file_not_empty():
# Don't write anything to file (will make it empty)
# Then raise an exception that would normally prevent cleanup
raise RuntimeError("This exception should be caught by try/finally")


class ThreadRaceTestHandler(FileHandler):
def __init__(self):
super().__init__()
self.barrier = threading.Barrier(2) # Synchronize two threads

def should_handle(self, path: str) -> bool:
return path.startswith("race-test://")

def handle_file(self, path, **kwargs):
file_id = path.split("/")[-1]

# Set _output_path for this file
self._output_path = f"/storage/{file_id}.txt"

# Wait for both threads to reach this point
self.barrier.wait()

# Now both threads continue - without thread-local storage,
# the second one would overwrite _output_path
return FileMetadata(original_filename=f"{file_id}.txt")


def test_output_path_thread_safety():
"""Test that _output_path is thread-safe and doesn't have race conditions."""
handler = ThreadRaceTestHandler()
results = {}

def thread_a():
handler._output_path = None
path = "race-test://file_A"
output = handler.execute(path)
results["A"] = output[0].path

def thread_b():
handler._output_path = None
path = "race-test://file_B"
output = handler.execute(path)
results["B"] = output[0].path

thread1 = threading.Thread(target=thread_a)
thread2 = threading.Thread(target=thread_b)

thread1.start()
thread2.start()

thread1.join()
thread2.join()

# Each thread should get its own correct path, not interfere with each other
assert results["A"] == "/storage/file_A.txt"
assert results["B"] == "/storage/file_B.txt"
37 changes: 37 additions & 0 deletions tests/pipeline/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import pytest
from vcr_config import my_vcr

from ricecooker.utils.pipeline.context import FileMetadata
from ricecooker.utils.pipeline.exceptions import InvalidFileException
from ricecooker.utils.pipeline.file_handler import FileHandler
from ricecooker.utils.pipeline.transfer import DiskResourceHandler
from ricecooker.utils.pipeline.transfer import DownloadStageHandler
from ricecooker.utils.pipeline.transfer import (
get_filename_from_content_disposition_header,
)
Expand Down Expand Up @@ -291,3 +295,36 @@ def test_disk_transfer_non_file_protocol():
"os.path.exists", return_value=False
): # Ensure it doesn't try to check a web URL
assert not handler.should_handle(path), "Handler should not handle HTTP URLs"


class DummyPassthroughHandler(FileHandler):
"""A dummy handler that passes through the original path without transferring to storage.

This simulates the bug where a download handler fails to actually download/transfer
the file but returns the original URL as the path.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but returns the original URL as the path

I'm not sure what this means. Previously, the bug was what the file wasn't getting transferred/downloaded actually, and the bug was that it seemed like it was, because there was a return value that was a path? but that the path was not the local path download location, but rather the original URL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that was what was actually happening for me - but because of the way that the handlers work, it is possible that this would cause an issue - basically, this just makes sure that download handlers always return a local file path in storage after they have finished, because otherwise they could cause issues for every other handler.

"""

def should_handle(self, path: str) -> bool:
return path.startswith("http://dummy-test-url.com")

def handle_file(self, path, **kwargs):
# Intentionally don't use write_file context manager
# This simulates a handler that fails to transfer the file to storage
return FileMetadata(original_filename="test.txt")


def test_download_stage_handler_catches_failed_transfer():
"""Test that DownloadStageHandler catches when files aren't transferred to storage.

This is a regression test for the issue where download handlers would sometimes
log "saved to [original URL]" instead of the actual storage path, indicating
that the file wasn't actually transferred to storage.
"""
# Create a DownloadStageHandler with our dummy passthrough handler
download_handler = DownloadStageHandler(children=[DummyPassthroughHandler()])

dummy_url = "http://dummy-test-url.com/test.txt"

# The handler should raise an InvalidFileException when the file isn't transferred to storage
with pytest.raises(InvalidFileException, match="failed to transfer to storage"):
download_handler.execute(dummy_url)
Loading