diff --git a/src/executorlib/standalone/serialize.py b/src/executorlib/standalone/serialize.py index 125cf97f..012508be 100644 --- a/src/executorlib/standalone/serialize.py +++ b/src/executorlib/standalone/serialize.py @@ -75,7 +75,7 @@ def serialize_funct( "kwargs": fn_kwargs, } ) - task_key = fn.__name__ + _get_hash(binary=binary_all) + task_key = _get_function_name(fn=fn) + _get_hash(binary=binary_all) data = { "fn": fn, "args": fn_args, @@ -99,3 +99,10 @@ def _get_hash(binary: bytes) -> str: # Remove specification of jupyter kernel from hash to be deterministic binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary) return str(hashlib.md5(binary_no_ipykernel).hexdigest()) + + +def _get_function_name(fn: Callable) -> str: + if hasattr(fn, "__name__"): + return fn.__name__ + else: + return str(fn).split()[0].split(".")[-1] diff --git a/tests/test_singlenodeexecutor_cache.py b/tests/test_singlenodeexecutor_cache.py index 650c0044..f04bd4dc 100644 --- a/tests/test_singlenodeexecutor_cache.py +++ b/tests/test_singlenodeexecutor_cache.py @@ -17,6 +17,11 @@ def get_error(a): raise ValueError(a) +class AddClass: + def __call__(self, a, b): + return a+b + + @unittest.skipIf( skip_h5py_test, "h5py is not installed, so the h5io tests are skipped." ) @@ -34,6 +39,21 @@ def test_cache_data(self): sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst) ) + def test_cache_data_class(self): + cache_directory = os.path.abspath("executorlib_cache") + with SingleNodeExecutor(cache_directory=cache_directory) as exe: + self.assertTrue(exe) + cloudpickle_register(ind=1) + add_instance = AddClass() + future_lst = [exe.submit(add_instance, a=i, b=i) for i in range(1, 4)] + result_lst = [f.result() for f in future_lst] + + cache_lst = get_cache_data(cache_directory=cache_directory) + self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst)) + self.assertEqual( + sum([sum([c["input_kwargs"]["a"], c["input_kwargs"]["b"]]) for c in cache_lst]), sum(result_lst) + ) + def test_cache_key(self): cache_directory = os.path.abspath("executorlib_cache") with SingleNodeExecutor(cache_directory=cache_directory) as exe: diff --git a/tests/test_standalone_serialize.py b/tests/test_standalone_serialize.py new file mode 100644 index 00000000..8e7104a1 --- /dev/null +++ b/tests/test_standalone_serialize.py @@ -0,0 +1,21 @@ +import unittest +from executorlib.standalone.serialize import _get_function_name + + +def my_function(a: int, b: int) -> int: + return a + b + + +class MyClass: + def __call__(self, a: int, b: int) -> int: + return a + b + + +class TestSerialization(unittest.TestCase): + def test_serialization(self): + fn = _get_function_name(fn=my_function) + self.assertEqual(fn, "my_function") + fn = _get_function_name(fn=MyClass()) + self.assertEqual(fn, "MyClass") + fn = _get_function_name(fn=None) + self.assertEqual(fn, "None")