Skip to content

Commit 8156eb4

Browse files
authored
Merge pull request #1712 from mgxd/fix/join
fix: joinnode connection bugfix
2 parents f00ab33 + 1c25969 commit 8156eb4

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

nipype/pipeline/engine/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def fullname(self):
6969
fullname = self._hierarchy + '.' + self.name
7070
return fullname
7171

72+
@property
73+
def itername(self):
74+
itername = self._id
75+
if self._hierarchy:
76+
itername = self._hierarchy + '.' + self._id
77+
return itername
78+
7279
def clone(self, name):
7380
"""Clone an EngineBase object
7481

nipype/pipeline/engine/tests/test_join.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,46 @@ def test_set_join_node_file_input():
602602
os.chdir(cwd)
603603
rmtree(wd)
604604

605+
def test_nested_workflow_join():
606+
"""Test collecting join inputs within a nested workflow"""
607+
cwd = os.getcwd()
608+
wd = mkdtemp()
609+
os.chdir(wd)
610+
611+
# Make the nested workflow
612+
def nested_wf(i, name='smallwf'):
613+
#iterables with list of nums
614+
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
615+
inputspec.iterables = [('n', i)]
616+
# increment each iterable before joining
617+
pre_join = pe.Node(IncrementInterface(),
618+
name='pre_join')
619+
# rejoin nums into list
620+
join = pe.JoinNode(IdentityInterface(fields=['n']),
621+
joinsource='inputspec',
622+
joinfield='n',
623+
name='join')
624+
#define and connect nested workflow
625+
wf = pe.Workflow(name='wf_%d'%i[0])
626+
wf.connect(inputspec, 'n', pre_join, 'input1')
627+
wf.connect(pre_join, 'output1', join, 'n')
628+
return wf
629+
# master wf
630+
meta_wf = pe.Workflow(name='meta', base_dir='.')
631+
# add each mini-workflow to master
632+
for i in [[1,3],[2,4]]:
633+
mini_wf = nested_wf(i)
634+
meta_wf.add_nodes([mini_wf])
635+
636+
result = meta_wf.run()
637+
638+
# there should be six nodes in total
639+
assert_equal(len(result.nodes()), 6,
640+
"The number of expanded nodes is incorrect.")
641+
642+
os.chdir(cwd)
643+
rmtree(wd)
644+
605645

606646
if __name__ == "__main__":
607647
import nose

nipype/pipeline/engine/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def generate_expanded_graph(graph_in):
711711
in_edges = jedge_dict[jnode] = {}
712712
edges2remove = []
713713
for src, dest, data in graph_in.in_edges_iter(jnode, True):
714-
in_edges[src._id] = data
714+
in_edges[src.itername] = data
715715
edges2remove.append((src, dest))
716716

717717
for src, dest in edges2remove:
@@ -796,7 +796,7 @@ def make_field_func(*pair):
796796
expansions = defaultdict(list)
797797
for node in graph_in.nodes_iter():
798798
for src_id, edge_data in list(old_edge_dict.items()):
799-
if node._id.startswith(src_id):
799+
if node.itername.startswith(src_id):
800800
expansions[src_id].append(node)
801801
for in_id, in_nodes in list(expansions.items()):
802802
logger.debug("The join node %s input %s was expanded"

0 commit comments

Comments
 (0)