diff --git a/infera/Cargo.toml b/infera/Cargo.toml index d7d89f1..21026be 100644 --- a/infera/Cargo.toml +++ b/infera/Cargo.toml @@ -25,6 +25,12 @@ sha2 = "0.10" hex = "0.4" filetime = "0.2" +tokio = { version = "1", features = ["full"] } +tokio-util = "0.7" # for ReaderStream +actix-middleware-etag = "0.4.6" +actix-web = "4.12.0" +bytes = "1.11.0" + [dev-dependencies] tempfile = "3.10" mockito = "1.7.0" diff --git a/infera/src/http.rs b/infera/src/http.rs index d508551..cb847d5 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -1,5 +1,3 @@ -// Handles downloading and caching of remote models. - use crate::config::{LogLevel, CONFIG}; use crate::error::InferaError; use crate::log; @@ -10,6 +8,8 @@ use std::path::{Path, PathBuf}; use std::thread; use std::time::{Duration, SystemTime}; +use reqwest::header::{ETAG, IF_NONE_MATCH}; + /// A guard that guarantees a temporary file is deleted when it goes out of scope. /// This is used to implement a panic-safe cleanup of partial downloads. struct TempFileGuard<'a> { @@ -140,60 +140,56 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } -/// Handles the download and caching of a remote model from a URL. -/// -/// If the model for the given URL is already present in the local cache, this -/// function updates its access time and returns the path. Otherwise, it downloads -/// the file, evicts old cache entries if needed, stores it in the cache directory, -/// and then returns the path. -/// -/// The cache uses an LRU (Least Recently Used) eviction policy with a configurable -/// size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). -/// -/// Downloads support automatic retries with exponential backoff. -/// -/// # Arguments -/// -/// * `url` - The HTTP/HTTPS URL of the ONNX model to be downloaded. -/// -/// # Returns -/// -/// A `Result` which is: -/// * `Ok(PathBuf)`: The local file path of the cached model. -/// * `Err(InferaError)`: An error indicating failure in creating the cache directory, -/// making the HTTP request, or writing the file to disk. +// / Handles downloading and caching a remote model from a URL. +// / +// / If the model for the given URL is already present in the local cache and the ETag of this object +// / has not changed since the last call, this function updates its access time and returns the path. +// / Otherwise, it downloads the file, evicts old cache entries if needed, stores it in the cache directory, +// / and then returns the path. +// / +// / The cache uses an LRU (Least Recently Used) eviction policy with a configurable +// / size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). +// / +// / Downloads support automatic retries with exponential backoff. +// / +// / # Arguments +// / +// / * `url` - The HTTP/HTTPS URL of the ONNX model to be downloaded. +// / +// / # Returns +// / +// / A `Result` which is: +// / * `Ok(PathBuf)`: The local file path of the cached model. +// / * `Err(InferaError)`: An error indicating failure in creating the cache directory, +// / making the HTTP request, or writing the file to disk. + pub(crate) fn handle_remote_model(url: &str) -> Result { + let max_attempts = CONFIG.http_retry_attempts; + let retry_delay_ms = CONFIG.http_retry_delay_ms; + let timeout_secs = CONFIG.http_timeout_secs; + let cache_dir = cache_dir(); if !cache_dir.exists() { log!(LogLevel::Info, "Creating cache directory: {:?}", cache_dir); fs::create_dir_all(&cache_dir).map_err(|e| InferaError::CacheDirError(e.to_string()))?; } + + // Compute cache key based on URL hash let mut hasher = Sha256::new(); hasher.update(url.as_bytes()); let hash_hex = hex::encode(hasher.finalize()); let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); + let etag_path = cache_dir.join(format!("{}.etag", hash_hex)); - if cached_path.exists() { - log!(LogLevel::Info, "Cache hit for URL: {}", url); - // Update access time for LRU tracking - touch_cache_file(&cached_path)?; - return Ok(cached_path); - } - - log!( - LogLevel::Info, - "Cache miss for URL: {}, downloading...", - url - ); + // Load cached ETag if available + let etag_trimmed = match fs::read_to_string(&etag_path) { + Ok(etag_value) => etag_value.trim().to_string(), + Err(_) => "".to_string(), + }; let temp_path = cached_path.with_extension("onnx.part"); let mut guard = TempFileGuard::new(&temp_path); - // Download with retry logic - let max_attempts = CONFIG.http_retry_attempts; - let retry_delay_ms = CONFIG.http_retry_delay_ms; - let timeout_secs = CONFIG.http_timeout_secs; - let mut last_error = None; for attempt in 1..=max_attempts { @@ -205,10 +201,17 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { url ); - match download_file(url, &temp_path, timeout_secs) { - Ok(_) => { - log!(LogLevel::Info, "Successfully downloaded: {}", url); + // Perform download using helper function + match download_file_with_etag(url, &temp_path, timeout_secs, &etag_trimmed) { + Ok(Some((false, etag_str))) => { + // first value is false, it mean the object in server is new or changed, take the downloading + log!( + LogLevel::Info, + "Cache miss for URL: {}, downloading...", + url + ); + log!(LogLevel::Info, "Successfully downloaded: {}", url); // Check file size and evict cache if needed let file_size = fs::metadata(&temp_path) .map_err(|e| InferaError::IoError(e.to_string()))? @@ -216,10 +219,23 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { log!(LogLevel::Debug, "Downloaded file size: {} bytes", file_size); evict_cache_if_needed(file_size)?; - fs::rename(&temp_path, &cached_path) .map_err(|e| InferaError::IoError(e.to_string()))?; + // Update ETag file + fs::write(&etag_path, &etag_str) + .map_err(|e| InferaError::IoError(e.to_string()))?; + + guard.commit(); + return Ok(cached_path); + } + Ok(Some((true, _))) => { + // first value is true, it mean the object in server is Not Modified, use cached file + log!(LogLevel::Info, "Cache hit for URL: {}", url); + + // Update access time for LRU tracking + touch_cache_file(&cached_path)?; + guard.commit(); return Ok(cached_path); } @@ -232,7 +248,6 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { e ); last_error = Some(e); - // Don't sleep after the last attempt if attempt < max_attempts { let delay = Duration::from_millis(retry_delay_ms * attempt as u64); @@ -240,6 +255,11 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { thread::sleep(delay); } } + Ok(None) => { + // theoretically unreachable, but necessary to satisfy exhaustiveness + // Handle as error + log!(LogLevel::Error, "Can't exist None for this matching"); + } } } @@ -249,37 +269,86 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { max_attempts, url ); - Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown error".to_string()))) + + Err(last_error + .unwrap_or_else(|| InferaError::HttpRequestError("Unknown download error".to_string()))) } -/// Download a file from a URL to a local path with timeout -fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), InferaError> { +// / Downloads a file from a URL with ETag support for caching. +// / +// / If the ETag is non-empty and the server responds with `304 Not Modified`, +// / this function returns early with `(true, etag)` indicating no download. +// / +// / Otherwise, it downloads the file to `dest`, updates the ETag if present, +// / and returns `(false, new_etag)`. +// / +// / # Arguments +// / +// / * `url` - The URL to download from. +// / * `dest` - The destination path to save the file. +// / * `timeout_secs` - HTTP request timeout in seconds. +// / * `etag` - Cached ETag string for conditional requests. +// / +// / # Returns +// / +// / `Result, InferaError>` tuple where the boolean indicates +// / if the file was not modified (true) and the string is the ETag. +#[allow(clippy::needless_return)] +fn download_file_with_etag( + url: &str, + dest: &Path, + timeout_secs: u64, + etag: &str, +) -> Result, InferaError> { let client = reqwest::blocking::Client::builder() .timeout(Duration::from_secs(timeout_secs)) .build() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - let mut response = client - .get(url) + let request = client.get(url).header(IF_NONE_MATCH, etag); + + let mut response = request .send() .map_err(|e| InferaError::HttpRequestError(e.to_string()))? .error_for_status() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + if !etag.is_empty() && response.status() == reqwest::StatusCode::NOT_MODIFIED { + // Not modified, no file write needed, return true and etag + return Ok(Some((true, etag.to_string()))); + } let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; - Ok(()) + // Extract ETag header if present for updating cache metadata + let etag_header = response + .headers() + .get(ETAG) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_owned()); + let etag_str = etag_header.unwrap_or_else(|| "".to_string()); + + return Ok(Some((false, etag_str))); } #[cfg(test)] mod tests { use super::*; + use actix_middleware_etag::Etag; + use actix_web::{web, App, Error, HttpResponse, HttpServer}; + use bytes::Bytes; use mockito::Server; use std::env; // moved here: used in tests only + use std::sync::Arc; use std::thread; use tiny_http::{Header, Response, Server as TinyServer}; + use tokio::sync::RwLock; + fn get_file_modification_time(path: &std::path::Path) -> std::io::Result { + let metadata = fs::metadata(path)?; + let modified_time = metadata.modified()?; + Ok(modified_time) + } #[test] fn test_handle_remote_model_cleanup_on_incomplete_download() { let server = TinyServer::http("127.0.0.1:0").unwrap(); @@ -386,33 +455,181 @@ mod tests { !temp_path.exists(), "Partial file should be cleaned up after a connection drop" ); - server_handle.join().unwrap(); } - #[test] - fn test_handle_remote_model_success_and_cache() { - // Serve a small body with an accurate Content-Length - let mut server = Server::new(); - let body = b"onnxdata".to_vec(); - let _m = server - .mock("GET", "/ok_model.onnx") - .with_status(200) - .with_header("Content-Length", &body.len().to_string()) - .with_body(body.clone()) - .create(); - let url = format!("{}/ok_model.onnx", server.url()); + #[actix_web::test] + async fn test_remote_model_download_and_cache_with_etag_enabled() { + use std::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); + let addr = listener.local_addr().unwrap(); + + let file_content = Arc::new(RwLock::new(b"initial content".to_vec())); + let file_content_server = file_content.clone(); + + let server = HttpServer::new(move || { + let content = file_content_server.clone(); + App::new().wrap(Etag::default()).route( + "/model.onnx", + web::get().to(move || { + let content = content.clone(); + async move { + let data = content.read().await; + let bytes = Bytes::copy_from_slice(&*data); + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/octet-stream") + .body(bytes), + ) + } + }), + ) + }) + .listen(listener) + .expect("Failed to bind server") + .run(); + + // Spawn server in background + let srv_handle = actix_web::rt::spawn(server); + + let url = format!("http://{}:{}/model.onnx", addr.ip(), addr.port()); + let second_call_url = url.clone(); + let third_call_url = url.clone(); + // Call your blocking cache-and-download function in blocking task + let path1 = tokio::task::spawn_blocking(move || handle_remote_model(&url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + assert!(path1.exists()); + let content1 = fs::read(&path1).expect("read cached file"); + let path1_modification_time = get_file_modification_time(&path1).unwrap(); + + // Call again, should refresh cache + let path2 = tokio::task::spawn_blocking(move || handle_remote_model(&second_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let content2 = fs::read(&path2).expect("read cached file"); + let path2_modification_time = get_file_modification_time(&path2).unwrap(); + + assert_eq!(path1, path2); + assert_eq!(content1, content2); + assert_eq!(path1_modification_time, path2_modification_time); - let path1 = handle_remote_model(&url).expect("download should succeed"); - assert!(path1.exists(), "cached file must exist"); + // Modify content to simulate update + { + let mut content_write = file_content.write().await; + *content_write = b"updated content".to_vec(); + } + + tokio::time::sleep(Duration::from_secs(1)).await; + + // Call again, should refresh cache + let path3 = tokio::task::spawn_blocking(move || handle_remote_model(&third_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + let content3 = fs::read(&path3).expect("read cached file"); + let path3_modification_time = get_file_modification_time(&path3).unwrap(); + assert_eq!(path1, path3); + assert_ne!(content1, content3); + assert_ne!(path1_modification_time, path3_modification_time); + + // Cleanup server + srv_handle.abort(); + } + + #[actix_web::test] + async fn test_remote_model_download_and_cache_with_etag_disabled() { + use std::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind random port"); + let addr = listener.local_addr().unwrap(); + + let file_content = Arc::new(RwLock::new(b"initial content".to_vec())); + let file_content_server = file_content.clone(); + + let server = HttpServer::new(move || { + let content = file_content_server.clone(); + App::new().route( + "/model.onnx", + web::get().to(move || { + let content = content.clone(); + async move { + let data = content.read().await; + let bytes = Bytes::copy_from_slice(&*data); + Ok::<_, Error>( + HttpResponse::Ok() + .content_type("application/octet-stream") + .body(bytes), + ) + } + }), + ) + }) + .listen(listener) + .expect("Failed to bind server") + .run(); + + // Spawn server in background + let srv_handle = actix_web::rt::spawn(server); + + let url = format!("http://{}:{}/model.onnx", addr.ip(), addr.port()); + let second_call_url = url.clone(); + let third_call_url = url.clone(); + // Call your blocking cache-and-download function in blocking task + let path1 = tokio::task::spawn_blocking(move || handle_remote_model(&url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + assert!(path1.exists()); let content1 = fs::read(&path1).expect("read cached file"); - assert_eq!(content1, body); + let path1_modification_time = get_file_modification_time(&path1).unwrap(); + + // Call again, should refresh cache + let path2 = tokio::task::spawn_blocking(move || handle_remote_model(&second_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let content2 = fs::read(&path2).expect("read cached file"); + let path2_modification_time = get_file_modification_time(&path2).unwrap(); - // Second call should hit cache and return same path without network - let path2 = handle_remote_model(&url).expect("cache should hit"); assert_eq!(path1, path2); - let temp_path = path1.with_extension("onnx.part"); - assert!(!temp_path.exists(), "no partial file should remain"); + assert_eq!(content1, content2); + assert_ne!(path1_modification_time, path2_modification_time); + + // Modify content to simulate update + { + let mut content_write = file_content.write().await; + *content_write = b"updated content".to_vec(); + } + + tokio::time::sleep(Duration::from_secs(1)).await; + + // Call again, should refresh cache + let path3 = tokio::task::spawn_blocking(move || handle_remote_model(&third_call_url)) + .await + .expect("Task panicked") + .expect("handle_remote_model failed"); + + let content3 = fs::read(&path3).expect("read cached file"); + let path3_modification_time = get_file_modification_time(&path3).unwrap(); + assert_eq!(path1, path3); + assert_ne!(content1, content3); + assert_ne!(path1_modification_time, path3_modification_time); + + // Cleanup server + srv_handle.abort(); } #[test] diff --git a/infera/src/lib.rs b/infera/src/lib.rs index aca7de8..2c4b21b 100644 --- a/infera/src/lib.rs +++ b/infera/src/lib.rs @@ -654,4 +654,4 @@ mod tests { assert_eq!(size_limit, crate::config::CONFIG.cache_size_limit); unsafe { infera_free(cache_info_ptr) }; } -} +} \ No newline at end of file