Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandasai/core/code_execution/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def get_environment() -> dict:
"pd": import_dependency("pandas"),
"plt": import_dependency("matplotlib.pyplot"),
"np": import_dependency("numpy"),
"px": import_dependency("plotly.express"),
"go": import_dependency("plotly.graph_objects"),
}

return env
Expand Down
37 changes: 33 additions & 4 deletions pandasai/core/code_generation/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import os.path
import re
import uuid
from pathlib import Path

import astor

from pandasai.agent.state import AgentState
from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.core.code_execution.code_executor import CodeExecutor
from pandasai.query_builders.sql_parser import SQLParser

from ...exceptions import MaliciousQueryError
Expand Down Expand Up @@ -146,9 +144,12 @@ def clean_code(self, code: str) -> str:
tuple: Cleaned code as a string and a list of additional dependencies.
"""
code = self._replace_output_filenames_with_temp_chart(code)
code = self._replace_output_filenames_with_temp_json_chart(code)

# If plt.show is in the code, remove that line
code = re.sub(r"plt.show\(\)", "", code)
code = self._remove_make_dirs(code)

# If plt.show or fig.show is in the code, remove that line
code = re.sub(r"\b(?:plt|fig)\.show\(\)", "", code)

tree = ast.parse(code)
new_body = []
Expand Down Expand Up @@ -180,3 +181,31 @@ def _replace_output_filenames_with_temp_chart(self, code: str) -> str:
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
code,
)

def _replace_output_filenames_with_temp_json_chart(self, code: str) -> str:
"""
Replace output file names with "temp_chart.json" (in case of usage of plotly).
"""
_id = uuid.uuid4()
chart_path = os.path.join(DEFAULT_CHART_DIRECTORY, f"temp_chart_{_id}.json")
chart_path = chart_path.replace("\\", "\\\\")
return re.sub(
r"""(['"])([^'"]*\.json)\1""",
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
code,
)

def _remove_make_dirs(self, code: str) -> str:
"""
Remove any directory creation commands from the code.
"""
# Remove lines that create directories, except for the default chart directory DEFAULT_CHART_DIRECTORY
code_lines = code.splitlines()
cleaned_lines = []
for line in code_lines:
if DEFAULT_CHART_DIRECTORY not in line and (
"os.makedirs(" in line or "os.mkdir(" in line
):
continue
cleaned_lines.append(line)
return "\n".join(cleaned_lines)
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ At the end, declare "result" variable as a dictionary of type and value in the f

Generate python code and return full updated code:

### Note: Use only relevant table for query and do aggregation, sorting, joins and grouby through sql query
### Note: Use only relevant table for query and do aggregation, sorting, joins and group by through sql query
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% if not output_type %}
type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
type (possible values "string", "number", "dataframe", "plot", "iplot"). No other type available. "plot" is when "matplotlib" is used; "iplot" when "plotly" is used. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } or { "type": "iplot", "value": "temp_chart.json" }
{% elif output_type == "number" %}
type (must be "number"), value must int. Example: { "type": "number", "value": 125 }
{% elif output_type == "string" %}
Expand All @@ -8,4 +8,6 @@ type (must be "string"), value must be string. Example: { "type": "string", "val
type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) }
{% elif output_type == "plot" %}
type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }
{% elif output_type == "iplot" %}
type (must be "iplot"), value must be string. Example: { "type": "iplot", "value": "temp_chart.json" }
{% endif %}
2 changes: 2 additions & 0 deletions pandasai/core/response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .chart import ChartResponse
from .dataframe import DataFrameResponse
from .error import ErrorResponse
from .interactive_chart import InteractiveChartResponse
from .number import NumberResponse
from .parser import ResponseParser
from .string import StringResponse
Expand All @@ -10,6 +11,7 @@
"ResponseParser",
"BaseResponse",
"ChartResponse",
"InteractiveChartResponse",
"DataFrameResponse",
"NumberResponse",
"StringResponse",
Expand Down
55 changes: 55 additions & 0 deletions pandasai/core/response/interactive_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import os
from typing import Any

from .base import BaseResponse


class InteractiveChartResponse(BaseResponse):
def __init__(self, value: Any, last_code_executed: str):
super().__init__(value, "ichart", last_code_executed)

self._validate()

def _get_chart(self) -> dict:
if isinstance(self.value, dict):
return self.value

if isinstance(self.value, str):
if os.path.exists(self.value):
with open(self.value, "rb") as f:
return json.load(f)

return json.loads(self.value)

raise ValueError(
"Invalid value type for InteractiveChartResponse. Expected dict or str."
)

def save(self, path: str):
img = self._get_chart()
with open(path, "w") as f:
json.dump(img, f)

def __str__(self) -> str:
return self.value if isinstance(self.value, str) else json.dumps(self.value)

def get_dict_image(self) -> dict:
return self._get_chart()

def _validate(self):
if not isinstance(self.value, (dict, str)):
raise ValueError(
"InteractiveChartResponse value must be a dict or a str representing a file path."
)

# if a string, it can be a path to a file or a JSON string
if isinstance(self.value, str):
try:
json.loads(self.value) # Check if it's a valid JSON string
except json.JSONDecodeError:
# If it fails, check if it's a valid file path
if not os.path.exists(self.value):
raise ValueError(
"InteractiveChartResponse value must be a valid file path or a JSON string."
)
18 changes: 18 additions & 0 deletions pandasai/core/response/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .base import BaseResponse
from .chart import ChartResponse
from .dataframe import DataFrameResponse
from .interactive_chart import InteractiveChartResponse
from .number import NumberResponse
from .string import StringResponse

Expand All @@ -26,6 +27,8 @@ def _generate_response(self, result: dict, last_code_executed: str = None):
return DataFrameResponse(result["value"], last_code_executed)
elif result["type"] == "plot":
return ChartResponse(result["value"], last_code_executed)
elif result["type"] == "iplot":
return InteractiveChartResponse(result["value"], last_code_executed)
else:
raise InvalidOutputValueMismatch(f"Invalid output type: {result['type']}")

Expand Down Expand Up @@ -72,4 +75,19 @@ def _validate_response(self, result: dict):
"Invalid output: Expected a plot save path str but received an incompatible type."
)

elif result["type"] == "iplot":
if not isinstance(result["value"], (str, dict)):
raise InvalidOutputValueMismatch(
"Invalid output: Expected a plot save path str but received an incompatible type."
)

if isinstance(result["value"], dict):
return True

path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
if not bool(re.match(path_to_plot_pattern, result["value"])):
raise InvalidOutputValueMismatch(
"Invalid output: Expected a plot save path str but received an incompatible type."
)

return True
Loading
Loading