-
Notifications
You must be signed in to change notification settings - Fork 415
Open
Description
I have created my own custom layer or module and found its parameters are not counted properly in the summary.
Looking at the code, I found this could be the reason:
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
So basically it is looking for parameters named weight
or bias
. I think it will be great if the code is updated to handle generic trainable parameter names.
gau-ku, Ugenteraan, curiale and Woodman718
Metadata
Metadata
Assignees
Labels
No labels