Skip to content

Commit

Permalink
Merge pull request #84 from allenai/favyen/fix-infra-unit-test
Browse files Browse the repository at this point in the history
Fix marine infrastructure filter unit test
  • Loading branch information
favyen2 authored Dec 20, 2024
2 parents 6723324 + ab1b146 commit 5071092
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 29 deletions.
18 changes: 7 additions & 11 deletions rslp/utils/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Filters for vessel detection projects."""

import functools
import json

import numpy as np
import requests
from upath import UPath


class Filter:
Expand All @@ -30,22 +31,17 @@ def should_filter(self, lat: float, lon: float) -> bool:


@functools.cache
def get_infra_latlons(infra_url: str) -> tuple[np.ndarray, np.ndarray]:
def get_infra_latlons(infra_path: UPath) -> tuple[np.ndarray, np.ndarray]:
"""Fetch and cache the infrastructure latitudes and longitudes.
Args:
infra_url: URL to the marine infrastructure GeoJSON file.
infra_path: path to the marine infrastructure GeoJSON file.
Returns:
A tuple of arrays: (latitudes, longitudes).
"""
try:
# Read the geojson data from the URL.
response = requests.get(infra_url, timeout=10)
response.raise_for_status() # Raise an error for bad responses
geojson_data = response.json()
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch infrastructure data: {e}")
with infra_path.open("r") as f:
geojson_data = json.load(f)

lats = np.array(
[feature["geometry"]["coordinates"][1] for feature in geojson_data["features"]]
Expand All @@ -72,7 +68,7 @@ def __init__(
infra_distance_threshold: distance threshold for marine infrastructure.
"""
self.infra_url = infra_url
self.infra_latlons = get_infra_latlons(self.infra_url)
self.infra_latlons = get_infra_latlons(UPath(self.infra_url))
self.infra_distance_threshold = infra_distance_threshold

def _get_haversine_distances(
Expand Down
65 changes: 47 additions & 18 deletions tests/unit/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,54 @@
import json
import pathlib

import pytest

from rslp.utils.filter import NearInfraFilter

TEST_INFRA_LON = 1.234
TEST_INFRA_LAT = 5.678


def test_near_infra_filter() -> None:
# Test case 1: Detection is exactly on infrastructure.
# The coordinates are directly extracted from the geojson file.
infra_lat = 16.613
infra_lon = 103.381
class TestNearInfraFilter:
@pytest.fixture
def single_point_infra_filter(self, tmp_path: pathlib.Path) -> NearInfraFilter:
geojson_data = {
"type": "FeatureCollection",
"properties": {},
"features": [
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Point",
"coordinates": [TEST_INFRA_LON, TEST_INFRA_LAT],
},
}
],
}
fname = tmp_path / "data.geojson"
with fname.open("w") as f:
json.dump(geojson_data, f)

filter = NearInfraFilter()
return NearInfraFilter(infra_url=str(fname))

# Since this point is exactly an infrastructure point, the filter should discard it (return True)
assert filter.should_filter(
infra_lat, infra_lon
), "Detection should be filtered out as it is located on infrastructure."
def test_exactly_on_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is exactly on infrastructure.
# The coordinates are directly extracted from the geojson file.
# Since this point is exactly an infrastructure point, the filter should discard it (return True)
assert single_point_infra_filter.should_filter(
TEST_INFRA_LAT,
TEST_INFRA_LON,
), "Detection should be filtered out as it is located on infrastructure."

# Test case 2: Detection is close to infrastructure.
assert filter.should_filter(
infra_lat + 0.0001, infra_lon + 0.0001
), "Detection should be filtered out as it is too close to infrastructure."
def test_close_to_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is close to infrastructure.
assert single_point_infra_filter.should_filter(
TEST_INFRA_LAT + 0.0001, TEST_INFRA_LON + 0.0001
), "Detection should be filtered out as it is too close to infrastructure."

# Test case 3: Detection is far from infrastructure.
assert not filter.should_filter(
infra_lat + 0.5, infra_lon + 0.5
), "Detection should be kept as it is far from infrastructure."
def test_far_from_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is far from infrastructure.
assert not single_point_infra_filter.should_filter(
TEST_INFRA_LAT + 0.5, TEST_INFRA_LON + 0.5
), "Detection should be kept as it is far from infrastructure."

0 comments on commit 5071092

Please sign in to comment.