Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/executorlib/standalone/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def serialize_funct(
"kwargs": fn_kwargs,
}
)
task_key = fn.__name__ + _get_hash(binary=binary_all)
task_key = _get_function_name(fn=fn) + _get_hash(binary=binary_all)
data = {
"fn": fn,
"args": fn_args,
Expand All @@ -99,3 +99,10 @@ def _get_hash(binary: bytes) -> str:
# Remove specification of jupyter kernel from hash to be deterministic
binary_no_ipykernel = re.sub(b"(?<=/ipykernel_)(.*)(?=/)", b"", binary)
return str(hashlib.md5(binary_no_ipykernel).hexdigest())


def _get_function_name(fn: Callable) -> str:
if hasattr(fn, "__name__"):
return fn.__name__
else:
return str(fn).split()[0].split(".")[-1]
20 changes: 20 additions & 0 deletions tests/test_singlenodeexecutor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def get_error(a):
raise ValueError(a)


class AddClass:
def __call__(self, a, b):
return a+b


@unittest.skipIf(
skip_h5py_test, "h5py is not installed, so the h5io tests are skipped."
)
Expand All @@ -34,6 +39,21 @@ def test_cache_data(self):
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
)

def test_cache_data_class(self):
cache_directory = os.path.abspath("executorlib_cache")
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
self.assertTrue(exe)
cloudpickle_register(ind=1)
add_instance = AddClass()
future_lst = [exe.submit(add_instance, a=i, b=i) for i in range(1, 4)]
result_lst = [f.result() for f in future_lst]

cache_lst = get_cache_data(cache_directory=cache_directory)
self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst))
self.assertEqual(
sum([sum([c["input_kwargs"]["a"], c["input_kwargs"]["b"]]) for c in cache_lst]), sum(result_lst)
)

def test_cache_key(self):
cache_directory = os.path.abspath("executorlib_cache")
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_standalone_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from executorlib.standalone.serialize import _get_function_name


def my_function(a: int, b: int) -> int:
return a + b


class MyClass:
def __call__(self, a: int, b: int) -> int:
return a + b


class TestSerialization(unittest.TestCase):
def test_serialization(self):
fn = _get_function_name(fn=my_function)
self.assertEqual(fn, "my_function")
fn = _get_function_name(fn=MyClass())
self.assertEqual(fn, "MyClass")
fn = _get_function_name(fn=None)
self.assertEqual(fn, "None")
Loading