Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix marine infrastructure filter unit test #84

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions rslp/utils/filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Filters for vessel detection projects."""

import functools
import json

import numpy as np
import requests
from upath import UPath


class Filter:
Expand All @@ -30,22 +31,17 @@ def should_filter(self, lat: float, lon: float) -> bool:


@functools.cache
def get_infra_latlons(infra_url: str) -> tuple[np.ndarray, np.ndarray]:
def get_infra_latlons(infra_path: UPath) -> tuple[np.ndarray, np.ndarray]:
"""Fetch and cache the infrastructure latitudes and longitudes.
Args:
infra_url: URL to the marine infrastructure GeoJSON file.
infra_path: path to the marine infrastructure GeoJSON file.
Returns:
A tuple of arrays: (latitudes, longitudes).
"""
try:
# Read the geojson data from the URL.
response = requests.get(infra_url, timeout=10)
response.raise_for_status() # Raise an error for bad responses
geojson_data = response.json()
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch infrastructure data: {e}")
with infra_path.open("r") as f:
geojson_data = json.load(f)

lats = np.array(
[feature["geometry"]["coordinates"][1] for feature in geojson_data["features"]]
Expand All @@ -72,7 +68,7 @@ def __init__(
infra_distance_threshold: distance threshold for marine infrastructure.
"""
self.infra_url = infra_url
self.infra_latlons = get_infra_latlons(self.infra_url)
self.infra_latlons = get_infra_latlons(UPath(self.infra_url))
self.infra_distance_threshold = infra_distance_threshold

def _get_haversine_distances(
Expand Down
65 changes: 47 additions & 18 deletions tests/unit/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,54 @@
import json
import pathlib

import pytest

from rslp.utils.filter import NearInfraFilter

TEST_INFRA_LON = 1.234
favyen2 marked this conversation as resolved.
Show resolved Hide resolved
TEST_INFRA_LAT = 5.678


def test_near_infra_filter() -> None:
# Test case 1: Detection is exactly on infrastructure.
# The coordinates are directly extracted from the geojson file.
infra_lat = 16.613
infra_lon = 103.381
class TestNearInfraFilter:
@pytest.fixture
def single_point_infra_filter(self, tmp_path: pathlib.Path) -> NearInfraFilter:
geojson_data = {
"type": "FeatureCollection",
"properties": {},
"features": [
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Point",
"coordinates": [TEST_INFRA_LON, TEST_INFRA_LAT],
},
}
],
}
fname = tmp_path / "data.geojson"
with fname.open("w") as f:
json.dump(geojson_data, f)

filter = NearInfraFilter()
return NearInfraFilter(infra_url=str(fname))

# Since this point is exactly an infrastructure point, the filter should discard it (return True)
assert filter.should_filter(
infra_lat, infra_lon
), "Detection should be filtered out as it is located on infrastructure."
def test_exactly_on_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is exactly on infrastructure.
# The coordinates are directly extracted from the geojson file.
# Since this point is exactly an infrastructure point, the filter should discard it (return True)
assert single_point_infra_filter.should_filter(
TEST_INFRA_LAT,
TEST_INFRA_LON,
), "Detection should be filtered out as it is located on infrastructure."

# Test case 2: Detection is close to infrastructure.
assert filter.should_filter(
infra_lat + 0.0001, infra_lon + 0.0001
), "Detection should be filtered out as it is too close to infrastructure."
def test_close_to_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is close to infrastructure.
assert single_point_infra_filter.should_filter(
TEST_INFRA_LAT + 0.0001, TEST_INFRA_LON + 0.0001
), "Detection should be filtered out as it is too close to infrastructure."

# Test case 3: Detection is far from infrastructure.
assert not filter.should_filter(
infra_lat + 0.5, infra_lon + 0.5
), "Detection should be kept as it is far from infrastructure."
def test_far_from_infra(self, single_point_infra_filter: NearInfraFilter) -> None:
# Test when detection is far from infrastructure.
assert not single_point_infra_filter.should_filter(
TEST_INFRA_LAT + 0.5, TEST_INFRA_LON + 0.5
), "Detection should be kept as it is far from infrastructure."
Loading