diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index b92b8b3a..12ef2052 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -210,27 +210,28 @@ def write( keys_iterator = iter(keys) if keys else None added_keys: List[str] = [] - with redis_client.pipeline(transaction=False) as pipe: - for i, obj in enumerate(objects, start=1): - # Construct key, validate, and write - key = ( - next(keys_iterator) - if keys_iterator - else self._create_key(obj, id_field) - ) - obj = self._preprocess(obj, preprocess) - self._validate(obj) - self._set(pipe, key, obj) - # Set TTL if provided - if ttl: - pipe.expire(key, ttl) - # Execute mini batch - if i % batch_size == 0: + if objects: + with redis_client.pipeline(transaction=False) as pipe: + for i, obj in enumerate(objects, start=1): + # Construct key, validate, and write + key = ( + next(keys_iterator) + if keys_iterator + else self._create_key(obj, id_field) + ) + obj = self._preprocess(obj, preprocess) + self._validate(obj) + self._set(pipe, key, obj) + # Set TTL if provided + if ttl: + pipe.expire(key, ttl) + # Execute mini batch + if i % batch_size == 0: + pipe.execute() + added_keys.append(key) + # Clean up batches if needed + if i % batch_size != 0: pipe.execute() - added_keys.append(key) - # Clean up batches if needed - if i % batch_size != 0: - pipe.execute() return added_keys diff --git a/tests/unit/test_async_search_index.py b/tests/unit/test_async_search_index.py index ff1c7d9f..3e7f6793 100644 --- a/tests/unit/test_async_search_index.py +++ b/tests/unit/test_async_search_index.py @@ -129,6 +129,13 @@ async def bad_preprocess(record): await async_index.load(data, id_field="id", preprocess=bad_preprocess) +@pytest.mark.asyncio +async def test_search_index_load_empty(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + await async_index.load([]) + + @pytest.mark.asyncio async def test_no_id_field(async_client, async_index): async_index.set_client(async_client) diff --git a/tests/unit/test_search_index.py b/tests/unit/test_search_index.py index 8a27e5b3..503153a2 100644 --- a/tests/unit/test_search_index.py +++ b/tests/unit/test_search_index.py @@ -125,6 +125,12 @@ def bad_preprocess(record): index.load(data, id_field="id", preprocess=bad_preprocess) +def test_search_index_load_empty(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + index.load([]) + + def test_no_id_field(client, index): index.set_client(client) index.create(overwrite=True, drop=True)