-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheverglades_species.py
200 lines (164 loc) · 8.42 KB
/
everglades_species.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#DeepForest bird detection from extracted Zooniverse predictions
import comet_ml
from pytorch_lightning.loggers import CometLogger
from deepforest import main
from deepforest import dataset
from deepforest import utilities
import create_species_model
from empty_frames_utilities import *
from evaluate import evaluate_model
import pandas as pd
import os
import numpy as np
import traceback
import torch
import tempfile
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path, PurePath
from pytorch_lightning import Trainer
from datetime import datetime
def get_species_abbrev_lookup(species_lookup):
species_abbrev_lookup = {}
for number, species in species_lookup.items():
split_name = species.split()
abbrev = ''
for sub_name in split_name:
abbrev += sub_name[0]
species_abbrev_lookup[number] = abbrev
return species_abbrev_lookup
def index_to_example(index, results, test_path, comet_experiment):
"""Make example images of for confusion matrix"""
tmpdir = tempfile.gettempdir()
results = results.iloc[index]
xmin = results['xmin']
xmax = results['xmax']
ymin = results['ymin']
ymax = results['ymax']
image_name = results['image_path']
test_image_path = Path(test_path).parent
image_path = PurePath(Path(test_image_path), Path(image_name))
print(image_path)
image = Image.open(str(image_path))
draw = ImageDraw.Draw(image, "RGB")
draw.rectangle((xmin, ymin, xmax, ymax), outline = (255, 255, 255), width=2)
font = ImageFont.truetype("Gidole-Regular.ttf", 20)
draw.text((xmin - 150, ymin - 150), f"image={image_name}\nxmin={xmin}, xmax={xmax}, ymin={ymin}, ymax={ymax}", fill=(255, 255, 255), font=font)
image = image.crop((xmin - 200, ymin - 200, xmax + 200, ymax + 200))
tmp_image_name = f"{tmpdir}/confusion-matrix-{index}.png"
image.save(tmp_image_name)
results = comet_experiment.log_image(
tmp_image_name, name=image_name,
)
plt.close("all")
# Return sample, assetId (index is added automatically)
return {"sample": tmp_image_name, "assetId": results["imageId"]}
def train_model(train_path, test_path, empty_images_path=None, save_dir=".",
gbd_pretrain = True,
experiment_name="ev-species",
debug = False):
"""Train a DeepForest model"""
comet_logger = CometLogger(project_name="everglades-species", workspace="weecology", experiment_name=experiment_name)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_savedir = "{}/{}".format(save_dir,timestamp)
tmpdir = tempfile.gettempdir()
try:
os.mkdir(model_savedir)
except Exception as e:
print(e)
comet_logger.experiment.log_parameter("timestamp",timestamp)
comet_logger.experiment.add_tag("species")
# Log the number of training and test
train = pd.read_csv(train_path)
test = pd.read_csv(test_path)
# Add weak annotations from photoshop to train
weak_train = pd.read_csv("/blue/ewhite/everglades/photoshop_annotations/split_annotations.csv")
train = pd.concat([train, weak_train])
train = train[train.label.isin(['Great Egret', 'Roseate Spoonbill', 'White Ibis',
'Great Blue Heron', 'Wood Stork', 'Snowy Egret', 'Anhinga'])]
test = test[test.label.isin(train.label)]
# Add in weak annotations for empty frames
empty_frames = pd.read_csv("/blue/ewhite/everglades/photoshop_annotations/inferred_empty_annotations.csv")
empty_frames = empty_frames.sample(n=1000)
empty_frames.image_path = empty_frames.image_path.apply(lambda x: os.path.basename(x))
# Confirm no name overlaps
overlapping_images = train[train.image_path.isin(empty_frames.image_path.unique())]
if not len(overlapping_images) == 0:
raise IOError("Overlapping images: {}".format(overlapping_images))
train = pd.concat([train, empty_frames])
# Store test train split for run to allow multiple simultaneous run starts
train_path = str(PurePath(Path(train_path).parents[0], Path(f'species_train_{timestamp}.csv')))
test_path = str(PurePath(Path(test_path).parents[0], Path(f'species_test_{timestamp}.csv')))
train.to_csv(train_path)
test.to_csv(test_path)
comet_logger.experiment.log_table("train.csv", train)
comet_logger.experiment.log_table("test.csv", test)
# Set config and train
label_dict = {key:value for value, key in enumerate(train.label.unique())}
species_lookup = {value:key for key, value in label_dict.items()}
species_abbrev_lookup = get_species_abbrev_lookup(species_lookup)
model = main.deepforest(num_classes=len(train.label.unique()),label_dict=label_dict)
if gbd_pretrain:
# Use the backbone and regression head from the global bird detector to transfer
# learning about bird detection and bird related features
global_bird_detector = main.deepforest()
global_bird_detector.use_bird_release()
model.model.backbone.load_state_dict(global_bird_detector.model.backbone.state_dict())
model.model.head.regression_head.load_state_dict(global_bird_detector.model.head.regression_head.state_dict())
model.config["train"]["csv_file"] = train_path
model.config["train"]["root_dir"] = os.path.dirname(train_path)
# Set config and train
model.config["validation"]["csv_file"] = test_path
model.config["validation"]["root_dir"] = os.path.dirname(test_path)
if debug:
model.config["train"]["fast_dev_run"] = True
model.config["gpus"] = None
model.config["workers"] = 1
model.config["batch_size"] = 1
if comet_logger is not None:
comet_logger.experiment.log_parameters(model.config)
comet_logger.experiment.log_parameter("Training_Annotations",train.shape[0])
comet_logger.experiment.log_parameter("Testing_Annotations",test.shape[0])
comet_logger.experiment.log_parameter("model_savedir",model_savedir)
# Image callback significantly slows down training time, but can be helpful for debugging.
# im_callback = images_callback(csv_file=model.config["validation"]["csv_file"], root_dir=model.config["validation"]["root_dir"], savedir=model_savedir, n=20)
trainer = Trainer(
accelerator="gpu",
strategy="ddp",
devices=model.config["gpus"],
enable_checkpointing=False,
max_epochs=model.config["train"]["epochs"],
logger=comet_logger
)
ds = dataset.TreeDataset(csv_file=model.config["train"]["csv_file"],
root_dir=model.config["train"]["root_dir"],
transforms=dataset.get_transform(augment=True),
label_dict=model.label_dict)
dataloader = torch.utils.data.DataLoader(ds,
batch_size = model.config["batch_size"],
collate_fn=utilities.collate_fn,
num_workers=model.config["workers"])
trainer.fit(model, dataloader)
trainer.save_checkpoint("{}/species_model.pl".format(model_savedir))
evaluate_model(test_path=test_path,
model_path="{}/species_model.pl".format(model_savedir),
save_dir=save_dir,
comet_logger=comet_logger,
timestamp=timestamp)
return model
if __name__ == "__main__":
regenerate = False
max_empty_frames = 0
if regenerate:
print(f"[INFO] Regenerating dataset with up to {max_empty_frames} empty frames")
create_species_model.generate(shp_dir="/blue/ewhite/everglades/Zooniverse/parsed_images/",
empty_frames_path="/blue/ewhite/everglades/Zooniverse/parsed_images/empty_frames.csv",
save_dir="/blue/ewhite/everglades/Zooniverse/predictions/",
max_empty_frames=max_empty_frames,
buffer=25)
train_model(train_path="/blue/ewhite/everglades/Zooniverse/parsed_images/species_train_resized.csv",
test_path="/blue/ewhite/everglades/Zooniverse/cleaned_test/test_resized.csv",
save_dir="/blue/ewhite/everglades/Zooniverse/",
gbd_pretrain=True,
empty_images_path="/blue/ewhite/everglades/Zooniverse/parsed_images/empty_frames.csv",
experiment_name="main")