Skip to content

Commit 0c89cac

Browse files
fix: actually callable_chain does not work for langchain so we have to make runnables without decorators
Signed-off-by: thiswillbeyourgithub <[email protected]>
1 parent 6b3ddab commit 0c89cac

File tree

4 files changed

+20
-63
lines changed

4 files changed

+20
-63
lines changed

wdoc/utils/customs/callable_runnable.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

wdoc/utils/misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from platformdirs import user_cache_dir
4343
from loguru import logger
4444

45-
from wdoc.utils.customs.callable_runnable import callable_chain
4645
from wdoc.utils.env import env, is_input_piped, pytest_ongoing
4746
from wdoc.utils.errors import UnexpectedDocDictArgument
4847

@@ -425,8 +424,6 @@ def html_to_text(html: str, remove_image: bool = False) -> str:
425424
return text
426425

427426

428-
@callable_chain
429-
@chain
430427
def debug_chain(inputs: Union[dict, List]) -> Union[dict, List]:
431428
"use it between | pipes | in a chain to open the debugger"
432429
if hasattr(inputs, "keys"):
@@ -435,6 +432,9 @@ def debug_chain(inputs: Union[dict, List]) -> Union[dict, List]:
435432
return inputs
436433

437434

435+
debug_chain = chain(debug_chain)
436+
437+
438438
def wrapped_model_name_matcher(model: str) -> str:
439439
"find the best match for a modelname (wrapped to make some check)"
440440
# find the currently set api keys to avoid matching models from

wdoc/utils/tasks/query.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from tqdm import tqdm
2323
from loguru import logger
2424

25-
from wdoc.utils.customs.callable_runnable import callable_chain
2625
from wdoc.utils.env import env
2726
from wdoc.utils.errors import (
2827
InvalidDocEvaluationByLLMEval,
@@ -52,8 +51,6 @@ def sieve_documents(instance) -> RunnableLambda:
5251
we can end up with a lot more document!
5352
"""
5453

55-
@callable_chain
56-
@chain
5754
def _sieve(inputs: dict) -> dict:
5855
assert "question_to_answer" in inputs, inputs.keys()
5956
assert "unfiltered_docs" in inputs, inputs.keys()
@@ -71,11 +68,11 @@ def _sieve(inputs: dict) -> dict:
7168
inputs["unfiltered_docs"] = inputs["unfiltered_docs"][: instance.top_k]
7269
return inputs
7370

71+
_sieve = chain(_sieve)
72+
7473
return _sieve
7574

7675

77-
@callable_chain
78-
@chain
7976
@log_and_time_fn
8077
def refilter_docs(inputs: dict) -> List[Document]:
8178
"filter documents fond via RAG based on the digit answered by the eval llm"
@@ -118,6 +115,9 @@ def refilter_docs(inputs: dict) -> List[Document]:
118115
return filtered_docs
119116

120117

118+
refilter_docs = chain(refilter_docs)
119+
120+
121121
@log_and_time_fn
122122
def parse_eval_output(output: str) -> str:
123123
"""
@@ -536,8 +536,6 @@ def pbar_chain(
536536
) -> RunnableLambda:
537537
"create a chain that just sets a tqdm progress bar"
538538

539-
@callable_chain
540-
@chain
541539
def actual_pbar_chain(
542540
inputs: Union[dict, List],
543541
llm: Union[ChatLiteLLM, FakeListChatModel] = llm,
@@ -554,6 +552,8 @@ def actual_pbar_chain(
554552

555553
return inputs
556554

555+
actual_pbar_chain = chain(actual_pbar_chain)
556+
557557
return actual_pbar_chain
558558

559559

@@ -562,8 +562,6 @@ def pbar_closer(
562562
) -> RunnableLambda:
563563
"close a pbar created by pbar_chain"
564564

565-
@callable_chain
566-
@chain
567565
def actual_pbar_closer(
568566
inputs: Union[dict, List],
569567
llm: Union[ChatLiteLLM, FakeListChatModel] = llm,
@@ -575,4 +573,6 @@ def actual_pbar_closer(
575573

576574
return inputs
577575

576+
actual_pbar_closer = chain(actual_pbar_closer)
577+
578578
return actual_pbar_closer

wdoc/wdoc.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from tqdm import tqdm
3939
from loguru import logger as logger
4040

41-
from wdoc.utils.customs.callable_runnable import callable_chain
4241
from wdoc.utils.batch_file_loader import batch_load_doc
4342
from wdoc.utils.customs.fix_llm_caching import SQLiteCacheFixed
4443
from wdoc.utils.embeddings import create_embeddings, load_embeddings_engine
@@ -1415,8 +1414,6 @@ def eval_cache_wrapper(func: Callable) -> Callable:
14151414
f"Related github issue: 'https://github.com/langchain-ai/langchain/issues/23257'"
14161415
)
14171416

1418-
@callable_chain
1419-
@chain
14201417
def autoincrease_top_k(filtered_docs: List[Document]) -> List[Document]:
14211418
if not self.max_top_k:
14221419
return filtered_docs
@@ -1439,8 +1436,8 @@ def autoincrease_top_k(filtered_docs: List[Document]) -> List[Document]:
14391436
)
14401437
return filtered_docs
14411438

1442-
@callable_chain
1443-
@chain
1439+
autoincrease_top_k = chain(autoincrease_top_k)
1440+
14441441
@eval_cache_wrapper
14451442
def evaluate_doc_chain(
14461443
inputs: dict,
@@ -1564,6 +1561,8 @@ async def do_eval(subinputs):
15641561
self.eval_llm.callbacks[0].pbar[-1].update(1)
15651562
return outputs
15661563

1564+
evaluate_doc_chain = chain(evaluate_doc_chain)
1565+
15671566
# uses in most places to increase concurrency limit
15681567
multi = {
15691568
"max_concurrency": env.WDOC_LLM_MAX_CONCURRENCY if not self.debug else 1
@@ -1572,8 +1571,6 @@ async def do_eval(subinputs):
15721571
if self.task == "search":
15731572
if self.query_eval_model is not None:
15741573
# for some reason I needed to have at least one chain object otherwise rag_chain is a dict
1575-
@callable_chain
1576-
@chain
15771574
def retrieve_documents(inputs):
15781575
return {
15791576
"unfiltered_docs": retriever.invoke(
@@ -1583,6 +1580,8 @@ def retrieve_documents(inputs):
15831580
}
15841581
return inputs
15851582

1583+
retrieve_documents = chain(retrieve_documents)
1584+
15861585
meta_refilter_docs = {
15871586
"filtered_docs": (
15881587
RunnablePassthrough.assign(
@@ -1774,8 +1773,6 @@ def retrieve_documents(inputs):
17741773

17751774
else:
17761775
# for some reason I needed to have at least one chain object otherwise rag_chain is a dict
1777-
@callable_chain
1778-
@chain
17791776
def retrieve_documents(inputs):
17801777
return {
17811778
"unfiltered_docs": retriever.invoke(
@@ -1785,6 +1782,8 @@ def retrieve_documents(inputs):
17851782
}
17861783
return inputs
17871784

1785+
retrieve_documents = chain(retrieve_documents)
1786+
17881787
meta_refilter_docs = {
17891788
"filtered_docs": (
17901789
RunnablePassthrough.assign(

0 commit comments

Comments
 (0)