Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignoring Temporal Overlap in Spatial Intersection for Multiple Datasets #2571

Open
tpet93 opened this issue Feb 10, 2025 · 5 comments
Open
Labels
datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation

Comments

@tpet93
Copy link

tpet93 commented Feb 10, 2025

Issue

When we have one set of images for ['image'] and another set of images for ['mask']
It is difficult to get each combination of intersections using an IntersectionDataset.

Apologies if this has been covered somewhere, but I haven't been able to find it.

Use Case

  • Dataset 1: ['image'] — multiple overlapping images with independent timestamps (e.g., Sentinel-2 from Jan-2023 to Mar-2023). (A,B,C,D)

  • Dataset 2: ['mask'] — multiple overlapping images with independent timestamps (e.g., SAR from Apr-2023 to May-2023). (1,2)

Image

For the sake of this example, no temporally overlapping data exists. Each dataset has many overlapping regions, and not every image overlaps with every other image within the same dataset.

Problem Description

  1. If both datasets are imported with timestamps, there is no spatiotemporal overlap (due to distinct time ranges).
  2. If timestamps in Dataset 2 are zeroed out, each pair of potential spatial overlaps is recognized as expected.
  3. However, the "hit" or query generated in the R-tree index cannot distinguish between overlapping images in Dataset 2, because the timestamp in the ROI refers only to Dataset 1.

Goal

  • Obtain all possible combinations of Dataset 1 and Dataset 2 where a spatial overlap exists.
    in the example image this would be: [A1,A2,B1,B2,C1,C2,D1]

Questions

  1. Is there a recommended process to accomplish a purely spatial intersection while preserving original timestamps?
  2. It almost seems we need two ROIs with matching x, y but different t values to identify one image from each dataset. Has anyone tackled a similar scenario?

Any insights or references are greatly appreciated!

Fix

No response

@tpet93 tpet93 added the documentation Improvements or additions to documentation label Feb 10, 2025
@adamjstewart
Copy link
Collaborator

Ignoring Temporal Overlap in Spatial Intersection for Multiple Datasets

We're actively working on adding spatial-only intersection capabilities for #2382. Will keep you updated when that work is complete.

In the meantime, the easiest hack you can use to make this work at the moment is to disable time information. For example, if you are using Sentinel-2:

class Sentinel2SpatialOnly(Sentinel2):
    filename_regex = r"""                                                                
        ^T(?P<tile>\d{{2}}[A-Z]{{3}})                                                    
        _(\d{{8}}T\d{{6}})                                                       
        _(?P<band>B[018][\dA])                                                           
        (?:_(?P<resolution>{}m))?                                                        
        \..*$                                                                            
    """ 

This takes the same regex used by the base class but removes the <date> tag so that TorchGeo thinks the filename does not contain any date information. Something similar could be done for the mask dataset, although disabling datetime in a single dataset is sufficient.

@adamjstewart adamjstewart added the datasets Geospatial or benchmark datasets label Feb 10, 2025
@tpet93
Copy link
Author

tpet93 commented Feb 10, 2025

I am already doing as suggested above, however if there are multiple overlapping images in the dataset that has the date regex modified there is no way to query the multiple images of that dataset by time. it always returns the same overlap.

#2408 look relevant.
I'm thinking I need to make a TimeIntersectionDataset class.
It would contain a few bits from the standard IntersectionDataset class in regards to calculating the possible intersections. and then override the query's for each sub-dataset.

I'll Join the Slack #time-series channel and see were I get to.

@sabman
Copy link

sabman commented Feb 10, 2025

I suspect you can also try to write a custom bounding box and intersection function that doesn't use temporal comparison. If you just need it to filter datasets it should be relatively easy. Some pseudocode:

nonTemporalBoundsABCD = NonTemporalBounds(datasetABCD.index)
nonTemporalBounds12 = NonTemporalBounds(dataset12.index)
pairs = nonTemporalBoundsABCD.intersection(nonTemporalBounds12)

Let me know if you end up doing this. I'd be interested.

@adamjstewart
Copy link
Collaborator

Yes, that would also work.

Also note that you don't have to use IntersectionDataset, you could pass a single dataset to your GeoSampler. This works well when both datasets have the same spatial bounds but non-overlapping times. However, in your situation, your images and masks would need to be preprocessed to ensure they have the same spatial bounds.

@tpet93
Copy link
Author

tpet93 commented Feb 11, 2025

For those following along I've made some progress on a proof of concept but it does have some breaking changes to the standard.

Summary of code:

  • override of IntersectionDataset
  • one internal index that is built using a time_error component to calculate intersections
  • one internal index for each of the 2 datasets with accurate timestamps in order to retrieve single image.
  • intersections are currently pulled out of class.index but could be delivered in a PreChippedGeoSampler fashion.
  • a unique query consists of a list of 2 BoundingBoxes (x,y should be equal , timestamps should be unique to the desired input image for each dataset respectively i.e
Query =  [
    BoundingBox(minx=294202.6321444819, maxx=294254.4917857733, miny=6650867.357695428, maxy=6680903.472785138, mint=1710289199.0, maxt=1710289199.999999),
    BoundingBox(minx=294202.6321444819, maxx=294254.4917857733, miny=6650867.357695428, maxy=6680903.472785138, mint=1722990059.0, maxt=1722990059.999999)
    ]

This seems to potentially have some factors in common with #2048 but also some differences

I intend to use this code like a PreChippedGeoSampler and then run GridGeoSamper on each hit.

I feel like mixing multiple multi time datasets is a common use case without being interested in "true" time series data.

Questions:

  • Open to input on how to better structure the implementation. (sampler vs dataset)

  • Currently any class that "wraps" over the top of this will need to allow a query list to function.

  • how to implement this functionality long term

  • is there a preferred test dataset to test this code on to ensure expected answers

from torchgeo.datasets import GeoDataset

# Sequence,concat_samples,Callable,CRS\
from typing import Sequence, Callable, Union
from pyproj import CRS
from torchgeo.datasets.utils import  concat_samples
from rtree.index import Index, Property
import datetime



class TimeIntersectionDataset(IntersectionDataset):
    """Dataset representing the intersection of two GeoDatasets.

    This allows users to do things like:

    * Combine image and target labels and sample from both simultaneously
      (e.g., Landsat and CDL)
    * Combine datasets for multiple image sources for multimodal learning or data fusion
      (e.g., Landsat and Sentinel)
    * Combine image and other raster data (e.g., elevation, temperature, pressure)
      and sample from both simultaneously (e.g., Landsat and Aster Global DEM)

    These combinations require that all queries are present in *both* datasets,
    and can be combined using an :class:`IntersectionDataset`:

    .. code-block:: python

       dataset = landsat & cdl

    .. versionadded:: 0.2
    """

    time_error = sys.maxsize

    def __init__(
        self,
        dataset1: GeoDataset,
        dataset2: GeoDataset,
        collate_fn: Callable[
            [Sequence[dict[str, Any]]], dict[str, Any]
        ] = concat_samples,
        transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
    ) -> None:
        """Initialize a new IntersectionDataset instance.

        When computing the intersection between two datasets that both contain model
        inputs (such as images) or model outputs (such as masks), the default behavior
        is to stack the data along the channel dimension. The *collate_fn* parameter
        can be used to change this behavior.

        Args:
            dataset1: the first dataset
            dataset2: the second dataset
            collate_fn: function used to collate samples
            transforms: a function/transform that takes input sample and its target as
                entry and returns a transformed version

        Raises:
            RuntimeError: if datasets have no spatiotemporal intersection
            ValueError: if either dataset is not a :class:`GeoDataset`

        .. versionadded:: 0.4
            The *transforms* parameter.
        """

        # main index only uses spatial bounds

        # these use spatiotemporal bounds but store the timestamp of dataset1 and dataset2 respectively
        self.st_index1 = Index(interleaved=False, properties=Property(dimension=3))
        self.st_index2 = Index(interleaved=False, properties=Property(dimension=3))
        self.datasets = [dataset1, dataset2]

        super().__init__(dataset1, dataset2, transforms=transforms)
        # cleasr index
        self.index = Index(interleaved=False, properties=Property(dimension=3))
        self.collate_fn = collate_fn

        for ds in self.datasets:
            if not isinstance(ds, GeoDataset):
                raise ValueError('IntersectionDataset only supports GeoDatasets')

        self.crs = dataset1.crs
        self.res = dataset1.res

        # Merge dataset indices into a single index
        self._merge_dataset_indices()

    def _merge_dataset_indices(self) -> None:
        """Create a new R-tree out of the individual indices from two datasets."""

        i = 0
        ds1, ds2 = self.datasets
        te = self.time_error

        def adjust_time_bounds(bounds, time_error):
            """Returns a new BoundingBox with adjusted min/max time bounds."""
            return BoundingBox(
                bounds[0], bounds[1], bounds[2], bounds[3],
                bounds[4] - time_error, bounds[5] + time_error
            )
        
        for hit1 in ds1.index.intersection(tuple(ds1.index.bounds), objects=True):
            # add time_error
            spatial_bbox1 = adjust_time_bounds(hit1.bounds, te)

            # for hit2 in ds2.index.intersection(spatial1_bounds, objects=True):
            for hit2 in ds2.index.intersection(tuple(spatial_bbox1), objects=True):

                # add time_error, TODO: does this double our time error distance?, perhaps we dont need to do this one
                spatial_bbox2 = adjust_time_bounds(hit2.bounds, te)

                # spatial_box3 = spatial_bbox1 & spatial_bbox2
                intersection_box = spatial_bbox1 & spatial_bbox2


                # Skip 0 area overlap (unless 0 area dataset)
                if intersection_box.area > 0 or spatial_bbox1.area == 0 or spatial_bbox2.area == 0:


                    #create index entry for dataset 1 that has exact timestamp range
                    time_hit1 = BoundingBox(
                        intersection_box[0], intersection_box[1], intersection_box[2], intersection_box[3],
                        hit1.bounds[4], hit1.bounds[5]
                    )

                    time_hit2 = BoundingBox(
                        intersection_box[0], intersection_box[1], intersection_box[2], intersection_box[3],
                        hit2.bounds[4], hit2.bounds[5]
                    )

                    self.index.insert(i, tuple(intersection_box))
                    self.st_index1.insert(i, tuple(time_hit1))
                    self.st_index2.insert(i, tuple(time_hit2))

                    i += 1


        if i == 0:
            raise RuntimeError('Datasets have no spatiotemporal intersection')

    def __getitem__(self, query: Union[BoundingBox, list[BoundingBox]]) -> dict[str, Any]:
        """Retrieve image and metadata indexed by a single query or a list of queries.

        Args:
            query: A single BoundingBox or a list of BoundingBox objects

        Raises:
            IndexError: if query is not within bounds of the index
        """


        if isinstance(query, BoundingBox):
            if not query.intersects(self.bounds):
                raise IndexError(
                    f'query: {query} not found in index with bounds: {self.bounds}'
                )
                
            query = [query]  # Ensure we always work with a list
          
        
        # Retrieve samples for each bounding box query
        samples = []

        #breaks if len query is greater than len datasets
        for i in range(len(query)):
            if not query[i].intersects(self.datasets[i].bounds):
                raise IndexError(
                    f'query: {query[i]} not found in index with bounds: {self.datasets[i].bounds}'
                )
            sample = self.datasets[i][query[i]]
            
            samples.append(sample)
       
        sample = self.collate_fn(samples)

        if self.transforms is not None:
            sample = self.transforms(sample)

        return sample

example usage:

TimeIntersectionDataset.time_error = datetime.timedelta(days=180).total_seconds()
# TimeIntersectionDataset.time_error = sys.maxsize

tids = TimeIntersectionDataset(ds1,ds2)

hits = []
for hit in tids.index.intersection(tuple(tids.bounds), objects=True):
    hits.append(BoundingBox(*hit.bounds))
#length and content of hits should be as if ds1 and ds2 had timestamps disabled if time error is very large
#contains many duplicate bboxs which should refer to a different timestamp



#TODO: have class deliver the triplet of indexes

hits = []
for hit in tids.index.intersection(tuple(tids.bounds), objects=True):
    hits.append(BoundingBox(*hit.bounds))

hits1 = []
for hit in tids.st_index1.intersection(tuple(tids.bounds), objects=True):
    hits1.append(BoundingBox(*hit.bounds))

hits2 = []
for hit in tids.st_index2.intersection(tuple(tids.bounds), objects=True):
    hits2.append(BoundingBox(*hit.bounds))


for i in range (0,len(hits)):
    query = [hits1[i],hits2[i]]
    sample = tids[query]

# returns a unique sample each time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants