diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..e9b89bf 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -5,8 +5,9 @@ from collections import OrderedDict import numpy as np +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary(model, input_size, batch_size=-1, device=device, dtypes=None): result, params_info = summary_string( model, input_size, batch_size, device, dtypes) print(result) @@ -14,7 +15,7 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dty return params_info -def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary_string(model, input_size, batch_size=-1, device=device, dtypes=None): if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size)