Skip to content

Commit 8d71bce

Browse files
feat: Added langflow components and prompt API
1 parent 038b7a7 commit 8d71bce

File tree

12 files changed

+220
-11
lines changed

12 files changed

+220
-11
lines changed

.vscode/settings.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"python.linting.enabled": true,
3+
"python.formatting.blackArgs": [
4+
"--line-length",
5+
"140"
6+
],
7+
"python.formatting.provider": "black",
8+
"editor.formatOnSave": true,
9+
"python.linting.mypyEnabled": false,
10+
"python.analysis.typeCheckingMode": "basic"
11+
}

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ openai
44
sqlalchemy
55
pyjwt[crypto]
66
ipython
7+
langflow
8+
black

server/api/chatbot.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
chatbot_router = APIRouter(prefix="/chatbot", tags=["chatbot"])
1010

11+
1112
class ChatBotModel(BaseModel):
1213
name: str
1314
dag: dict
1415

16+
1517
@chatbot_router.post("/", status_code=200)
1618
def create_chatbot(inputs: ChatBotModel, token: Annotated[str, Header()], db: Session = Depends(database.db_session)):
1719
username = get_user_from_jwt(token)
@@ -26,18 +28,19 @@ def create_chatbot(inputs: ChatBotModel, token: Annotated[str, Header()], db: Se
2628
response = {"msg": "failed"}
2729
return response
2830

31+
2932
# @chatbot_router.get("/", status_code=200)
3033

3134
# @chatbot_router.put("/", status_code=200)
3235
# def update_chatbot(inputs: ChatBotModel, db: Session = Depends(database.db_session)):
33-
# user: User = db.query(User).filter((User.username == inputs.username) & (User.password == inputs.old_password)).first()
34-
# if user is not None:
35-
# user.password = inputs.new_password
36-
# db.commit()
37-
# response = {"msg": "success"}
38-
# else:
39-
# response = {"msg": "failed"}
40-
# return response
36+
# user: User = db.query(User).filter((User.username == inputs.username) & (User.password == inputs.old_password)).first()
37+
# if user is not None:
38+
# user.password = inputs.new_password
39+
# db.commit()
40+
# response = {"msg": "success"}
41+
# else:
42+
# response = {"msg": "failed"}
43+
# return response
4144

4245
# @chatbot_router.delete("/", status_code=200)
4346
# def update_chatbot(inputs: ChatBotModel, db: Session = Depends(database.db_session)):

server/api/langflow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import logging
2+
from typing import Any, Dict
3+
4+
from fastapi import APIRouter, HTTPException
5+
6+
from langflow.interface.run import process_graph
7+
from langflow.interface.types import build_langchain_types_dict
8+
9+
# build router
10+
router = APIRouter(prefix="/flow", tags=["flow"])
11+
# add docs to router
12+
router.__doc__ = """
13+
# Flow API
14+
"""
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@router.get("/components")
20+
def get_all():
21+
return build_langchain_types_dict()

server/api/prompts.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import logging
2+
from typing import Any, Dict
3+
from sqlalchemy.orm import Session
4+
from database import db_session
5+
import contextlib
6+
import io
7+
8+
from fastapi import APIRouter, HTTPException, Depends
9+
10+
from langflow.interface.run import load_langchain_object, save_cache, fix_memory_inputs
11+
12+
from database_utils.chatbot import get_chatbot
13+
from schemas.prompt_schema import PromptSchema
14+
15+
# build router
16+
router = APIRouter(tags=["prompts"])
17+
# add docs to router
18+
router.__doc__ = """
19+
# Prompts API
20+
"""
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def format_intermediate_steps(intermediate_steps):
26+
formatted_chain = []
27+
for step in intermediate_steps:
28+
action = step[0]
29+
observation = step[1]
30+
31+
formatted_chain.append(
32+
{
33+
"action": action.tool,
34+
"action_input": action.tool_input,
35+
"observation": observation,
36+
}
37+
)
38+
return formatted_chain
39+
40+
41+
def get_result_and_thought_using_graph(langchain_object, message: str):
42+
"""Get result and thought from extracted json"""
43+
try:
44+
if hasattr(langchain_object, "verbose"):
45+
langchain_object.verbose = True
46+
chat_input = None
47+
memory_key = ""
48+
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
49+
memory_key = langchain_object.memory.memory_key
50+
51+
for key in langchain_object.input_keys:
52+
if key not in [memory_key, "chat_history"]:
53+
chat_input = {key: message}
54+
55+
if hasattr(langchain_object, "return_intermediate_steps"):
56+
# https://github.com/hwchase17/langchain/issues/2068
57+
# Deactivating until we have a frontend solution
58+
# to display intermediate steps
59+
langchain_object.return_intermediate_steps = True
60+
61+
fix_memory_inputs(langchain_object)
62+
63+
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
64+
try:
65+
output = langchain_object(chat_input)
66+
except ValueError as exc:
67+
# make the error message more informative
68+
logger.debug(f"Error: {str(exc)}")
69+
output = langchain_object.run(chat_input)
70+
71+
intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else []
72+
73+
result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output
74+
if intermediate_steps:
75+
thought = format_intermediate_steps(intermediate_steps)
76+
else:
77+
thought = {"steps": output_buffer.getvalue()}
78+
79+
except Exception as exc:
80+
raise ValueError(f"Error: {str(exc)}") from exc
81+
return result, thought
82+
83+
84+
def process_graph(message, chat_history, data_graph):
85+
"""
86+
Process graph by extracting input variables and replacing ZeroShotPrompt
87+
with PromptTemplate,then run the graph and return the result and thought.
88+
"""
89+
# Load langchain object
90+
logger.debug("Loading langchain object")
91+
is_first_message = len(chat_history) == 0
92+
computed_hash, langchain_object = load_langchain_object(data_graph, is_first_message)
93+
logger.debug("Loaded langchain object")
94+
95+
if langchain_object is None:
96+
# Raise user facing error
97+
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
98+
99+
# Generate result and thought
100+
logger.debug("Generating result and thought")
101+
result, thought = get_result_and_thought_using_graph(langchain_object, message)
102+
logger.debug("Generated result and thought")
103+
104+
# Save langchain_object to cache
105+
# We have to save it here because if the
106+
# memory is updated we need to keep the new values
107+
logger.debug("Saving langchain object to cache")
108+
save_cache(computed_hash, langchain_object, is_first_message)
109+
logger.debug("Saved langchain object to cache")
110+
return {"result": str(result), "thought": thought}
111+
112+
113+
@router.post("/chatbot/{chatbot_id}/prompt")
114+
def get_prompt(chatbot_id: int, prompt: PromptSchema, db: Session = Depends(db_session)):
115+
try:
116+
chatbot = get_chatbot(db, chatbot_id)
117+
118+
# Process graph
119+
logger.debug("Processing graph")
120+
result = process_graph(prompt.new_message, prompt.chat_history, chatbot.dag)
121+
122+
logger.debug("Processed graph")
123+
return result
124+
125+
except Exception as e:
126+
logger.exception(e)
127+
raise HTTPException(status_code=500, detail=str(e)) from e

server/app.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from api.metrics import metrics_router
1515
from api.feedback import feedback_router
1616
from api.intermediate_steps import intermediate_steps_router
17+
from api.langflow import router as langflow_router
18+
from api.prompts import router as prompts_router
1719

1820
c.Base.metadata.create_all(bind=engine)
1921

@@ -34,7 +36,8 @@
3436
app.include_router(intermediate_steps_router)
3537
app.include_router(chatbot_router)
3638
app.include_router(auth_router)
37-
39+
app.include_router(langflow_router)
40+
app.include_router(prompts_router)
3841
####################################################
3942
################ APIs ##############################
4043
####################################################

server/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ class Prompt(Base):
6060
gpt_rating = Column(Text, nullable=True)
6161
user_rating = Column(Enum(PromptRating), nullable=True)
6262
chatbot_user_rating = input_prompt = Column(Enum(PromptRating), nullable=True)
63-
response = Column(Text, nullable=False)
64-
time_taken = Column(Float, nullable=False)
63+
response = Column(Text, nullable=True)
64+
time_taken = Column(Float, nullable=True)
6565
created_at = Column(DateTime, nullable=False)
6666
session_id = Column(String(80), nullable=False)
6767
meta = Column(JSON)

server/database_utils/__init__.py

Whitespace-only changes.

server/database_utils/chatbot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from sqlalchemy.orm import Session
2+
3+
from database import ChatBot
4+
5+
6+
def get_chatbot(db: Session, chatbot_id: int) -> ChatBot:
7+
row = db.query(ChatBot).filter(ChatBot.id == chatbot_id).first()
8+
9+
if row is None:
10+
raise ValueError(f"Chatbot with id {chatbot_id} does not exist")
11+
12+
return row

server/database_utils/prompt.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from datetime import datetime
2+
from typing import List
3+
4+
from sqlalchemy.orm import Session
5+
6+
from database import ChatBot, Prompt, IntermediateStep
7+
8+
9+
def get_prompts(db: Session, chatbot_id: int) -> List[Prompt]:
10+
row = db.query(Prompt).filter(Prompt.chatbot_id == chatbot_id).all()
11+
return row
12+
13+
14+
def create_prompt(db: Session, input_prompt: str, session_id=str) -> Prompt:
15+
db_prompt = Prompt(
16+
input_prompt=input_prompt,
17+
created_at=datetime.now(),
18+
session_id=session_id,
19+
)
20+
db.add(db_prompt)
21+
db.commit()
22+
db.refresh(db_prompt)
23+
return db_prompt

0 commit comments

Comments
 (0)