diff --git a/src/data_processing.py b/src/data_processing.py index dc76a01..ccff97c 100644 --- a/src/data_processing.py +++ b/src/data_processing.py @@ -2,55 +2,143 @@ import os from logging import warn from deepforest import preprocess +from typing import Optional, Union, List -def undersample(train_df, ratio): - """Undersample top classes by selecting most diverse images""" +def undersample(train_df: pd.DataFrame, ratio: float) -> pd.DataFrame: + """ + Undersample top classes by selecting most diverse images. + + This function reduces class imbalance by removing images that only contain + the two most common classes, while preserving images that contain additional + species. + + Args: + train_df: DataFrame containing training annotations with 'label' and 'image_path' columns + ratio: Float between 0 and 1 indicating what fraction of top-class-only images to keep + + Returns: + DataFrame with undersampled annotations + + Example: + >>> train_df = pd.DataFrame({ + ... 'label': ['Bird', 'Bird', 'Rare', 'Bird'], + ... 'image_path': ['img1.jpg', 'img2.jpg', 'img3.jpg', 'img3.jpg'] + ... }) + >>> undersampled_df = undersample(train_df, ratio=0.5) + """ + if not 0 <= ratio <= 1: + raise ValueError("Ratio must be between 0 and 1") # Find images that only have top two most common classes top_two_classes = train_df.label.value_counts().index[:2] top_two_labels = train_df[train_df["label"].isin(top_two_classes)] - # remove images that have any other classes - top_two_images= train_df[train_df.image_path.isin(top_two_labels.image_path.unique())] + # Remove images that have any other classes + top_two_images = train_df[train_df.image_path.isin(top_two_labels.image_path.unique())] with_additional_species = top_two_images[~top_two_images["label"].isin(top_two_classes)].image_path.unique() - images_to_remove = [x for x in top_two_images.image_path.unique() if x not in with_additional_species][:int(len(with_additional_species)*ratio)] + images_to_remove = [x for x in top_two_images.image_path.unique() if x not in with_additional_species] + images_to_remove = images_to_remove[:int(len(with_additional_species)*ratio)] train_df = train_df[~train_df["image_path"].isin(images_to_remove)] return train_df -def preprocess_images(annotations, root_dir, save_dir, limit_empty_frac=0.1, patch_size=450, patch_overlap=0): - """Cut images into GPU friendly chunks""" +def preprocess_images( + annotations: pd.DataFrame, + root_dir: str, + save_dir: str, + limit_empty_frac: float = 0.1, + patch_size: int = 450, + patch_overlap: int = 0 +) -> pd.DataFrame: + """ + Cut images into GPU-friendly chunks and process annotations accordingly. + + This function splits large images into smaller patches and adjusts their + annotations to match the new coordinates. It also handles empty patches + and maintains a balanced dataset. + + Args: + annotations: DataFrame containing image annotations + root_dir: Root directory containing the original images + save_dir: Directory to save processed image patches + limit_empty_frac: Maximum fraction of empty patches to keep + patch_size: Size of the output patches in pixels + patch_overlap: Overlap between patches in pixels + + Returns: + DataFrame containing annotations for the processed image patches + + Raises: + FileNotFoundError: If root_dir or image files don't exist + ValueError: If patch_size <= 0 or patch_overlap < 0 + """ + if patch_size <= 0 or patch_overlap < 0: + raise ValueError("Invalid patch_size or patch_overlap") + + if not os.path.exists(root_dir): + raise FileNotFoundError(f"Root directory not found: {root_dir}") + + os.makedirs(save_dir, exist_ok=True) + crop_annotations = [] for image_path in annotations.image_path.unique(): annotation_df = annotations[annotations.image_path == image_path] - annotation_df = annotation_df[~annotation_df.xmin.isnull()] + if annotation_df.empty: allow_empty = True - annotation_df = None else: allow_empty = False + crop_annotation = process_image( - image_path, - annotation_df=annotation_df, - root_dir=root_dir, - save_dir=save_dir, - patch_size=patch_size, - patch_overlap=patch_overlap, + image_path=image_path, + annotation_df=annotation_df, + root_dir=root_dir, + save_dir=save_dir, + patch_size=patch_size, + patch_overlap=patch_overlap, allow_empty=allow_empty ) crop_annotations.append(crop_annotation) crop_annotations = pd.concat(crop_annotations) - return crop_annotations -def process_image(image_path, annotation_df, root_dir, save_dir, patch_size, patch_overlap, allow_empty): +def process_image( + image_path: str, + annotation_df: Optional[pd.DataFrame], + root_dir: str, + save_dir: str, + patch_size: int, + patch_overlap: int, + allow_empty: bool +) -> pd.DataFrame: + """ + Process a single image by splitting it into patches and adjusting annotations. + + Args: + image_path: Path to the image file + annotation_df: DataFrame containing annotations for this image, or None if empty + root_dir: Root directory containing the original images + save_dir: Directory to save processed image patches + patch_size: Size of the output patches in pixels + patch_overlap: Overlap between patches in pixels + allow_empty: Whether to allow patches without annotations + + Returns: + DataFrame containing annotations for the processed image patches + + Note: + If the crops already exist in save_dir, they will be skipped and + the existing annotations will be returned. + """ image_name = os.path.splitext(os.path.basename(image_path))[0] crop_csv = "{}.csv".format(os.path.join(save_dir, image_name)) + if os.path.exists(crop_csv): warn("Crops for {} already exist in {}. Skipping.".format(crop_csv, save_dir)) return pd.read_csv(crop_csv) + full_path = os.path.join(root_dir, image_path) crop_annotation = preprocess.split_raster( @@ -62,6 +150,7 @@ def process_image(image_path, annotation_df, root_dir, save_dir, patch_size, pat root_dir=root_dir, allow_empty=allow_empty ) + if annotation_df is None: empty_annotations = [] for i in range(len(crop_annotation)): diff --git a/src/model.py b/src/model.py index 484f3f1..0724640 100644 --- a/src/model.py +++ b/src/model.py @@ -172,7 +172,7 @@ def train(model, train_annotations, test_annotations, train_image_dir, comet_pro for filename in sample_train_annotations.image_path: sample_train_annotations_for_image = sample_train_annotations[sample_train_annotations.image_path == filename] sample_train_annotations_for_image.root_dir = train_image_dir - visualize.plot_results(sample_train_annotations_for_image) + visualize.plot_results(sample_train_annotations_for_image, savedir=tmpdir) comet_logger.experiment.log_image(os.path.join(tmpdir, filename)) model.trainer.fit(model) diff --git a/src/pipeline.py b/src/pipeline.py index f9acb89..3a05c38 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -50,7 +50,7 @@ def run(self): reporting = Reporting() reporting.generate_reports(pipeline_monitor) - if performance.success: + if pipeline_monitor.check_success(): print("Pipeline performance is satisfactory, exiting") return None else: diff --git a/src/pipeline_evaluation.py b/src/pipeline_evaluation.py index eeab898..6ccf9d6 100644 --- a/src/pipeline_evaluation.py +++ b/src/pipeline_evaluation.py @@ -2,6 +2,7 @@ from torchmetrics.detection import MeanAveragePrecision from torchmetrics.classification import Accuracy from torchmetrics.functional import confusion_matrix +from src.model import predict import pandas as pd class PipelineEvaluation: @@ -55,8 +56,9 @@ def _format_targets(self, annotations_df): return targets def evaluate_detection(self): - preds = self.model.predict( - self.detection_annotations_df.image_path.tolist(), + preds = predict( + model=self.model, + image_paths=self.detection_annotations_df.image_path.tolist(), patch_size=self.patch_size, patch_overlap=self.patch_overlap, min_score=self.min_score @@ -94,6 +96,14 @@ def evaluate(self): self.confident_classification_results = self.confident_classification_accuracy() self.uncertain_classification_results = self.uncertain_classification_accuracy() + def check_success(self): + """Check if pipeline performance is satisfactory""" + # For each metric, check if it is above the threshold + if self.detection_results['detection']["true_positive_rate"] > self.detection_true_positive_threshold: + return True + else: + return False + def report(self): """Generate a report of the pipeline evaluation""" results = { diff --git a/src/propagate.py b/src/propagate.py index ba338a0..519b147 100644 --- a/src/propagate.py +++ b/src/propagate.py @@ -1,10 +1,26 @@ import pandas as pd import numpy as np -from datetime import datetime -from typing import List, Tuple, Dict +from datetime import datetime, timedelta +from typing import List, Tuple, Dict, Optional import os class LabelPropagator: + """ + A class to propagate object detection labels across temporally and spatially close images. + + This class analyzes image timestamps and object locations to identify potential missed + detections in sequential images. When an object is detected in one image but missing in + a temporally close image with similar spatial coordinates, the label is propagated. + + Attributes: + time_threshold (int): Maximum time difference in seconds between images to consider for propagation + distance_threshold (float): Maximum Euclidean distance in pixels between objects to consider them the same + + Example: + >>> propagator = LabelPropagator(time_threshold_seconds=5, distance_threshold_pixels=50) + >>> propagated_df = propagator.propagate_labels(annotations_df) + """ + def __init__(self, time_threshold_seconds: int = 5, distance_threshold_pixels: float = 50): """ Initialize the label propagator. @@ -16,31 +32,68 @@ def __init__(self, time_threshold_seconds: int = 5, distance_threshold_pixels: f self.time_threshold = time_threshold_seconds self.distance_threshold = distance_threshold_pixels - def _parse_timestamp(self, filename: str) -> datetime: - """Extract timestamp from image filename or metadata.""" - # Implement based on your filename format - # Example: "IMG_20230615_123456.jpg" -> datetime(2023, 06, 15, 12, 34, 56) + def _parse_timestamp(self, filename: str) -> Optional[datetime]: + """ + Extract timestamp from image filename. + + Args: + filename: Name of the image file with embedded timestamp + + Returns: + datetime object if parsing successful, None otherwise + + Example: + >>> propagator._parse_timestamp("IMG_20230615_123456.jpg") + datetime(2023, 6, 15, 12, 34, 56) + """ try: - # Modify this according to your actual filename format date_str = filename.split('_')[1] + filename.split('_')[2].split('.')[0] return datetime.strptime(date_str, '%Y%m%d%H%M%S') except: return None def _calculate_center(self, bbox: Tuple[float, float, float, float]) -> Tuple[float, float]: - """Calculate center point of bounding box.""" + """ + Calculate center point of bounding box. + + Args: + bbox: Tuple of (xmin, ymin, xmax, ymax) coordinates + + Returns: + Tuple of (x, y) coordinates of the center point + """ x1, y1, x2, y2 = bbox return ((x1 + x2) / 2, (y1 + y2) / 2) def _calculate_distance(self, point1: Tuple[float, float], point2: Tuple[float, float]) -> float: - """Calculate Euclidean distance between two points.""" + """ + Calculate Euclidean distance between two points. + + Args: + point1: Tuple of (x, y) coordinates of first point + point2: Tuple of (x, y) coordinates of second point + + Returns: + Euclidean distance between the points + """ return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2) def _find_temporal_neighbors(self, annotations_df: pd.DataFrame) -> Dict[str, List[str]]: - """Find temporally close images.""" + """ + Find temporally close images within the time threshold. + + Args: + annotations_df: DataFrame containing image annotations + + Returns: + Dictionary mapping each image to list of temporal neighbors + + Note: + Images are considered neighbors if their timestamps are within + the time_threshold of each other. + """ temporal_neighbors = {} - # Sort by timestamp timestamps = {row['image_path']: self._parse_timestamp(os.path.basename(row['image_path'])) for _, row in annotations_df.iterrows()} @@ -67,35 +120,33 @@ def propagate_labels(self, annotations_df: pd.DataFrame) -> pd.DataFrame: """ Propagate labels to temporally and spatially close objects. + This method analyzes the input annotations and propagates labels to nearby + images when an object is detected in one image but missing in temporally + close images at similar spatial coordinates. + Args: annotations_df: DataFrame with columns ['image_path', 'xmin', 'ymin', 'xmax', 'ymax', 'label'] Returns: - DataFrame with propagated labels + DataFrame with original and propagated labels, including a 'propagated' column + + Note: + Propagated labels are marked with propagated=True in the output DataFrame """ - # Create a copy to store propagated annotations propagated_df = annotations_df.copy() - - # Find temporal neighbors temporal_neighbors = self._find_temporal_neighbors(annotations_df) - - # Store new annotations to be added new_annotations = [] - # For each image with annotations for img1 in temporal_neighbors: img1_annotations = annotations_df[annotations_df['image_path'] == img1] - # For each temporal neighbor for img2 in temporal_neighbors[img1]: img2_annotations = annotations_df[annotations_df['image_path'] == img2] - # For each object in img1 for _, obj1 in img1_annotations.iterrows(): bbox1 = (obj1['xmin'], obj1['ymin'], obj1['xmax'], obj1['ymax']) center1 = self._calculate_center(bbox1) - # Check if there's a matching object in img2 match_found = False for _, obj2 in img2_annotations.iterrows(): bbox2 = (obj2['xmin'], obj2['ymin'], obj2['xmax'], obj2['ymax']) @@ -106,18 +157,15 @@ def propagate_labels(self, annotations_df: pd.DataFrame) -> pd.DataFrame: match_found = True break - # If no match found, propagate the label if not match_found: new_annotation = obj1.copy() new_annotation['image_path'] = img2 new_annotation['propagated'] = True new_annotations.append(new_annotation) - # Add propagated annotations if new_annotations: propagated_df = pd.concat([propagated_df, pd.DataFrame(new_annotations)], ignore_index=True) - # Add propagated column if it doesn't exist if 'propagated' not in propagated_df.columns: propagated_df['propagated'] = False diff --git a/tests/conftest.py b/tests/conftest.py index 104c8ce..8710258 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,11 +42,11 @@ def config(tmpdir_factory): # Create sample bounding box annotations data = { 'image_path': ['empty.jpg', 'birds.jpg', 'birds_val.jpg'], - 'xmin': [None, 200, 150], - 'ymin': [None, 300, 250], - 'xmax': [None, 300, 250], - 'ymax': [None, 400, 350], - 'label': ['None', 'Bird', 'Bird'], + 'xmin': [0, 200, 150], + 'ymin': [0, 300, 250], + 'xmax': [0, 300, 250], + 'ymax': [0, 400, 350], + 'label': ['Bird', 'Bird', 'Bird'], 'annotator': ['test_user', 'test_user', 'test_user'] } diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index e69de29..ef270a9 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -0,0 +1,233 @@ +import pytest +import pandas as pd +import os +import numpy as np +from src.data_processing import undersample, preprocess_images, process_image + +@pytest.fixture +def sample_annotations(): + """Create sample annotations for testing.""" + return pd.DataFrame({ + 'image_path': ['img1.jpg', 'img1.jpg', 'img2.jpg', 'img3.jpg', 'img3.jpg'], + 'label': ['Bird', 'Bird', 'Bird', 'Bird', 'Rare'], + 'xmin': [100, 200, 300, 400, 500], + 'ymin': [100, 200, 300, 400, 500], + 'xmax': [150, 250, 350, 450, 550], + 'ymax': [150, 250, 350, 450, 550] + }) + +@pytest.fixture +def empty_annotations(): + """Create annotations with some empty images.""" + return pd.DataFrame({ + 'image_path': ['img1.jpg', 'img2.jpg'], + 'label': ['Bird', None], + 'xmin': [100, None], + 'ymin': [100, None], + 'xmax': [150, None], + 'ymax': [150, None] + }) + +@pytest.fixture +def mock_split_raster(monkeypatch): + """Mock the split_raster function.""" + def mock_fn(*args, **kwargs): + return pd.DataFrame({ + 'image_path': ['patch1.jpg'], + 'xmin': [10], + 'ymin': [10], + 'xmax': [20], + 'ymax': [20] + }) + + monkeypatch.setattr('src.data_processing.preprocess.split_raster', mock_fn) + return mock_fn + +def test_undersample_ratio(): + """Test undersampling with different ratios.""" + df = pd.DataFrame({ + 'label': ['Bird'] * 8 + ['Rare'] * 2, + 'image_path': [f'img{i}.jpg' for i in range(10)] + }) + + # Test with ratio 0.5 + result = undersample(df, ratio=0.5) + assert len(result) < len(df) + + # Test with ratio 0 + result = undersample(df, ratio=0) + assert len(result) == 2 # Only rare class images + + # Test with ratio 1 + result = undersample(df, ratio=1) + assert len(result) == len(df) + +def test_undersample_invalid_ratio(): + """Test undersampling with invalid ratios.""" + df = pd.DataFrame({'label': ['Bird'], 'image_path': ['img1.jpg']}) + + with pytest.raises(ValueError): + undersample(df, ratio=-0.1) + + with pytest.raises(ValueError): + undersample(df, ratio=1.1) + +def test_preprocess_images(mock_split_raster, sample_annotations, tmp_path): + """Test image preprocessing.""" + # Create temporary directories + root_dir = tmp_path / "root" + save_dir = tmp_path / "save" + root_dir.mkdir() + + # Create dummy image files + for img in sample_annotations['image_path'].unique(): + (root_dir / img).touch() + + # Test preprocessing + result = preprocess_images( + sample_annotations, + str(root_dir), + str(save_dir), + patch_size=450 + ) + + assert isinstance(result, pd.DataFrame) + assert not result.empty + assert 'image_path' in result.columns + +def test_preprocess_images_invalid_params(sample_annotations, tmp_path): + """Test preprocessing with invalid parameters.""" + with pytest.raises(ValueError): + preprocess_images( + sample_annotations, + str(tmp_path), + str(tmp_path), + patch_size=-1 + ) + + with pytest.raises(ValueError): + preprocess_images( + sample_annotations, + str(tmp_path), + str(tmp_path), + patch_size=450, + patch_overlap=-1 + ) + +def test_process_image_empty(empty_annotations, tmp_path): + """Test processing images with empty annotations.""" + root_dir = tmp_path / "root" + save_dir = tmp_path / "save" + root_dir.mkdir() + save_dir.mkdir() + + # Create dummy image + (root_dir / "img2.jpg").touch() + + result = process_image( + "img2.jpg", + None, + str(root_dir), + str(save_dir), + patch_size=450, + patch_overlap=0, + allow_empty=True + ) + + assert isinstance(result, pd.DataFrame) + assert 'image_path' in result.columns + assert result['xmin'].isna().all() + +def test_process_image_existing_crops(mock_split_raster, tmp_path): + """Test processing when crops already exist.""" + save_dir = tmp_path / "save" + save_dir.mkdir() + + # Create existing crop CSV + crop_csv = save_dir / "img1.csv" + pd.DataFrame({'image_path': ['patch1.jpg']}).to_csv(crop_csv) + + result = process_image( + "img1.jpg", + None, + str(tmp_path), + str(save_dir), + patch_size=450, + patch_overlap=0, + allow_empty=True + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) > 0 + +def test_process_image_file_not_found(sample_annotations, tmp_path): + """Test processing with missing image file.""" + with pytest.raises(FileNotFoundError): + process_image( + "nonexistent.jpg", + sample_annotations, + str(tmp_path), + str(tmp_path), + patch_size=450, + patch_overlap=0, + allow_empty=False + ) + +@pytest.mark.parametrize("patch_size,patch_overlap", [ + (450, 0), + (300, 50), + (600, 100) +]) +def test_process_image_different_sizes(mock_split_raster, sample_annotations, tmp_path, patch_size, patch_overlap): + """Test processing with different patch sizes and overlaps.""" + root_dir = tmp_path / "root" + save_dir = tmp_path / "save" + root_dir.mkdir() + save_dir.mkdir() + + # Create dummy image + (root_dir / "img1.jpg").touch() + + result = process_image( + "img1.jpg", + sample_annotations, + str(root_dir), + str(save_dir), + patch_size=patch_size, + patch_overlap=patch_overlap, + allow_empty=False + ) + + assert isinstance(result, pd.DataFrame) + assert not result.empty + +@pytest.mark.integration +def test_full_preprocessing_pipeline(tmp_path): + """Integration test for the full preprocessing pipeline.""" + # Create test data + root_dir = tmp_path / "root" + save_dir = tmp_path / "save" + root_dir.mkdir() + save_dir.mkdir() + + # Create test image and annotations + (root_dir / "test.jpg").touch() + annotations = pd.DataFrame({ + 'image_path': ['test.jpg'], + 'label': ['Bird'], + 'xmin': [100], + 'ymin': [100], + 'xmax': [200], + 'ymax': [200] + }) + + try: + result = preprocess_images( + annotations, + str(root_dir), + str(save_dir), + patch_size=450 + ) + assert isinstance(result, pd.DataFrame) + except Exception as e: + pytest.fail(f"Integration test failed: {str(e)}") diff --git a/tests/test_propagate.py b/tests/test_propagate.py index 3357c49..2c8d331 100644 --- a/tests/test_propagate.py +++ b/tests/test_propagate.py @@ -1,5 +1,6 @@ import pytest import pandas as pd +import numpy as np from src.propagate import LabelPropagator from datetime import datetime import os @@ -21,54 +22,120 @@ def sample_annotations(): 'label': ['Bird', 'Bird', 'Bird', 'Bird'] }) +@pytest.fixture +def complex_annotations(): + """Create more complex annotations for testing edge cases.""" + return pd.DataFrame({ + 'image_path': [ + 'IMG_20230615_123456.jpg', + 'IMG_20230615_123456.jpg', # Multiple objects in same image + 'IMG_20230615_123457.jpg', + 'IMG_20230615_123458.jpg', + 'IMG_20230615_123506.jpg', + 'invalid_filename.jpg' # Invalid filename format + ], + 'xmin': [100, 200, 150, 200, 500, 300], + 'ymin': [100, 200, 150, 200, 500, 300], + 'xmax': [120, 220, 170, 220, 520, 320], + 'ymax': [120, 220, 170, 220, 520, 320], + 'label': ['Bird', 'Bird', 'Bird', 'Bird', 'Bird', 'Bird'] + }) + @pytest.fixture def propagator(): """Create a LabelPropagator instance.""" return LabelPropagator(time_threshold_seconds=5, distance_threshold_pixels=50) def test_propagator_initialization(propagator): - """Test propagator initialization.""" + """Test propagator initialization with different parameters.""" assert propagator.time_threshold == 5 assert propagator.distance_threshold == 50 + + # Test with different parameters + prop2 = LabelPropagator(time_threshold_seconds=10, distance_threshold_pixels=100) + assert prop2.time_threshold == 10 + assert prop2.distance_threshold == 100 def test_timestamp_parsing(propagator): - """Test timestamp parsing from filenames.""" + """Test timestamp parsing from various filename formats.""" + # Test valid filename filename = 'IMG_20230615_123456.jpg' timestamp = propagator._parse_timestamp(filename) assert timestamp == datetime(2023, 6, 15, 12, 34, 56) + + # Test invalid filename + invalid_filename = 'invalid_filename.jpg' + assert propagator._parse_timestamp(invalid_filename) is None def test_center_calculation(propagator): """Test bounding box center calculation.""" + # Test integer coordinates bbox = (100, 100, 120, 120) center = propagator._calculate_center(bbox) assert center == (110, 110) + + # Test float coordinates + bbox = (100.5, 100.5, 120.5, 120.5) + center = propagator._calculate_center(bbox) + assert center == (110.5, 110.5) def test_distance_calculation(propagator): """Test distance calculation between points.""" + # Test vertical distance point1 = (100, 100) point2 = (100, 150) distance = propagator._calculate_distance(point1, point2) assert distance == 50 + + # Test diagonal distance + point2 = (100 + 30, 100 + 40) # 3-4-5 triangle + distance = propagator._calculate_distance(point1, point2) + assert distance == 50 + +def test_temporal_neighbors(propagator, sample_annotations): + """Test finding temporal neighbors.""" + neighbors = propagator._find_temporal_neighbors(sample_annotations) + + # First three images should be neighbors + assert 'IMG_20230615_123457.jpg' in neighbors['IMG_20230615_123456.jpg'] + assert 'IMG_20230615_123458.jpg' in neighbors['IMG_20230615_123456.jpg'] + + # Last image should not be neighbor of first (10 seconds apart) + assert 'IMG_20230615_123506.jpg' not in neighbors['IMG_20230615_123456.jpg'] def test_label_propagation(propagator, sample_annotations): - """Test label propagation.""" + """Test label propagation with various scenarios.""" propagated_df = propagator.propagate_labels(sample_annotations) - # Check that propagated column exists + # Check basic properties assert 'propagated' in propagated_df.columns - - # Check that original annotations are preserved assert len(propagated_df) >= len(sample_annotations) - - # Check that propagated annotations are marked assert (propagated_df['propagated'] == True).any() + + # Check that original annotations are preserved + original_images = sample_annotations['image_path'].unique() + for img in original_images: + assert img in propagated_df['image_path'].values -def test_temporal_threshold(propagator, sample_annotations): - """Test that labels only propagate within time threshold.""" - propagated_df = propagator.propagate_labels(sample_annotations) +def test_complex_propagation(propagator, complex_annotations): + """Test label propagation with complex scenarios.""" + propagated_df = propagator.propagate_labels(complex_annotations) - # The last image is 10 seconds later, should not receive propagated labels - last_image_annotations = propagated_df[ - propagated_df['image_path'] == 'IMG_20230615_123506.jpg' + # Check handling of multiple objects in same image + first_image_annotations = propagated_df[ + propagated_df['image_path'] == 'IMG_20230615_123456.jpg' ] - assert len(last_image_annotations) == 1 # Only original annotation \ No newline at end of file + assert len(first_image_annotations) >= 2 + + # Check handling of invalid filenames + invalid_annotations = propagated_df[ + propagated_df['image_path'] == 'invalid_filename.jpg' + ] + assert len(invalid_annotations) == 1 # Should preserve original annotation + +def test_empty_dataframe(propagator): + """Test handling of empty input DataFrame.""" + empty_df = pd.DataFrame(columns=['image_path', 'xmin', 'ymin', 'xmax', 'ymax', 'label']) + result = propagator.propagate_labels(empty_df) + assert len(result) == 0 + assert 'propagated' in result.columns \ No newline at end of file