Skip to content

Commit

Permalink
feat: add bands
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Sep 27, 2023
1 parent 9c323c4 commit 3d8e4cf
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pystac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"STACObjectType",
"Link",
"HIERARCHICAL_LINKS",
"Band",
"Catalog",
"CatalogType",
"Collection",
Expand Down Expand Up @@ -75,6 +76,7 @@
SpatialExtent,
TemporalExtent,
)
from pystac.band import Band
from pystac.common_metadata import CommonMetadata
from pystac.summaries import RangeSummary, Summaries
from pystac.asset import Asset
Expand Down
22 changes: 22 additions & 0 deletions pystac/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union

from pystac import common_metadata, utils
from pystac.band import Band
from pystac.html.jinja_env import get_jinja_env

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,13 +72,15 @@ def __init__(
description: Optional[str] = None,
media_type: Optional[str] = None,
roles: Optional[List[str]] = None,
bands: Optional[List[Band]] = None,
extra_fields: Optional[Dict[str, Any]] = None,
) -> None:
self.href = utils.make_posix_style(href)
self.title = title
self.description = description
self.media_type = media_type
self.roles = roles
self._bands = bands
self.extra_fields = extra_fields or {}

# The Item which owns this Asset.
Expand Down Expand Up @@ -113,6 +116,16 @@ def get_absolute_href(self) -> Optional[str]:
return utils.make_absolute_href(self.href, item_self)
return None

@property
def bands(self) -> Optional[List[Band]]:
if self._bands is None and self.owner is not None:
return self.owner.bands
return self._bands

@bands.setter
def bands(self, bands: Optional[List[Band]]) -> None:
self._bands = bands

def to_dict(self) -> Dict[str, Any]:
"""Returns this Asset as a dictionary.
Expand All @@ -138,6 +151,9 @@ def to_dict(self) -> Dict[str, Any]:
if self.roles is not None:
d["roles"] = self.roles

if self.bands is not None:
d["bands"] = [band.to_dict() for band in self.bands]

return d

def clone(self) -> Asset:
Expand Down Expand Up @@ -201,6 +217,11 @@ def from_dict(cls: Type[A], d: Dict[str, Any]) -> A:
title = d.pop("title", None)
description = d.pop("description", None)
roles = d.pop("roles", None)
bands = d.pop("bands", None)
if bands is None:
deserialized_bands = None
else:
deserialized_bands = [Band.from_dict(band) for band in bands]
properties = None
if any(d):
properties = d
Expand All @@ -211,6 +232,7 @@ def from_dict(cls: Type[A], d: Dict[str, Any]) -> A:
title=title,
description=description,
roles=roles,
bands=deserialized_bands,
extra_fields=properties,
)

Expand Down
45 changes: 45 additions & 0 deletions pystac/band.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Optional


@dataclass
class Band:
"""A name and some properties that apply to a band (aka subasset)."""

name: str
"""The name of the band (e.g., "B01", "B8", "band2", "red").
This should be unique across all bands defined in the list of bands. This is
typically the name the data provider uses for the band.
"""

description: Optional[str] = None
"""Description to fully explain the band.
CommonMark 0.29 syntax MAY be used for rich text representation.
"""

properties: Dict[str, Any] = field(default_factory=dict)
"""Other properties on the band."""

@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Band:
"""Creates a new band object from a dictionary."""
try:
name = d.pop("name")
except KeyError:
raise ValueError("missing required field on band: name")
description = d.pop("description", None)
return Band(name=name, description=description, properties=d)

def to_dict(self) -> Dict[str, Any]:
"""Creates a dictionary from this band object."""
d = {
"name": self.name,
}
if self.description is not None:
d["description"] = self.description
d.update(self.properties)
return d
26 changes: 26 additions & 0 deletions pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pystac
from pystac import CatalogType, STACError, STACObjectType
from pystac.asset import Asset
from pystac.band import Band
from pystac.catalog import Catalog
from pystac.errors import DeprecatedWarning, ExtensionNotImplemented, STACTypeError
from pystac.layout import HrefLayoutStrategy
Expand Down Expand Up @@ -517,6 +518,8 @@ class Collection(Catalog):
"""Default file name that will be given to this STAC object
in a canonical format."""

_bands: Optional[List[Band]]

def __init__(
self,
id: str,
Expand All @@ -532,6 +535,7 @@ def __init__(
providers: Optional[List[Provider]] = None,
summaries: Optional[Summaries] = None,
assets: Optional[Dict[str, Asset]] = None,
bands: Optional[List[Band]] = None,
):
super().__init__(
id,
Expand All @@ -555,6 +559,8 @@ def __init__(
for k, asset in assets.items():
self.add_asset(k, asset)

self._bands = bands

def __repr__(self) -> str:
return "<Collection id={}>".format(self.id)

Expand Down Expand Up @@ -588,6 +594,9 @@ def to_dict(
if any(self.assets):
d["assets"] = {k: v.to_dict() for k, v in self.assets.items()}

if self.bands is not None:
d["bands"] = [band.to_dict() for band in self.bands]

return d

def clone(self) -> Collection:
Expand Down Expand Up @@ -664,6 +673,12 @@ def from_dict(
assets = {k: Asset.from_dict(v) for k, v in assets.items()}
links = d.pop("links")

bands = d.pop("bands", None)
if bands is not None:
deserialized_bands = [Band.from_dict(band) for band in bands]
else:
deserialized_bands = None

d.pop("stac_version")

collection = cls(
Expand All @@ -680,6 +695,7 @@ def from_dict(
href=href,
catalog_type=catalog_type,
assets=assets,
bands=deserialized_bands,
)

for link in links:
Expand Down Expand Up @@ -830,3 +846,13 @@ def full_copy(
@classmethod
def matches_object_type(cls, d: Dict[str, Any]) -> bool:
return identify_stac_object_type(d) == STACObjectType.COLLECTION

@property
def bands(self) -> Optional[List[Band]]:
"""Returns the bands set on this collection."""
return self._bands

@bands.setter
def bands(self, bands: Optional[List[Band]]) -> None:
"""Sets the bands on this collection."""
self._bands = bands
26 changes: 26 additions & 0 deletions pystac/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pystac
from pystac import RelType, STACError, STACObjectType
from pystac.asset import Asset
from pystac.band import Band
from pystac.catalog import Catalog
from pystac.collection import Collection
from pystac.errors import DeprecatedWarning, ExtensionNotImplemented
Expand Down Expand Up @@ -106,6 +107,8 @@ class Item(STACObject):
stac_extensions: List[str]
"""List of extensions the Item implements."""

_bands: Optional[List[Band]]

STAC_OBJECT_TYPE = STACObjectType.ITEM

def __init__(
Expand All @@ -122,6 +125,7 @@ def __init__(
collection: Optional[Union[str, Collection]] = None,
extra_fields: Optional[Dict[str, Any]] = None,
assets: Optional[Dict[str, Asset]] = None,
bands: Optional[List[Band]] = None,
):
super().__init__(stac_extensions or [])

Expand Down Expand Up @@ -167,6 +171,8 @@ def __init__(
for k, asset in assets.items():
self.add_asset(k, asset)

self._bands = bands

def __repr__(self) -> str:
return "<Item id={}>".format(self.id)

Expand Down Expand Up @@ -406,6 +412,16 @@ def get_derived_from(self) -> List[Item]:
"Link failed to resolve. Use get_links instead."
) from e

@property
def bands(self) -> Optional[List[Band]]:
"""Returns the bands set on this item."""
return self._bands

@bands.setter
def bands(self, bands: Optional[List[Band]]) -> None:
"""Sets the bands on this item."""
self._bands = bands

def to_dict(
self, include_self_link: bool = True, transform_hrefs: bool = True
) -> Dict[str, Any]:
Expand Down Expand Up @@ -442,6 +458,9 @@ def to_dict(
for key in self.extra_fields:
d[key] = self.extra_fields[key]

if self.bands is not None:
d["properties"]["bands"] = [band.to_dict() for band in self.bands]

return d

def clone(self) -> Item:
Expand Down Expand Up @@ -516,13 +535,20 @@ def from_dict(
if k not in [*pass_through_fields, *parse_fields, *exclude_fields]
}

bands = properties.pop("bands", None)
if bands is not None:
deserialized_bands = [Band.from_dict(d) for d in bands]
else:
deserialized_bands = None

item = cls(
**{k: d.get(k) for k in pass_through_fields}, # type: ignore
datetime=datetime,
properties=properties,
extra_fields=extra_fields,
href=href,
assets={k: Asset.from_dict(v) for k, v in assets.items()},
bands=deserialized_bands,
)

for link in links:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pystac
from pystac import (
Asset,
Band,
Catalog,
CatalogType,
Collection,
Expand Down Expand Up @@ -670,3 +671,32 @@ def test_permissive_temporal_extent_deserialization(collection: Collection) -> N
]["interval"][0]
with pytest.warns(UserWarning):
Collection.from_dict(collection_dict)


def test_set_bands_on_collection(collection: Collection) -> None:
collection.add_asset("data", Asset(href="example.tif"))
collection.bands = [Band(name="analytic")]
assert collection.assets["data"].bands
assert collection.assets["data"].bands[0].name == "analytic"


def test_bands_roundtrip_on_asset(collection: Collection) -> None:
collection.add_asset("data", Asset(href="example.tif"))
collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False)
collection_dict["assets"]["data"]["bands"] = [{"name": "data"}]
collection = Collection.from_dict(collection_dict)
assert collection.assets["data"].bands
assert collection.assets["data"].bands[0].name == "data"
collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False)
assert collection_dict["assets"]["data"]["bands"][0]["name"] == "data"


def test_bands_roundtrip_on_collection(collection: Collection) -> None:
collection.add_asset("data", Asset(href="example.tif"))
collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False)
collection_dict["bands"] = [{"name": "data"}]
collection = Collection.from_dict(collection_dict)
assert collection.assets["data"].bands
assert collection.assets["data"].bands[0].name == "data"
collection_dict = collection.to_dict(include_self_link=False, transform_hrefs=False)
assert collection_dict["bands"][0]["name"] == "data"
41 changes: 40 additions & 1 deletion tests/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import pystac
import pystac.serialization.common_properties
from pystac import Asset, Catalog, Collection, Item, Link
from pystac import Asset, Band, Catalog, Collection, Item, Link
from pystac.utils import (
datetime_to_str,
get_opt,
Expand Down Expand Up @@ -636,3 +636,42 @@ def test_pathlib() -> None:
# This works, but breaks mypy until we fix
# https://github.com/stac-utils/pystac/issues/1216
Item.from_file(Path(TestCases.get_path("data-files/item/sample-item.json")))


def test_bands_do_not_exist(sample_item: Item) -> None:
sample_item.assets["analytic"].bands is None


def test_set_bands(sample_item: Item) -> None:
sample_item.assets["analytic"].bands = [Band(name="analytic")]
assert sample_item.assets["analytic"].bands[0].name == "analytic"


def test_set_bands_on_item(sample_item: Item) -> None:
sample_item.bands = [Band(name="analytic")]
assert sample_item.assets["analytic"].bands
assert sample_item.assets["analytic"].bands[0].name == "analytic"


def test_bands_roundtrip_on_asset(sample_item: Item) -> None:
sample_item_dict = sample_item.to_dict(
include_self_link=False, transform_hrefs=False
)
sample_item_dict["assets"]["analytic"]["bands"] = [{"name": "analytic"}]
item = Item.from_dict(sample_item_dict)
assert item.assets["analytic"].bands
assert item.assets["analytic"].bands[0].name == "analytic"
item_dict = item.to_dict(include_self_link=False, transform_hrefs=False)
assert item_dict["assets"]["analytic"]["bands"][0]["name"] == "analytic"


def test_bands_roundtrip_on_item(sample_item: Item) -> None:
sample_item_dict = sample_item.to_dict(
include_self_link=False, transform_hrefs=False
)
sample_item_dict["properties"]["bands"] = [{"name": "analytic"}]
item = Item.from_dict(sample_item_dict)
assert item.assets["analytic"].bands
assert item.assets["analytic"].bands[0].name == "analytic"
item_dict = item.to_dict(include_self_link=False, transform_hrefs=False)
assert item_dict["properties"]["bands"][0]["name"] == "analytic"

0 comments on commit 3d8e4cf

Please sign in to comment.