diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..4cd42d3 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -30,9 +30,15 @@ def hook(module, input, output): summary[m_key]["input_shape"] = list(input[0].size()) summary[m_key]["input_shape"][0] = batch_size if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output - ] + output_shape = [] + for o in output: + if isinstance(o, (list, tuple)): + for item in o: + output_shape.append([-1] + list(item.size())[1:]) + else: + output_shape.append([-1] + list(o.size())[1:]) + + summary[m_key]["output_shape"] = output_shape else: summary[m_key]["output_shape"] = list(output.size()) summary[m_key]["output_shape"][0] = batch_size