Skip to content

Faster R-CNN (WIP) #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed

Faster R-CNN (WIP) #21

wants to merge 18 commits into from

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Dec 29, 2016

This WIP PR implements Faster R-CNN based on https://github.com/rbgirshick/py-faster-rcnn .
There are lots of copy-paste from Girshick's repo, and a few files were copied as is.
I decided to fuse a number of operations together, but I'm not sure anymore it's the best way to go.

A few comments:

  • There are two main classes: FasterRCNN and RPN. RPN implements the region proposal network, and faster r-cnn wraps almost all the detection logic (which might hurt flexibility at some point).
  • Both classes behave differently if a ground-truth is provided, avoiding to have to write different network definitions for train/test.
  • I tried to avoid a global cfg dict containing all the parameters, and instead pass the parameters as the constructors of the classes. But different train/test parameters still need to be defined in the model.
  • There is currently no handling on the image scaling (originally present in im_info[2]), necessary for properly pruning small boxes in the RPN.
  • I'll add optimized implementations for ROIPooling/nms/etc later on using cffi.

I'm opening this PR to get some feed-back on the general structure of the code.
Overall, I'm starting to think that it might be better to use the same structure as the one in py-faster-rcnn from Girshick repo.

Code cleanup and organization is required.

@colesbury
Copy link
Member

This looks good. I think the most important thing is to get training (and evaluation) working and then worry more about design and cleaning up the code.

Some thoughts:

  • I think copying from Ross's Faster-RCNN repo makes sense where appropriate. Keeping them in separate files, unchanged if possible, would be best
  • There's a trade-off between making everything part of the model (nn.Container) and keeping it as part of the training function. For example, FasterRCNN.forward is a bit complicated because it combines training and evaluation logic: it has two possible return types and two possible input types.
  • The forward() method in modules should take in Variables where appropriate; the wrapping of tensors in Variables should happen outside the module
  • Try to match PyTorch style (PEP 8) for stuff that's not from RBG's repo. (4 space indent, etc.)
  • Use Python argparse to configure options. I like how you're not passing a global cfg around. If you want to support a config file as well, I think you can do so with ConfigParser

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to include the license file if we're including the original code?

boxes[:, 2] = width - oldx1 - 1
return boxes

class TransformVOCDetectionAnnotation(object):

This comment was marked as off-topic.

This comment was marked as off-topic.

if isinstance(x, np.ndarray):
return Variable(torch.from_numpy(x), requires_grad=False)
elif torch.is_tensor(x):
return Variable(x, requires_grad=True)

This comment was marked as off-topic.

cls_crit = nn.CrossEntropyLoss()
cls_loss = cls_crit(scores, labels)

reg_crit = nn.SmoothL1Loss()

This comment was marked as off-topic.

from torch.autograd.function import Function
from torch._thnn import type2backend

class AdaptiveMaxPool2d(Function):

This comment was marked as off-topic.

This comment was marked as off-topic.


# I need to know the original image size (or have the scaling factor)
def get_roi_boxes(self, anchors, rpn_map, rpn_bbox_deltas, im):
# TODO fix this!!!

This comment was marked as off-topic.

This comment was marked as off-topic.

class_to_ind = dict(zip(cls, range(len(cls))))


train = VOCDetection('/home/francisco/work/datasets/VOCdevkit/', 'train',

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Jan 6, 2017

Also, there are a lot of commented statements that should be removed before merging

@fmassa
Copy link
Member Author

fmassa commented Jan 6, 2017

@colesbury @apaszke thanks for your comments. I think the ConfigParser might be a nice way of addressing tons of arguments.

I initially wanted to keep the number of files small (as it is an example code), so I fused a number of things together, but that's probably a poor design choice.

I will validate that the basic code is working as expected by performing a training/evaluation and then I'll focus on getting a refactoring of this PR.

I'll get back to it on Monday, I've a trip to do in Rio tomorrow :)

@apaszke
Copy link
Contributor

apaszke commented Jan 6, 2017

Ugh nvm, for some reason I haven't noticed that it's WIP...

Have a nice trip! 😃

@bhack
Copy link

bhack commented Mar 23, 2017

Have you tested https://github.com/longcw/faster_rcnn_pytorch?
/cc @longcw

@fmassa
Copy link
Member Author

fmassa commented Mar 23, 2017

@bhack I haven't, and I didn't have the time to finish this properly.
Given that there are already a number of pytorch implementations of object detection algorithms in pytorch, I'll close this one for the time being.
If I find some time to finish this up with a simple interface, I'll send a new PR.

@fmassa fmassa closed this Mar 23, 2017
@bhack
Copy link

bhack commented Mar 25, 2017

@KaimingHe In the plan of releasing Mask r-cnn there will be also a faster-rcnn pytorch implementation merged in this repository?

@bhack
Copy link

bhack commented Apr 1, 2017

A TF WIP Mask R-CNN effort was started at https://github.com/CharlesShang/FastMaskRCNN. Actually there is still no public reference implementation of the paper so we will see what kind of accurancy can be reproduced.

@bhack
Copy link

bhack commented Apr 1, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants