2121from contextlib import contextmanager
2222from typing import Any
2323from unittest import TestCase
24- from unittest .mock import MagicMock , patch
24+ from unittest .mock import ANY , MagicMock , patch
2525
2626import fsspec
2727
@@ -117,8 +117,9 @@ def setUp(self):
117117 self ._fsspec_patcher = patch (
118118 "opentelemetry.util.genai._upload.completion_hook.fsspec"
119119 )
120- self .mock_fsspec = self ._fsspec_patcher .start ()
121- self .mock_fsspec .open = ThreadSafeMagicMock ()
120+ mock_fsspec = self ._fsspec_patcher .start ()
121+ self .mock_fs = ThreadSafeMagicMock ()
122+ mock_fsspec .url_to_fs .return_value = self .mock_fs , ""
122123
123124 self .hook = UploadCompletionHook (
124125 base_path = BASE_PATH ,
@@ -133,12 +134,12 @@ def tearDown(self) -> None:
133134 def block_upload (self ):
134135 unblock_upload = threading .Event ()
135136
136- def blocked_upload (* args : Any ):
137+ def blocked_upload (* args : Any , ** kwargs : Any ):
137138 unblock_upload .wait ()
138139 return MagicMock ()
139140
140141 try :
141- self .mock_fsspec .open .side_effect = blocked_upload
142+ self .mock_fs .open .side_effect = blocked_upload
142143 yield
143144 finally :
144145 unblock_upload .set ()
@@ -156,7 +157,7 @@ def test_upload_then_shutdown(self):
156157 self .hook .shutdown ()
157158
158159 self .assertEqual (
159- self .mock_fsspec .open .call_count ,
160+ self .mock_fs .open .call_count ,
160161 3 ,
161162 "should have uploaded 3 files" ,
162163 )
@@ -172,7 +173,7 @@ def test_upload_blocked(self):
172173 )
173174
174175 self .assertLessEqual (
175- self .mock_fsspec .open .call_count ,
176+ self .mock_fs .open .call_count ,
176177 MAXSIZE ,
177178 f"uploader should only be called { MAXSIZE = } times" ,
178179 )
@@ -200,7 +201,7 @@ def test_shutdown_timeout(self):
200201 self .hook .shutdown (timeout_sec = 0.01 )
201202
202203 def test_failed_upload_logs (self ):
203- self .mock_fsspec .open .side_effect = RuntimeError ("failed to upload" )
204+ self .mock_fs .open .side_effect = RuntimeError ("failed to upload" )
204205
205206 with self .assertLogs (level = logging .ERROR ) as logs :
206207 self .hook .on_completion (
@@ -216,6 +217,27 @@ def test_invalid_upload_format(self):
216217 with self .assertRaisesRegex (ValueError , "Invalid upload_format" ):
217218 UploadCompletionHook (base_path = BASE_PATH , upload_format = "invalid" )
218219
220+ def test_upload_format_sets_content_type (self ):
221+ for upload_format , expect_content_type in (
222+ ("json" , "application/json" ),
223+ ("jsonl" , "application/jsonl" ),
224+ ):
225+ hook = UploadCompletionHook (
226+ base_path = BASE_PATH , upload_format = upload_format
227+ )
228+ self .addCleanup (hook .shutdown )
229+
230+ hook .on_completion (
231+ inputs = FAKE_INPUTS ,
232+ outputs = FAKE_OUTPUTS ,
233+ system_instruction = FAKE_SYSTEM_INSTRUCTION ,
234+ )
235+ hook .shutdown ()
236+
237+ self .mock_fs .open .assert_called_with (
238+ ANY , "w" , content_type = expect_content_type
239+ )
240+
219241 def test_parse_upload_format_envvar (self ):
220242 for envvar_value , expect in (
221243 ("" , "json" ),
@@ -246,7 +268,11 @@ def test_parse_upload_format_envvar(self):
246268 base_path = BASE_PATH , upload_format = "jsonl"
247269 )
248270 self .addCleanup (hook .shutdown )
249- self .assertEqual (hook ._format , "jsonl" )
271+ self .assertEqual (
272+ hook ._format ,
273+ "jsonl" ,
274+ "upload_format kwarg should take precedence" ,
275+ )
250276
251277 def test_upload_after_shutdown_logs (self ):
252278 self .hook .shutdown ()
@@ -409,3 +435,26 @@ def test_upload_jsonlines(self) -> None:
409435{"role":"user","parts":[{"response":{"capital":"Paris"},"id":"get_capital_0","type":"tool_call_response"}],"index":2}
410436""" ,
411437 )
438+
439+ def test_upload_chained_filesystem_ref (self ) -> None :
440+ """Using a chained filesystem like simplecache should refer to the final remote destination"""
441+ hook = UploadCompletionHook (
442+ base_path = "simplecache::memory" ,
443+ upload_format = "jsonl" ,
444+ )
445+ self .addCleanup (hook .shutdown )
446+ log_record = LogRecord ()
447+
448+ hook .on_completion (
449+ inputs = FAKE_INPUTS ,
450+ outputs = FAKE_OUTPUTS ,
451+ system_instruction = FAKE_SYSTEM_INSTRUCTION ,
452+ log_record = log_record ,
453+ )
454+ hook .shutdown ()
455+
456+ ref_uri : str = log_record .attributes ["gen_ai.input.messages_ref" ]
457+ self .assertTrue (
458+ ref_uri .startswith ("memory://" ),
459+ f"{ ref_uri = } does not start with final destination uri memory://" ,
460+ )
0 commit comments