Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions icat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def _attrpath(self, attrname):
"""
rclass = self.entity
pattr = ""
if attrname.endswith(')'):
attrname = (attrname.split("("))[1].split(")")[0]
for attr in attrname.split('.'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the proper place to deal with that. _attrpath() is a low level helper with a well defined scope: to check an attribute name, iterating over the components of the dotted path to related objects as needed. It should not deal with higher level issues such as whether that attribute is used inside an SQL function. It is called from many places, not in all of them an SQL function would make sense.

Also, I'm not sure about the way how the attribute name and SQL function is parsed here.

if pattr:
pattr += ".%s" % attr
Expand Down Expand Up @@ -169,6 +171,8 @@ def _makesubst(self, objs):
i = obj.rfind('.')
if i < 0:
continue
if obj.endswith(')'):
obj = (obj.split("("))[1].split(")")[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as above: _makesubst() should rather not need to care about SQL functions.

obj = obj[:i]
for (o, attrInfo, oclass) in self._attrpath(obj):
if o not in subst:
Expand Down Expand Up @@ -373,7 +377,10 @@ def addConditions(self, conditions):
"""
if conditions:
for a in conditions.keys():
for (pattr, attrInfo, rclass) in self._attrpath(a):
a_name = a
if a.endswith(')'):
a_name = (a.split("("))[1].split(")")[0]
for (pattr, attrInfo, rclass) in self._attrpath(a_name):
pass
if a in self.conditions:
conds = []
Expand Down Expand Up @@ -468,7 +475,13 @@ def __str__(self):
if self.conditions:
conds = []
for a in sorted(self.conditions.keys()):
attr = self._dosubst(a, subst, False)
a_name = a
if a.endswith(')'):
sql_function_name = a.split("(")[0]
a_name = (a.split("("))[1].split(")")[0]
attr = self._dosubst(a_name, subst, False)
if "sql_function_name" in locals():
attr = f"{sql_function_name}({attr})"
cond = self.conditions[a]
if isinstance(cond, str):
conds.append("%s %s" % (attr, cond))
Expand Down