Source code for satellitetools.common.sentinel2

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Classes for handling Sentinel-2 data.

@author: Olli Nevalainen (Finnish Meteorological Institute)

"""

import logging
from dataclasses import dataclass
from typing import Dict, List, Optional

try:
    # breaking change introduced in python 3.11
    from enum import Enum, StrEnum
except ImportError:
    from enum import Enum

[docs] class StrEnum(str, Enum): pass
import numpy as np import pandas as pd import xarray as xr from satellitetools.common.classes import AOI, DataSource
[docs] logger = logging.getLogger(__name__)
[docs] S2_REFL_TRANS = 10000
[docs] SCL_NODATA = 99
[docs] SPECTRAL_BAND_NO_DATA = np.nan
# Use GEE band naming
[docs] class S2Band(StrEnum): """Sentinel-2 bands."""
[docs] B1 = "B1"
[docs] B2 = "B2"
[docs] B3 = "B3"
[docs] B4 = "B4"
[docs] B5 = "B5"
[docs] B6 = "B6"
[docs] B7 = "B7"
[docs] B8 = "B8"
[docs] B8A = "B8A"
[docs] B9 = "B9"
[docs] B11 = "B11"
[docs] B12 = "B12"
[docs] AOT = "AOT"
[docs] WVP = "WVP"
[docs] SCL = "SCL"
[docs] def to_aws(self) -> str: """Convert band name to AWS band name. Returns: ---------------- str Band name in AWS. """ band_name_in_aws = S2_BANDS_GEE_TO_AWS[self.value] return band_name_in_aws
[docs] def to_gee(self) -> str: """Convert band name to GEE band name. Returns: ---------------- str Band name in GEE. """ return self.value
@classmethod
[docs] def get_10m_to_20m_bands(cls) -> List["S2Band"]: """Get 10-20 m bands for the band. Returns: ---------------- List[S2Band] List of 10-20 m bands. """ return [S2Band(b) for b in S2_BANDS_10_20_GEE]
@classmethod
[docs] def get_all_bands(cls) -> List["S2Band"]: """Get all bands for the band. Returns: ---------------- List[S2Band] List of all bands. """ return [b for b in S2Band]
[docs] class SCLClass(Enum): """Sentinel-2 Scene Classification Layer (SCL) classes."""
[docs] NODATA = 0
[docs] SATURATED_DEFECTIVE = 1
[docs] DARK_FEATURE_SHADOW = 2
[docs] CLOUD_SHADOW = 3
[docs] VEGETATION = 4
[docs] NOT_VEGETATED = 5
[docs] WATER = 6
[docs] UNCLASSIFIED = 7
[docs] CLOUD_MEDIUM_PROBA = 8
[docs] CLOUD_HIGH_PROBA = 9
[docs] THIN_CIRRUS = 10
[docs] SNOW_ICE = 11
[docs] S2_BANDS_GEE = [ "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12", "AOT", "WVP", "SCL", ]
[docs] S2_BANDS_10_20_GEE = [ "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B11", "B12", ]
# Old v0 earth search # S2_BANDS_COG = [ # "B01", # "B02", # "B03", # "B04", # "B05", # "B06", # "B07", # "B08", # "B8A", # "B09", # "B11", # "B12", # ] # # 10-20 m bands # S2_BANDS_10_20_COG = ["B02", "B03", "B04", "B05", "B06", "B07", "B8A", "B11", "B12"]
[docs] S2_BANDS_COG = [ "coastal", # B1 "blue", # B2 "green", # B3 "red", # B4 "rededge1", # B5 "rededge2", # B6 "rededge3", # B7 "nir", # B8 "nir08", # B8A "nir09", # B9 "swir16", # B11 "swir22", # B12 "aot", "wvp", "scl", ]
[docs] S2_BANDS_10_20_COG = [ "blue", "green", "red", "rededge1", "rededge2", "rededge3", "nir08", "swir16", "swir22", ]
[docs] S2_BANDS_AWS_TO_GEE = { aws: gee for aws, gee in zip(S2_BANDS_COG, S2_BANDS_GEE, strict=True) }
# S2_BANDS_10_20_AWS_TO_GEE = { # aws: gee for aws, gee in zip(S2_BANDS_10_20_COG, S2_BANDS_10_20_GEE) # }
[docs] S2_BANDS_GEE_TO_AWS = { gee: aws for gee, aws in zip(S2_BANDS_GEE, S2_BANDS_COG, strict=True) }
# S2_BANDS_10_20_GEE_TO_AWS = { # gee: aws for gee, aws in zip(S2_BANDS_10_20_GEE, S2_BANDS_10_20_COG) # }
[docs] S2_SCL_CLASSES = [c.name for c in SCLClass]
[docs] S2_FILTER1 = [ SCLClass.NODATA, SCLClass.SATURATED_DEFECTIVE, SCLClass.CLOUD_SHADOW, SCLClass.UNCLASSIFIED, SCLClass.CLOUD_MEDIUM_PROBA, SCLClass.CLOUD_HIGH_PROBA, SCLClass.THIN_CIRRUS, SCLClass.SNOW_ICE, ]
[docs] S2_FILTER2 = [ SCLClass.NODATA, SCLClass.SATURATED_DEFECTIVE, SCLClass.CLOUD_SHADOW, SCLClass.CLOUD_MEDIUM_PROBA, SCLClass.CLOUD_HIGH_PROBA, SCLClass.THIN_CIRRUS, SCLClass.SNOW_ICE, ]
[docs] class Sentinel2RequestParams: """S2 data request paramaters. Attributes: ---------------- datestart : str Start date of the data request. dateend : str End date of the data request. datasource : DataSource Data source for the request. bands : List[S2Band] List of Sentinel-2 bands to request. target_gsd : float Target ground sampling distance (GSD) in meters. qi_evaluation_scale : float Quality indicator evaluation scale. """ def __init__( self, datestart: str, dateend: str, datasource: DataSource, bands: List[S2Band] = None, target_gsd: float = 20, qi_evaluation_scale: float = 20, ): """Initialize the Sentinel2RequestParams class. Parameters: ---------------- datestart : str Start date of the data request. dateend : str End date of the data request. datasource : DataSource Data source for the request. bands : List[S2Band] List of Sentinel-2 bands to request. target_gsd : float Target ground sampling distance (GSD) in meters. qi_evaluation_scale : float Quality indicator evaluation scale. """
[docs] self.datestart = datestart
[docs] self.dateend = dateend
[docs] self.datasource = datasource
if bands: # Validate input and convert to S2Band if strings self.bands = [S2Band(band) for band in bands] else: self.bands = [band for band in S2Band]
[docs] self.target_gsd = 20 if target_gsd is None else target_gsd
[docs] self.qi_evaluation_scale = ( 20 if qi_evaluation_scale is None else qi_evaluation_scale )
[docs] def __repr__(self) -> str: return ( f"RequestParams(datestart={self.datestart}, dateend={self.dateend}, " f"datasource={self.datasource}, bands={self.bands}, " f"target_gsd={self.target_gsd}, qi_evaluation_scale={self.qi_evaluation_scale})" )
[docs] class Sentinel2Metadata: """Sentinel-2 metadata class. Attributes: ---------------- acquisition_time : pd.Timestamp Acquisition time of the image. tileid : str Sentinel-2 tile ID. assetid : str Asset ID of the image. productid : str Product ID of the image. projection : str Projection of the image. datasource : DataSource Data source of the image. observation_geometry : Optional[Sentinel2ObservationGeometry] Observation geometry of the image. class_percentages : Optional[Dict[SCLClass, float]] Class percentages of the image. """ def __init__( self, acquisition_time: pd.Timestamp, tileid: str, assetid: str, productid: str, projection: str, datasource: DataSource, observation_geometry: Optional["Sentinel2ObservationGeometry"] = None, class_percentages: Optional[Dict[SCLClass, float]] = None, ): """Initialize the Sentinel2Metadata class. Parameters: ---------------- acquisition_time : pd.Timestamp Acquisition time of the image. tileid : str Sentinel-2 tile ID. assetid : str Asset ID of the image. productid : str Product ID of the image. projection : str Projection of the image. datasource : DataSource Data source of the image. """
[docs] self.acquisition_time = acquisition_time
[docs] self.tileid = tileid
[docs] self.assetid = assetid
[docs] self.productid = productid
[docs] self.projection = projection
[docs] self.datasource = datasource
# Placeholder for observation geometry
[docs] self.observation_geometry = ( observation_geometry if observation_geometry else None )
# Placeholder for class percentages
[docs] self.class_percentages = class_percentages if class_percentages else None
[docs] def __repr__(self) -> str: return ( f"Sentinel2Metadata(assetid={self.assetid}, " f"acquisition_time={self.acquisition_time})," f"datasource={self.datasource}" )
@dataclass
[docs] class Sentinel2ObservationGeometry: """Sentinel-2 observation geometry class. Attributes: ---------------- sun_azimuth : float Sun azimuth angle. sun_zenith : float Sun zenith angle. view_azimuth : float View azimuth angle. view_zenith : float View zenith angle. """
[docs] sun_azimuth: float
[docs] sun_zenith: float
[docs] view_azimuth: float
[docs] view_zenith: float
@dataclass
[docs] class Coordinates: """Coordinates class. Attributes: ---------------- x : list[float] List of x coordinates. y : list[float] List of y coordinates. """
[docs] x: list[float]
[docs] y: list[float]
[docs] class Sentinel2Item: """Sentinel-2 data item class. Attributes: ---------------- metadata : Sentinel2Metadata Metadata of the data item. data : Optional[Dict[S2Band, np.ndarray]] Data of the data item. coordinates : Dict[S2Band, Coordinates] Coordinates of the data item. """ def __init__( self, metadata: Sentinel2Metadata, data: Optional[Dict[S2Band, np.ndarray]] = None, ): """Initialize the Sentinel2Item class. Parameters: ---------------- metadata : Sentinel2Metadata Metadata of the data item. data : Optional[Dict[S2Band, np.ndarray]] Data of the data item. """
[docs] self.metadata = metadata
[docs] self.data = data if data else {}
[docs] self.coordinates: Dict[S2Band, Coordinates] = {}
[docs] def __repr__(self): return ( f"Sentinel2Item(productid={self.metadata.productid}, " f"time={self.metadata.acquisition_time}," f"data={self.data.keys()})" )
[docs] def set_coordinates(self, band: S2Band, xs: list, ys: list): """Set coordinates for the data item. Parameters: ---------------- band : S2Band Sentinel-2 band. xs : list List of x coordinates. ys : list List of y coordinates. """ self.coordinates[band] = Coordinates(x=xs, y=ys)
[docs] class Sentinel2DataCollection: """Sentinel-2 data collection class. Attributes: ---------------- aoi : AOI Area of interest. req_params : Sentinel2RequestParams Request parameters. quality_information : pd.DataFrame Quality information. xr_dataset : xr.Dataset Xarray dataset. s2_items : List[Sentinel2Item] List of Sentinel-2 items. """ def __init__( self, aoi: AOI, req_params: Sentinel2RequestParams, ): """Initialize the Sentinel2DataCollection class. Parameters: ---------------- aoi : AOI Area of interest. req_params : Sentinel2RequestParams Request parameters. """
[docs] self.aoi = aoi
[docs] self.req_params = req_params
[docs] self.quality_information: pd.DataFrame = None
[docs] self.xr_dataset: xr.Dataset = None
[docs] self.s2_items: List[Sentinel2Item] = None
# Print information about the data collection
[docs] def __repr__(self): return ( f"Sentinel2DataCollection(AOI.name={self.aoi.name}, " f"Sentinel2RequestParams={self.req_params})" )
[docs] def check_s2_items_exist(self) -> bool: if self.s2_items is None: logger.info("No Sentinel-2 items searched (s2_items=None).") return False elif len(self.s2_items) == 0: logger.info( "No Sentinel-2 items found for the time period or " "left after filtering (s2_items=[])." ) return False return True
[docs] def create_quality_information(self): """Create quality information dataframe.""" # Check that s2_items are available if not self.s2_items: logger.info("No Sentinel-2 items to create quality information.") return None qi_dicts = [] for s2_item in self.s2_items: qi_dict = s2_item.metadata.__dict__ qi_dict.update(s2_item.metadata.class_percentages) # Remove class percentages key as class percentages are added to the dictionary qi_dict.pop("class_percentages") # Remove observation geometry qi_dict.pop("observation_geometry") qi_dicts.append(qi_dict) # Make qi dataframe df_qi = pd.DataFrame(qi_dicts) # Make acquisition time utc aware df_qi["acquisition_time"] = pd.to_datetime(df_qi.acquisition_time, utc=True) df_qi.sort_values("acquisition_time", inplace=True) # Set acquisition time as index df_qi.set_index("acquisition_time", inplace=True) self.quality_information = df_qi
[docs] def filter_s2_items_by_tile(self, tileid: Optional[str] = None): """Filter S2 items by tile. Parameters: ---------------- tileid : str Tile ID. """ if not self.s2_items: logger.debug("No Sentinel-2 items to filter by tile.") return None if tileid is not None: self.s2_items = [ s2_item for s2_item in self.s2_items if s2_item.metadata.tileid == tileid ] else: if self.aoi.tile is None: all_tiles = [s2_items.metadata.tileid for s2_items in self.s2_items] most_common_tile = most_common(all_tiles) self.aoi.tile = most_common_tile self.s2_items = [ s2_item for s2_item in self.s2_items if s2_item.metadata.tileid == self.aoi.tile ]
[docs] def filter_s2_items_by_quality_information( self, qi_threshold: float = 0, qi_filter: List[SCLClass] = S2_FILTER1 ): """Filter S2 items based on quality information. Parameters: ---------------- qi_threshold : float Quality indicator threshold. qi_filter : List[SCLClass] Quality indicator filter. """ if self.quality_information is None: logger.debug("Quality information not available for filtering.") return None filtered_qi = filter_s2_qi_dataframe( self.quality_information, qi_threshold, qi_filter ) # IDs for images passing the quality filter assetids = filtered_qi["assetid"].values.tolist() # Remove items that do not pass the quality self.s2_items = [ s2_item for s2_item in self.s2_items if s2_item.metadata.assetid in assetids ]
[docs] def filter_s2_items( self, qi_threshold: float = 0, qi_filter: List[SCLClass] = S2_FILTER1 ): """Filter S2 items based on quality information. Parameters: ---------------- qi_threshold : float Quality indicator threshold. qi_filter : List[SCLClass] Quality indicator filter. """ if not self.s2_items: logger.info("No Sentinel-2 items to filter.") return None # Filter by quality information self.filter_s2_items_by_quality_information(qi_threshold, qi_filter) if not self.s2_items: logger.info("No data passing the quality filter.") return None # Filter to specified tile or use the first tile self.filter_s2_items_by_tile()
[docs] def sort_s2_items(self): """Sort S2 items by acquisition time.""" self.s2_items = sorted(self.s2_items, key=lambda x: x.metadata.acquisition_time)
[docs] def data_to_xarray(self): """Convert data to xarray dataset.""" # Check that s2_items are available if not self.check_s2_items_exist(): return None # Sort s2_items by acquisition time self.sort_s2_items() # 2D data bands = self.req_params.bands bands_str = [b.value for b in bands] spectral_bands = [b for b in bands if b != S2Band.SCL] dataset_dict = {} multiband_arrays = [] scl_arrays = [] acquisition_times = [] aoi_pixels = None all_metadata = { "assetid": [], "productid": [], "datasource": [], "sun_azimuth": [], "sun_zenith": [], "view_azimuth": [], "view_zenith": [], } for s2_item in self.s2_items: # Handle spectral data if len(spectral_bands) > 0: multiband_array = np.stack([s2_item.data[band] for band in bands]) multiband_arrays.append(multiband_array) if S2Band.SCL in bands: scl_array = s2_item.data[S2Band.SCL] scl_arrays.append(scl_array) acquisition_times.append(np.datetime64(s2_item.metadata.acquisition_time)) all_metadata["assetid"].append(s2_item.metadata.assetid) all_metadata["productid"].append(s2_item.metadata.productid) all_metadata["datasource"].append(s2_item.metadata.datasource.value) if s2_item.metadata.observation_geometry: all_metadata["sun_azimuth"].append( s2_item.metadata.observation_geometry.sun_azimuth ) all_metadata["sun_zenith"].append( s2_item.metadata.observation_geometry.sun_zenith ) all_metadata["view_azimuth"].append( s2_item.metadata.observation_geometry.view_azimuth ) all_metadata["view_zenith"].append( s2_item.metadata.observation_geometry.view_zenith ) if multiband_arrays: multitemporal_array = np.stack(multiband_arrays) dataset_dict["band_data"] = ( ["time", "band", "y", "x"], multitemporal_array, ) # There might be nans within the aoi if data was not available, get the # max amount of aoi pixels # Most common amount of nan pixels, i.e. pixels outside aoi nans_per_band = np.isnan(multitemporal_array[0, :, :, :]) # for 1st img sums = np.sum(nans_per_band, axis=(1, 2)) # sums over x and y num_of_outside_aoi_pixels = np.argmax(np.bincount(sums)) total_num_of_pixels = np.size(multitemporal_array[0, 0, :, :]) aoi_pixels = total_num_of_pixels - num_of_outside_aoi_pixels if scl_arrays: multitemporal_scl_array = np.stack(scl_arrays).astype(np.int16) dataset_dict[S2Band.SCL.value] = ( ["time", "y", "x"], multitemporal_scl_array, ) # count aoi pixels if not already counted with spectral data if not aoi_pixels: aoi_pixels = np.size(multitemporal_scl_array[0, :, :]) - np.sum( multitemporal_scl_array[0, :, :] == SCL_NODATA ) # Add metadata to dataset dataset_dict.update( { var: (["time"], metadata_var) for var, metadata_var in all_metadata.items() if metadata_var } ) # crs from projection crs = self.s2_items[0].metadata.projection tileid = self.s2_items[0].metadata.tileid coords = { "time": acquisition_times, "band": bands_str, "y": self.s2_items[0].coordinates[bands[0]].y.astype(np.float64), "x": self.s2_items[0].coordinates[bands[0]].x.astype(np.float64), } ds = xr.Dataset( dataset_dict, coords=coords, attrs={ "name": self.aoi.name, "crs": crs, "tile_id": tileid, "aoi_geometry": self.aoi.geometry.wkt, "aoi_pixels": aoi_pixels, }, ) self.xr_dataset = ds
[docs] def filter_s2_qi_dataframe( s2_qi_dataframe: pd.DataFrame, qi_thresh: float, s2_filter: List[SCLClass] = S2_FILTER1, ) -> pd.DataFrame: """Filter qi dataframe. Parameters: ---------------- s2_qi_dataframe : pd.DataFrame Quality information dataframe. qi_thresh : float Quality indicator threshold. s2_filter : List[SCLClass] Quality indicator filter. Returns: ---------------- pd.DataFrame Filtered quality information dataframe. """ # Drop if SCL data is nan s2_qi_dataframe = s2_qi_dataframe.dropna(axis=0, subset=S2_SCL_CLASSES) s2_filter_str = [scl.name for scl in s2_filter] filtered_s2_qi_df = s2_qi_dataframe.loc[ s2_qi_dataframe[s2_filter_str].sum(axis=1) <= qi_thresh ] return filtered_s2_qi_df
[docs] def most_common(lst: list[str]) -> str: return max(set(lst), key=lst.count)