Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 61 additions & 62 deletions detectree2/models/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,27 +191,24 @@ def tree_height(self):
# Regular functions now
def get_tile_width(file):
"""Split up the file name to get width and buffer then adding to get overall width."""
filename = file.replace(".geojson", "")
filename_split = filename.split("_")

tile_width = (2 * int(filename_split[-2]) + int(filename_split[-3]))

file = Path(file)
filename_split = file.stem.split("_")
tile_width = 2 * int(filename_split[-2]) + int(filename_split[-3])
return tile_width


def get_epsg(file):
"""Splitting up the file name to get EPSG"""
filename = file.replace(".geojson", "")
filename_split = filename.split("_")

file = Path(file)
filename_split = file.stem.split("_")
epsg = filename_split[-1]
return epsg


def get_tile_origin(file):
"""Splitting up the file name to get tile origin"""
filename = file.replace(".geojson", "")
filename_split = filename.split("_")
file = Path(file)
filename_split = file.stem.split("_")

buffer = int(filename_split[-2])
# center = int(filename_split[-3])
Expand Down Expand Up @@ -343,7 +340,7 @@ def initialise_feats(

def initialise_feats2(
directory,
file,
filename,
lidar_img,
area_threshold,
conf_threshold,
Expand All @@ -353,21 +350,21 @@ def initialise_feats2(
epsg
):
"""Creates a list of all the features as objects of the class."""
with open(directory + "/" + file) as feat_file:
feat_path = Path(directory) / filename
with feat_path.open() as feat_file:
feat_json = json.load(feat_file)
feats = feat_json["features"]

feats = feat_json["features"]
all_feats = []
count = 0

for feat in feats:
feat_obj = GeoFeature(file, directory, count, feat, lidar_img, epsg)
feat_obj = GeoFeature(filename, directory, count, feat, lidar_img, epsg)

if feat_threshold_tests2(feat_obj, conf_threshold, area_threshold,
border_filter, tile_width, tile_origin):
all_feats.append(feat_obj)
count += 1
else:
continue

return all_feats

Expand Down Expand Up @@ -634,10 +631,10 @@ def site_f1_score2(
area of the corresponding polygons.

Args:
tile_directory: path to the folder containing all of the tiles
test_directory: path to the folder containing just the test files
pred_directory: path to the folder containing the predictions and the reprojections
lidar_img: path to the lidar image of an entire region
tile_directory (str | Path): path to the folder containing all of the tiles
test_directory (str | Path): path to the folder containing just the test files
pred_directory (str | Path): path to the folder containing the predictions and the reprojections
lidar_img (str | Path): path to the lidar image of an entire region
IoU_threshold: minimum value of IoU such that the intersection can be considered a true positive
min_height: minimum height of the features to be considered
max_height: minimum height of the features to be considered
Expand All @@ -649,69 +646,71 @@ def site_f1_score2(
save: bool to tell program whether the filtered crowns should be saved
"""

test_entries = os.listdir(test_directory)
total_tps = 0
total_fps = 0
total_fns = 0
tile_directory = Path(tile_directory)
test_directory = Path(test_directory)
pred_directory = Path(pred_directory)

total_tps = total_fps = total_fns = 0
heights = []
# total_tests = 0

for file in test_entries:
if ".geojson" in file:
print(file)
for file in test_directory.iterdir():
if file.suffix != ".geojson":
continue

# work out the area threshold to ignore these crowns in the tiles
# tile_width = get_tile_width(file) * scaling[0]
# area_threshold = ((tile_width)**2) * area_fraction_limit
print(file.name)

tile_width = get_tile_width(file)
tile_origin = get_tile_origin(file)
epsg = get_epsg(file)
# work out the area threshold to ignore these crowns in the tiles
# tile_width = get_tile_width(file) * scaling[0]
# area_threshold = ((tile_width)**2) * area_fraction_limit

test_file = file #.replace(".geojson", "_geo.geojson")
all_test_feats = initialise_feats2(tile_directory, test_file,
lidar_img, area_threshold,
conf_threshold, border_filter,
tile_width, tile_origin, epsg)
tile_width = get_tile_width(file)
tile_origin = get_tile_origin(file)
epsg = get_epsg(file)

new_heights = get_heights(all_test_feats, min_height, max_height)
heights.extend(new_heights)
all_test_feats = initialise_feats2(tile_directory, file.name,
lidar_img, area_threshold,
conf_threshold, border_filter,
tile_width, tile_origin, epsg)

pred_file = "Prediction_" + file.replace(".geojson", "_eval.geojson")
all_pred_feats = initialise_feats2(pred_directory, pred_file,
lidar_img, area_threshold,
conf_threshold, border_filter,
tile_width, tile_origin, epsg)
new_heights = get_heights(all_test_feats, min_height, max_height)
heights.extend(new_heights)

if save:
save_feats(tile_directory, all_test_feats)
save_feats(tile_directory, all_pred_feats)
pred_file = "Prediction_" + file.stem + "_eval.geojson"
all_pred_feats = initialise_feats2(pred_directory, pred_file,
lidar_img, area_threshold,
conf_threshold, border_filter,
tile_width, tile_origin, epsg)

find_intersections(all_test_feats, all_pred_feats)
tps, fps, fns = positives_test(all_test_feats, all_pred_feats,
IoU_threshold, min_height, max_height)
if save:
save_feats(tile_directory, all_test_feats)
save_feats(tile_directory, all_pred_feats)

print("tps:", tps)
print("fps:", fps)
print("fns:", fns)
print("")
find_intersections(all_test_feats, all_pred_feats)
tps, fps, fns = positives_test(all_test_feats, all_pred_feats,
IoU_threshold, min_height, max_height)

total_tps += tps
total_fps += fps
total_fns += fns
print(f"tps: {tps}")
print(f"fps: {fps}")
print(f"fns: {fns}\n")

total_tps += tps
total_fps += fps
total_fns += fns

try:
prec, rec = prec_recall(total_tps, total_fps, total_fns)
# not used!
f1_score = f1_cal(prec, rec) # noqa: F841
med_height = median(heights)
print("Precision ", "Recall ", "F1")
print(prec, rec, f1_score)
print(" ")
print("Total_trees=", len(heights))
print("med_height=", med_height)
print(prec, rec, f1_score, "\n")
print(f"Total_trees = {len(heights)}")
print(f"med_height = {med_height}")

except ZeroDivisionError:
print("ZeroDivisionError: Height threshold is too large.")

return prec, rec, f1_score


Expand Down
Loading
Loading