Skip to content

Commit ff46299

Browse files
Dspy fanout cache (#8062)
* init * increment * add test * remove old cache code * rename fanout_cache to disk_cache * add testing * fix tests * add cache for colbert * fix tests * increase timeout to 10s * change default value * cache embedding * fallback to litellm cache for embedding call * fix test
1 parent bc3648a commit ff46299

File tree

13 files changed

+664
-180
lines changed

13 files changed

+664
-180
lines changed

dspy/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dspy.utils.usage_tracker import track_usage
1717

1818
from dspy.dsp.utils.settings import settings
19+
from dspy.clients import DSPY_CACHE
1920

2021
configure_dspy_loggers(__name__)
2122

@@ -28,3 +29,5 @@
2829
BootstrapRS = BootstrapFewShotWithRandomSearch
2930

3031
from .__metadata__ import __name__, __version__, __description__, __url__, __author__, __author_email__
32+
33+
cache = DSPY_CACHE

dspy/clients/__init__.py

+61-12
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import litellm
66
import os
77
from pathlib import Path
8-
from litellm.caching import Cache
8+
from dspy.clients.cache import Cache
99
import logging
10+
from typing import Optional
11+
from litellm.caching.caching import Cache as LitellmCache
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -21,20 +23,66 @@ def _litellm_track_cache_hit_callback(kwargs, completion_response, start_time, e
2123

2224
litellm.success_callback = [_litellm_track_cache_hit_callback]
2325

24-
try:
25-
# TODO: There's probably value in getting litellm to support FanoutCache and to separate the limit for
26-
# the LM cache from the embeddings cache. Then we can lower the default 30GB limit.
27-
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
2826

29-
if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT:
30-
litellm.cache.cache.disk_cache.reset("size_limit", DISK_CACHE_LIMIT)
31-
except Exception as e:
32-
# It's possible that users don't have the write permissions to the cache directory.
33-
# In that case, we'll just disable the cache.
34-
logger.warning("Failed to initialize LiteLLM cache: %s", e)
35-
litellm.cache = None
27+
def configure_cache(
28+
enable_disk_cache: Optional[bool] = True,
29+
enable_memory_cache: Optional[bool] = True,
30+
disk_cache_dir: Optional[str] = DISK_CACHE_DIR,
31+
disk_size_limit_bytes: Optional[int] = DISK_CACHE_LIMIT,
32+
memory_max_entries: Optional[int] = 1000000,
33+
enable_litellm_cache: bool = False,
34+
):
35+
"""Configure the cache for DSPy.
36+
37+
Args:
38+
enable_disk_cache: Whether to enable on-disk cache.
39+
enable_memory_cache: Whether to enable in-memory cache.
40+
disk_cache_dir: The directory to store the on-disk cache.
41+
disk_size_limit_bytes: The size limit of the on-disk cache.
42+
memory_max_entries: The maximum number of entries in the in-memory cache.
43+
enable_litellm_cache: Whether to enable LiteLLM cache.
44+
"""
45+
if enable_disk_cache and enable_litellm_cache:
46+
raise ValueError(
47+
"Cannot enable both LiteLLM and DSPy on-disk cache, please set at most one of `enable_disk_cache` or "
48+
"`enable_litellm_cache` to True."
49+
)
50+
51+
if enable_litellm_cache:
52+
try:
53+
litellm.cache = LitellmCache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
54+
55+
if litellm.cache.cache.disk_cache.size_limit != DISK_CACHE_LIMIT:
56+
litellm.cache.cache.disk_cache.reset("size_limit", DISK_CACHE_LIMIT)
57+
except Exception as e:
58+
# It's possible that users don't have the write permissions to the cache directory.
59+
# In that case, we'll just disable the cache.
60+
logger.warning("Failed to initialize LiteLLM cache: %s", e)
61+
litellm.cache = None
62+
else:
63+
litellm.cache = None
64+
65+
import dspy
66+
67+
dspy.cache = Cache(
68+
enable_disk_cache,
69+
enable_memory_cache,
70+
disk_cache_dir,
71+
disk_size_limit_bytes,
72+
memory_max_entries,
73+
)
74+
3675

3776
litellm.telemetry = False
77+
litellm.cache = None # By default we disable litellm cache and use DSPy on-disk cache.
78+
79+
DSPY_CACHE = Cache(
80+
enable_disk_cache=True,
81+
enable_memory_cache=True,
82+
disk_cache_dir=DISK_CACHE_DIR,
83+
disk_size_limit_bytes=DISK_CACHE_LIMIT,
84+
memory_max_entries=1000000,
85+
)
3886

3987
# Turn off by default to avoid LiteLLM logging during every LM call.
4088
litellm.suppress_debug_info = True
@@ -61,4 +109,5 @@ def disable_litellm_logging():
61109
"Embedder",
62110
"enable_litellm_logging",
63111
"disable_litellm_logging",
112+
"configure_cache",
64113
]

dspy/clients/cache.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import copy
2+
import logging
3+
import threading
4+
from functools import wraps
5+
from hashlib import sha256
6+
from typing import Any, Dict, Optional
7+
8+
import cloudpickle
9+
import pydantic
10+
import ujson
11+
from cachetools import LRUCache
12+
from diskcache import FanoutCache
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class Cache:
18+
"""DSPy Cache
19+
20+
`Cache` provides 2 levels of caching (in the given order):
21+
1. In-memory cache - implemented with cachetools.LRUCache
22+
2. On-disk cache - implemented with diskcache.FanoutCache
23+
"""
24+
25+
def __init__(
26+
self,
27+
enable_disk_cache: bool,
28+
enable_memory_cache: bool,
29+
disk_cache_dir: str,
30+
disk_size_limit_bytes: Optional[int] = 1024 * 1024 * 10,
31+
memory_max_entries: Optional[int] = 1000000,
32+
ignored_args_for_cache_key: Optional[list[str]] = None,
33+
):
34+
"""
35+
Args:
36+
enable_disk_cache: Whether to enable on-disk cache.
37+
enable_memory_cache: Whether to enable in-memory cache.
38+
disk_cache_dir: The directory where the disk cache is stored.
39+
disk_size_limit_bytes: The maximum size of the disk cache (in bytes).
40+
memory_max_entries: The maximum size of the in-memory cache (in number of items).
41+
ignored_args_for_cache_key: A list of arguments to ignore when computing the cache key from the request.
42+
"""
43+
44+
self.enable_disk_cache = enable_disk_cache
45+
self.enable_memory_cache = enable_memory_cache
46+
if self.enable_memory_cache:
47+
self.memory_cache = LRUCache(maxsize=memory_max_entries)
48+
else:
49+
self.memory_cache = {}
50+
if self.enable_disk_cache:
51+
self.disk_cache = FanoutCache(
52+
shards=16,
53+
timeout=10,
54+
directory=disk_cache_dir,
55+
size_limit=disk_size_limit_bytes,
56+
)
57+
else:
58+
self.disk_cache = {}
59+
60+
self.ignored_args_for_cache_key = ignored_args_for_cache_key or []
61+
62+
self._lock = threading.RLock()
63+
64+
def __contains__(self, key: str) -> bool:
65+
"""Check if a key is in the cache."""
66+
return key in self.memory_cache or key in self.disk_cache
67+
68+
def cache_key(self, request: Dict[str, Any]) -> str:
69+
"""
70+
Obtain a unique cache key for the given request dictionary by hashing its JSON
71+
representation. For request fields having types that are known to be JSON-incompatible,
72+
convert them to a JSON-serializable format before hashing.
73+
"""
74+
75+
def transform_value(value):
76+
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
77+
return value.model_json_schema()
78+
elif isinstance(value, pydantic.BaseModel):
79+
return value.model_dump()
80+
elif callable(value):
81+
# Try to get the source code of the callable if available
82+
import inspect
83+
84+
try:
85+
# For regular functions, we can get the source code
86+
return f"<callable_source:{inspect.getsource(value)}>"
87+
except (TypeError, OSError, IOError):
88+
# For lambda functions or other callables where source isn't available,
89+
# use a string representation
90+
return f"<callable:{value.__name__ if hasattr(value, '__name__') else 'lambda'}>"
91+
elif isinstance(value, dict):
92+
return {k: transform_value(v) for k, v in value.items()}
93+
else:
94+
return value
95+
96+
params = {k: transform_value(v) for k, v in request.items() if k not in self.ignored_args_for_cache_key}
97+
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()
98+
99+
def get(self, request: Dict[str, Any]) -> Any:
100+
try:
101+
key = self.cache_key(request)
102+
except Exception:
103+
logger.debug(f"Failed to generate cache key for request: {request}")
104+
return None
105+
106+
if self.enable_memory_cache and key in self.memory_cache:
107+
with self._lock:
108+
response = self.memory_cache[key]
109+
elif self.enable_disk_cache and key in self.disk_cache:
110+
# Found on disk but not in memory cache, add to memory cache
111+
response = self.disk_cache[key]
112+
if self.enable_memory_cache:
113+
with self._lock:
114+
self.memory_cache[key] = response
115+
else:
116+
return None
117+
118+
response = copy.deepcopy(response)
119+
if hasattr(response, "usage"):
120+
# Clear the usage data when cache is hit, because no LM call is made
121+
response.usage = {}
122+
return response
123+
124+
def put(self, request: Dict[str, Any], value: Any) -> None:
125+
try:
126+
key = self.cache_key(request)
127+
except Exception:
128+
logger.debug(f"Failed to generate cache key for request: {request}")
129+
return
130+
131+
if self.enable_memory_cache:
132+
with self._lock:
133+
self.memory_cache[key] = value
134+
135+
if self.enable_disk_cache:
136+
try:
137+
self.disk_cache[key] = value
138+
except Exception as e:
139+
# Disk cache writing can fail for different reasons, e.g. disk full or the `value` is not picklable.
140+
logger.debug(f"Failed to put value in disk cache: {value}, {e}")
141+
142+
def reset_memory_cache(self) -> None:
143+
if not self.enable_memory_cache:
144+
return
145+
146+
with self._lock:
147+
self.memory_cache.clear()
148+
149+
def save_memory_cache(self, filepath: str) -> None:
150+
if not self.enable_memory_cache:
151+
return
152+
153+
with self._lock:
154+
with open(filepath, "wb") as f:
155+
cloudpickle.dump(self.memory_cache, f)
156+
157+
def load_memory_cache(self, filepath: str) -> None:
158+
if not self.enable_memory_cache:
159+
return
160+
161+
with self._lock:
162+
with open(filepath, "rb") as f:
163+
self.memory_cache = cloudpickle.load(f)
164+
165+
166+
def request_cache(cache_arg_name: Optional[str] = None, ignored_args_for_cache_key: Optional[list[str]] = None):
167+
"""Decorator for applying caching to a function based on the request argument.
168+
169+
Args:
170+
cache_arg_name: The name of the argument that contains the request. If not provided, the entire kwargs is used
171+
as the request.
172+
ignored_args_for_cache_key: A list of arguments to ignore when computing the cache key from the request.
173+
"""
174+
175+
def decorator(fn):
176+
@wraps(fn)
177+
def wrapper(*args, **kwargs):
178+
import dspy
179+
180+
cache = dspy.cache
181+
original_ignored_args_for_cache_key = cache.ignored_args_for_cache_key
182+
cache.ignored_args_for_cache_key = ignored_args_for_cache_key or []
183+
184+
# Use fully qualified function name for uniqueness
185+
fn_identifier = f"{fn.__module__}.{fn.__qualname__}"
186+
187+
# Create a modified request that includes the function identifier so that it's incorporated into the cache
188+
# key. Deep copy is required because litellm sometimes modifies the kwargs in place.
189+
if cache_arg_name:
190+
# When `cache_arg_name` is provided, use the value of the argument with this name as the request for
191+
# caching.
192+
modified_request = copy.deepcopy(kwargs[cache_arg_name])
193+
else:
194+
# When `cache_arg_name` is not provided, use the entire kwargs as the request for caching.
195+
modified_request = copy.deepcopy(kwargs)
196+
for i, arg in enumerate(args):
197+
modified_request[f"positional_arg_{i}"] = arg
198+
modified_request["_fn_identifier"] = fn_identifier
199+
200+
# Retrieve from cache if available
201+
cached_result = cache.get(modified_request)
202+
203+
if cached_result is not None:
204+
return cached_result
205+
206+
# Otherwise, compute and store the result
207+
result = fn(*args, **kwargs)
208+
cache.put(modified_request, result)
209+
210+
cache.ignored_args_for_cache_key = original_ignored_args_for_cache_key
211+
return result
212+
213+
return wrapper
214+
215+
return decorator

0 commit comments

Comments
 (0)