Skip to content

Top Species Filtering #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
Binary file added __pycache__/credentials.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/datasets.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/utils.cpython-39.pyc
Binary file not shown.
6 changes: 0 additions & 6 deletions create_data_lists.py

This file was deleted.

162 changes: 160 additions & 2 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,168 @@
import torch
from torch.utils.data import Dataset
import json
from torchvision.transforms import ToPILImage
from transformations import BBoxToBoundary
import os
import pandas as pd
from PIL import Image
from utils import transform
from ast import literal_eval
from collections import Counter
import ast
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import random

class SerengetiDataset(Dataset):
"""
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
"""
def __init__(self, image_folder, images_df, annotations_df, classes_df, night_images, split=None, transform=None):
self.image_folder = image_folder
self.images_df = images_df
self.annotations_df = annotations_df
self.classes_df = classes_df
self.transform = transform
self.night_images = night_images
self.split = split
if self.split:
self.split = self.split.upper()
assert self.split in {'DAY', 'NIGHT'}
if self.split == 'NIGHT':
self.images_df = self.images_df[self.images_df['image_path_rel'].isin(self.night_images)]
elif self.split == 'DAY':
self.images_df = self.images_df[~self.images_df['image_path_rel'].isin(self.night_images)]

self.bboxes = {row['id']: [] for _, row in self.images_df.iterrows()}
for i, row in self.annotations_df.iterrows():
if row['image_id'] in self.bboxes:
self.bboxes[row['image_id']].append(i)

self.annotations_df['bbox'] = self.annotations_df['bbox'].apply(literal_eval)

print(f'Initialized dataset [{self.split} split].')

def __getitem__(self, i):
image_info = self.images_df.iloc[i]

path = os.path.join(self.image_folder, image_info['image_path_rel'])
image = Image.open(path)

box_idxs = self.bboxes[image_info['id']]
boxes = torch.FloatTensor([self.annotations_df.iloc[i]['bbox'] for i in box_idxs])

species = image_info['question__species'].lower()
label_step = self.classes_df.loc[self.classes_df['name'] == species, 'id']
label = self.classes_df.loc[self.classes_df['name'] == species, 'id'].iloc[0]
labels = torch.FloatTensor([label for _ in boxes])

if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)

return image, boxes, labels

def __len__(self):
return len(self.images_df)

def collate_fn(self, batch):
images = list()
boxes = list()
labels = list()

for b in batch:
images.append(b[0])
boxes.append(b[1])
labels.append(b[2])

images = torch.stack(images, dim=0)

return images, boxes, labels # tensor (N, 3, x, y), 3 lists of N tensors each

def get_classes(self):
classes_present = set()
for i, row in self.images_df.iterrows():
species = row['question__species'].lower()
classes_present.add(species)

filtered_classes = self.classes_df[self.classes_df['name'].isin(classes_present)]
return (filtered_classes)

def get_class_frequencies(self):
class_frequencies = {row['name']: 0 for i, row in self.get_classes().iterrows()}
for i, row in self.images_df.iterrows():
species = row['question__species'].lower()
box_idxs = self.bboxes[row['id']]
class_frequencies[species] += len(box_idxs)

return class_frequencies

def show_sample(sample, fractional=False):
if fractional:
sample = BBoxToBoundary()(sample)
image, bboxes, labels = sample

fig, ax = plt.subplots(1)
ax.imshow(image)
plt.title(labels)
for bbox in bboxes:
#bottom left, width, height
w, h = bbox[2], bbox[3]
x = bbox[0]
y = bbox[1]


rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor='black', facecolor='none')
ax.add_patch(rect)
plt.show()


def get_dataset_params(use_tmp=False, top_species=False):
'''
Utility function holding parameters used to initialize a dataset
:param use_tmp: set to true if image data is being stored in the GPU /tmp store
:param top_species: set to true to only use images with a sample of the 5 most common day and night species
'''
if use_tmp:
image_folder = '~/../../../tmp/snapshot-serengeti/'
else:
image_folder = '~/scratch/snapshot-serengeti/'

if top_species:
images_df = pd.read_csv('./snapshot-serengeti/bbox_images_top_species.csv')
else:
images_df = pd.read_csv('./snapshot-serengeti/bbox_images_non_empty_downloaded.csv')

annotations_df = pd.read_csv('./snapshot-serengeti/bbox_annotations_downloaded.csv')
classes_df = pd.read_csv('./snapshot-serengeti/classes.csv')
with open('./snapshot-serengeti/grayscale_images.txt', 'r') as f:
night_images = set(ast.literal_eval(f.read()))

return image_folder, images_df, annotations_df, classes_df, night_images


def main():
# dataset = SerengetiDataset(*get_dataset_params())
day_dataset = SerengetiDataset(*get_dataset_params(), split='DAY')
night_dataset = SerengetiDataset(*get_dataset_params(), split='NIGHT')
day_freqs = day_dataset.get_class_frequencies()
night_freqs = night_dataset.get_class_frequencies()
total_freqs = {k: (v, night_freqs[k]) for k, v in day_freqs.items() if k in night_freqs.keys()}
viable_freqs = {k: v for k, v in total_freqs if min(v) > 500}
print(total_freqs)



if __name__ == '__main__':
main()







# Reference Dataset
'''
class PascalVOCDataset(Dataset):
"""
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
Expand Down Expand Up @@ -83,3 +240,4 @@ def collate_fn(self, batch):
images = torch.stack(images, dim=0)

return images, boxes, labels, difficulties # tensor (N, 3, 300, 300), 3 lists of N tensors each
'''
51 changes: 51 additions & 0 deletions download_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from google.oauth2.service_account import Credentials
from google.cloud import storage
import pandas as pd
import os

from credentials import get_creds

def download_images(images):
#try with blob storage
creds = get_creds()
print(creds)
gcs_client = storage.Client(credentials=creds)
bucket_name = 'public-datasets-lila'
bucket = gcs_client.bucket(bucket_name)

folder_name = 'SSDataset'
if not os.path.exists(folder_name):
os.mkdir(folder_name)

downloaded_files = set()
downloaded_files_path = os.path.join(folder_name, 'downloaded_files.txt')
if os.path.exists(downloaded_files_path):
with open(downloaded_files_path, 'r') as f:
downloaded_files = set(f.read().splitlines())

# Download each file that hasn't already been downloaded
for path in images['image_path_rel']:
if path in downloaded_files:
continue

# Download the file from GCS
blob_name = f'snapshotserengeti-unzipped/{path}'
blob = bucket.blob(blob_name)
file_path = os.path.join(folder_name, images.loc[i, 'capture_id'] + '.jpg')
blob.download_to_filename(file_path)

# Record that the file has been downloaded
downloaded_files.add(path)
with open(downloaded_files_path, 'a') as f:
f.write(f'{path}\n')
print(f'Downloaded file {path}')


def main():
print()
print('Starting download')
images = pd.read_csv('./SSDataset/images.csv')
download_images(images)

if __name__ == '__main__':
main()
25 changes: 25 additions & 0 deletions gcreds/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow

def get_creds():
creds = None
token_path = 'token.json'
creds_path = 'creds.json'
SCOPES = ['https://www.googleapis.com/auth/devstorage.read_only']

if os.path.exists(token_path):
creds = Credentials.from_authorized_user_file(token_path, SCOPES)

if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
creds_path, SCOPES)
creds = flow.run_local_server(port=8091)

with open(token_path, 'w') as token:
token.write(creds.to_json())

return creds
1 change: 1 addition & 0 deletions gcreds/creds.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"installed":{"client_id":"667058587023-dm4qo00lk7k60mgnjo9g6tqedrtcu63b.apps.googleusercontent.com","project_id":"prjx-385715","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"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","redirect_uris":["http://localhost"]}}
Loading