diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py index 99a8ce3320ad4e..80b4cf4c9038ef 100644 --- a/Lib/multiprocessing/shared_memory.py +++ b/Lib/multiprocessing/shared_memory.py @@ -74,12 +74,10 @@ class SharedMemory: _track = True def __init__(self, name=None, create=False, size=0, *, track=True): - if not size >= 0: - raise ValueError("'size' must be a positive integer") + if size < 0: + raise ValueError("'size' must be a non-negative integer") if create: self._flags = _O_CREX | os.O_RDWR - if size == 0: - raise ValueError("'size' must be a positive number different from zero") if name is None and not self._flags & os.O_EXCL: raise ValueError("'name' can only be None if create=True") @@ -114,9 +112,11 @@ def __init__(self, name=None, create=False, size=0, *, track=True): os.ftruncate(self._fd, size) stats = os.fstat(self._fd) size = stats.st_size - self._mmap = mmap.mmap(self._fd, size) - except OSError: - self.unlink() + if size > 0: + self._mmap = mmap.mmap(self._fd, size) + except: + if create: + _posixshmem.shm_unlink(self._name) raise if self._track: resource_tracker.register(self._name, "shared_memory") @@ -125,7 +125,7 @@ def __init__(self, name=None, create=False, size=0, *, track=True): # Windows Named Shared Memory - if create: + if create and size > 0: while True: temp_name = _make_filename() if name is None else name # Create and reserve shared memory block with this name @@ -155,7 +155,9 @@ def __init__(self, name=None, create=False, size=0, *, track=True): _winapi.CloseHandle(h_map) self._name = temp_name break - + elif create and size == 0: + # TODO: Leave as None? + self._name = _make_filename() if name is None else name else: self._name = name # Dynamically determine the existing named shared memory @@ -179,10 +181,14 @@ def __init__(self, name=None, create=False, size=0, *, track=True): size = _winapi.VirtualQuerySize(p_buf) finally: _winapi.UnmapViewOfFile(p_buf) - self._mmap = mmap.mmap(-1, size, tagname=name) + if size > 0: + self._mmap = mmap.mmap(-1, size, tagname=name) self._size = size - self._buf = memoryview(self._mmap) + if size > 0: + self._buf = memoryview(self._mmap) + else: + self._buf = memoryview(b'') def __del__(self): try: diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 4dc9a31d22f771..6514d0f786f6ea 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -4447,14 +4447,6 @@ def test_invalid_shared_memory_creation(self): with self.assertRaises(ValueError): sms_invalid = shared_memory.SharedMemory(create=True, size=-1) - # Test creating a shared memory segment with size 0 - with self.assertRaises(ValueError): - sms_invalid = shared_memory.SharedMemory(create=True, size=0) - - # Test creating a shared memory segment without size argument - with self.assertRaises(ValueError): - sms_invalid = shared_memory.SharedMemory(create=True) - def test_shared_memory_pickle_unpickle(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto): @@ -4851,6 +4843,33 @@ def test_shared_memory_tracking(self): resource_tracker.unregister(mem._name, "shared_memory") mem.close() + @unittest.skipIf(os.name != "posix", "windows automatically unlinks") + def test_creating_shm_unlinks_on_error(self): + name = self._new_shm_name("test_csuoe") + with unittest.mock.patch('mmap.mmap') as mock_mmap: + mock_mmap.side_effect = OSError("filesystems are evil") + with self.assertRaises(OSError): + shared_memory.SharedMemory(name, create=True, size=1) + with self.assertRaises(FileNotFoundError): + import _posixshmem + _posixshmem.shm_unlink(name) + + def test_existing_shm_not_unlinked_on_error(self): + name = self._new_shm_name("test_esnuoe") + mem = shared_memory.SharedMemory(name, create=True, size=1) + self.addCleanup(mem.unlink) + with unittest.mock.patch('mmap.mmap') as mock_mmap: + mock_mmap.side_effect = OSError("filesystems are evil") + with self.assertRaises(OSError): + shared_memory.SharedMemory(name, create=False) + + def test_zero_length_shared_memory(self): + name = self._new_shm_name("test_zlsm") + mem = shared_memory.SharedMemory(name, create=True, size=0) + self.addCleanup(mem.unlink) + self.assertEqual(mem.size, 0) + self.assertEqual(len(mem.buf), 0) + # # Test to verify that `Finalize` works. #