#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Classes for handling Sentinel-2 data.
@author: Olli Nevalainen (Finnish Meteorological Institute)
"""
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional
try:
# breaking change introduced in python 3.11
from enum import Enum, StrEnum
except ImportError:
from enum import Enum
[docs]
class StrEnum(str, Enum):
pass
import numpy as np
import pandas as pd
import xarray as xr
from satellitetools.common.classes import AOI, DataSource
[docs]
logger = logging.getLogger(__name__)
[docs]
SPECTRAL_BAND_NO_DATA = np.nan
# Use GEE band naming
[docs]
class S2Band(StrEnum):
"""Sentinel-2 bands."""
[docs]
def to_aws(self) -> str:
"""Convert band name to AWS band name.
Returns:
----------------
str
Band name in AWS.
"""
band_name_in_aws = S2_BANDS_GEE_TO_AWS[self.value]
return band_name_in_aws
[docs]
def to_gee(self) -> str:
"""Convert band name to GEE band name.
Returns:
----------------
str
Band name in GEE.
"""
return self.value
@classmethod
[docs]
def get_10m_to_20m_bands(cls) -> List["S2Band"]:
"""Get 10-20 m bands for the band.
Returns:
----------------
List[S2Band]
List of 10-20 m bands.
"""
return [S2Band(b) for b in S2_BANDS_10_20_GEE]
@classmethod
[docs]
def get_all_bands(cls) -> List["S2Band"]:
"""Get all bands for the band.
Returns:
----------------
List[S2Band]
List of all bands.
"""
return [b for b in S2Band]
[docs]
class SCLClass(Enum):
"""Sentinel-2 Scene Classification Layer (SCL) classes."""
[docs]
SATURATED_DEFECTIVE = 1
[docs]
DARK_FEATURE_SHADOW = 2
[docs]
S2_BANDS_GEE = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B11",
"B12",
"AOT",
"WVP",
"SCL",
]
[docs]
S2_BANDS_10_20_GEE = [
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B11",
"B12",
]
# Old v0 earth search
# S2_BANDS_COG = [
# "B01",
# "B02",
# "B03",
# "B04",
# "B05",
# "B06",
# "B07",
# "B08",
# "B8A",
# "B09",
# "B11",
# "B12",
# ]
# # 10-20 m bands
# S2_BANDS_10_20_COG = ["B02", "B03", "B04", "B05", "B06", "B07", "B8A", "B11", "B12"]
[docs]
S2_BANDS_COG = [
"coastal", # B1
"blue", # B2
"green", # B3
"red", # B4
"rededge1", # B5
"rededge2", # B6
"rededge3", # B7
"nir", # B8
"nir08", # B8A
"nir09", # B9
"swir16", # B11
"swir22", # B12
"aot",
"wvp",
"scl",
]
[docs]
S2_BANDS_10_20_COG = [
"blue",
"green",
"red",
"rededge1",
"rededge2",
"rededge3",
"nir08",
"swir16",
"swir22",
]
[docs]
S2_BANDS_AWS_TO_GEE = {
aws: gee for aws, gee in zip(S2_BANDS_COG, S2_BANDS_GEE, strict=True)
}
# S2_BANDS_10_20_AWS_TO_GEE = {
# aws: gee for aws, gee in zip(S2_BANDS_10_20_COG, S2_BANDS_10_20_GEE)
# }
[docs]
S2_BANDS_GEE_TO_AWS = {
gee: aws for gee, aws in zip(S2_BANDS_GEE, S2_BANDS_COG, strict=True)
}
# S2_BANDS_10_20_GEE_TO_AWS = {
# gee: aws for gee, aws in zip(S2_BANDS_10_20_GEE, S2_BANDS_10_20_COG)
# }
[docs]
S2_SCL_CLASSES = [c.name for c in SCLClass]
[docs]
S2_FILTER1 = [
SCLClass.NODATA,
SCLClass.SATURATED_DEFECTIVE,
SCLClass.CLOUD_SHADOW,
SCLClass.UNCLASSIFIED,
SCLClass.CLOUD_MEDIUM_PROBA,
SCLClass.CLOUD_HIGH_PROBA,
SCLClass.THIN_CIRRUS,
SCLClass.SNOW_ICE,
]
[docs]
S2_FILTER2 = [
SCLClass.NODATA,
SCLClass.SATURATED_DEFECTIVE,
SCLClass.CLOUD_SHADOW,
SCLClass.CLOUD_MEDIUM_PROBA,
SCLClass.CLOUD_HIGH_PROBA,
SCLClass.THIN_CIRRUS,
SCLClass.SNOW_ICE,
]
[docs]
class Sentinel2RequestParams:
"""S2 data request paramaters.
Attributes:
----------------
datestart : str
Start date of the data request.
dateend : str
End date of the data request.
datasource : DataSource
Data source for the request.
bands : List[S2Band]
List of Sentinel-2 bands to request.
target_gsd : float
Target ground sampling distance (GSD) in meters.
qi_evaluation_scale : float
Quality indicator evaluation scale.
"""
def __init__(
self,
datestart: str,
dateend: str,
datasource: DataSource,
bands: List[S2Band] = None,
target_gsd: float = 20,
qi_evaluation_scale: float = 20,
):
"""Initialize the Sentinel2RequestParams class.
Parameters:
----------------
datestart : str
Start date of the data request.
dateend : str
End date of the data request.
datasource : DataSource
Data source for the request.
bands : List[S2Band]
List of Sentinel-2 bands to request.
target_gsd : float
Target ground sampling distance (GSD) in meters.
qi_evaluation_scale : float
Quality indicator evaluation scale.
"""
[docs]
self.datestart = datestart
[docs]
self.datasource = datasource
if bands:
# Validate input and convert to S2Band if strings
self.bands = [S2Band(band) for band in bands]
else:
self.bands = [band for band in S2Band]
[docs]
self.target_gsd = 20 if target_gsd is None else target_gsd
[docs]
self.qi_evaluation_scale = (
20 if qi_evaluation_scale is None else qi_evaluation_scale
)
[docs]
def __repr__(self) -> str:
return (
f"RequestParams(datestart={self.datestart}, dateend={self.dateend}, "
f"datasource={self.datasource}, bands={self.bands}, "
f"target_gsd={self.target_gsd}, qi_evaluation_scale={self.qi_evaluation_scale})"
)
@dataclass
[docs]
class Sentinel2ObservationGeometry:
"""Sentinel-2 observation geometry class.
Attributes:
----------------
sun_azimuth : float
Sun azimuth angle.
sun_zenith : float
Sun zenith angle.
view_azimuth : float
View azimuth angle.
view_zenith : float
View zenith angle.
"""
@dataclass
[docs]
class Coordinates:
"""Coordinates class.
Attributes:
----------------
x : list[float]
List of x coordinates.
y : list[float]
List of y coordinates.
"""
[docs]
class Sentinel2Item:
"""Sentinel-2 data item class.
Attributes:
----------------
metadata : Sentinel2Metadata
Metadata of the data item.
data : Optional[Dict[S2Band, np.ndarray]]
Data of the data item.
coordinates : Dict[S2Band, Coordinates]
Coordinates of the data item.
"""
def __init__(
self,
metadata: Sentinel2Metadata,
data: Optional[Dict[S2Band, np.ndarray]] = None,
):
"""Initialize the Sentinel2Item class.
Parameters:
----------------
metadata : Sentinel2Metadata
Metadata of the data item.
data : Optional[Dict[S2Band, np.ndarray]]
Data of the data item.
"""
[docs]
self.data = data if data else {}
[docs]
self.coordinates: Dict[S2Band, Coordinates] = {}
[docs]
def __repr__(self):
return (
f"Sentinel2Item(productid={self.metadata.productid}, "
f"time={self.metadata.acquisition_time},"
f"data={self.data.keys()})"
)
[docs]
def set_coordinates(self, band: S2Band, xs: list, ys: list):
"""Set coordinates for the data item.
Parameters:
----------------
band : S2Band
Sentinel-2 band.
xs : list
List of x coordinates.
ys : list
List of y coordinates.
"""
self.coordinates[band] = Coordinates(x=xs, y=ys)
[docs]
class Sentinel2DataCollection:
"""Sentinel-2 data collection class.
Attributes:
----------------
aoi : AOI
Area of interest.
req_params : Sentinel2RequestParams
Request parameters.
quality_information : pd.DataFrame
Quality information.
xr_dataset : xr.Dataset
Xarray dataset.
s2_items : List[Sentinel2Item]
List of Sentinel-2 items.
"""
def __init__(
self,
aoi: AOI,
req_params: Sentinel2RequestParams,
):
"""Initialize the Sentinel2DataCollection class.
Parameters:
----------------
aoi : AOI
Area of interest.
req_params : Sentinel2RequestParams
Request parameters.
"""
[docs]
self.req_params = req_params
[docs]
self.xr_dataset: xr.Dataset = None
[docs]
self.s2_items: List[Sentinel2Item] = None
# Print information about the data collection
[docs]
def __repr__(self):
return (
f"Sentinel2DataCollection(AOI.name={self.aoi.name}, "
f"Sentinel2RequestParams={self.req_params})"
)
[docs]
def check_s2_items_exist(self) -> bool:
if self.s2_items is None:
logger.info("No Sentinel-2 items searched (s2_items=None).")
return False
elif len(self.s2_items) == 0:
logger.info(
"No Sentinel-2 items found for the time period or "
"left after filtering (s2_items=[])."
)
return False
return True
[docs]
def filter_s2_items_by_tile(self, tileid: Optional[str] = None):
"""Filter S2 items by tile.
Parameters:
----------------
tileid : str
Tile ID.
"""
if not self.s2_items:
logger.debug("No Sentinel-2 items to filter by tile.")
return None
if tileid is not None:
self.s2_items = [
s2_item
for s2_item in self.s2_items
if s2_item.metadata.tileid == tileid
]
else:
if self.aoi.tile is None:
all_tiles = [s2_items.metadata.tileid for s2_items in self.s2_items]
most_common_tile = most_common(all_tiles)
self.aoi.tile = most_common_tile
self.s2_items = [
s2_item
for s2_item in self.s2_items
if s2_item.metadata.tileid == self.aoi.tile
]
[docs]
def filter_s2_items(
self, qi_threshold: float = 0, qi_filter: List[SCLClass] = S2_FILTER1
):
"""Filter S2 items based on quality information.
Parameters:
----------------
qi_threshold : float
Quality indicator threshold.
qi_filter : List[SCLClass]
Quality indicator filter.
"""
if not self.s2_items:
logger.info("No Sentinel-2 items to filter.")
return None
# Filter by quality information
self.filter_s2_items_by_quality_information(qi_threshold, qi_filter)
if not self.s2_items:
logger.info("No data passing the quality filter.")
return None
# Filter to specified tile or use the first tile
self.filter_s2_items_by_tile()
[docs]
def sort_s2_items(self):
"""Sort S2 items by acquisition time."""
self.s2_items = sorted(self.s2_items, key=lambda x: x.metadata.acquisition_time)
[docs]
def data_to_xarray(self):
"""Convert data to xarray dataset."""
# Check that s2_items are available
if not self.check_s2_items_exist():
return None
# Sort s2_items by acquisition time
self.sort_s2_items()
# 2D data
bands = self.req_params.bands
bands_str = [b.value for b in bands]
spectral_bands = [b for b in bands if b != S2Band.SCL]
dataset_dict = {}
multiband_arrays = []
scl_arrays = []
acquisition_times = []
aoi_pixels = None
all_metadata = {
"assetid": [],
"productid": [],
"datasource": [],
"sun_azimuth": [],
"sun_zenith": [],
"view_azimuth": [],
"view_zenith": [],
}
for s2_item in self.s2_items:
# Handle spectral data
if len(spectral_bands) > 0:
multiband_array = np.stack([s2_item.data[band] for band in bands])
multiband_arrays.append(multiband_array)
if S2Band.SCL in bands:
scl_array = s2_item.data[S2Band.SCL]
scl_arrays.append(scl_array)
acquisition_times.append(np.datetime64(s2_item.metadata.acquisition_time))
all_metadata["assetid"].append(s2_item.metadata.assetid)
all_metadata["productid"].append(s2_item.metadata.productid)
all_metadata["datasource"].append(s2_item.metadata.datasource.value)
if s2_item.metadata.observation_geometry:
all_metadata["sun_azimuth"].append(
s2_item.metadata.observation_geometry.sun_azimuth
)
all_metadata["sun_zenith"].append(
s2_item.metadata.observation_geometry.sun_zenith
)
all_metadata["view_azimuth"].append(
s2_item.metadata.observation_geometry.view_azimuth
)
all_metadata["view_zenith"].append(
s2_item.metadata.observation_geometry.view_zenith
)
if multiband_arrays:
multitemporal_array = np.stack(multiband_arrays)
dataset_dict["band_data"] = (
["time", "band", "y", "x"],
multitemporal_array,
)
# There might be nans within the aoi if data was not available, get the
# max amount of aoi pixels
# Most common amount of nan pixels, i.e. pixels outside aoi
nans_per_band = np.isnan(multitemporal_array[0, :, :, :]) # for 1st img
sums = np.sum(nans_per_band, axis=(1, 2)) # sums over x and y
num_of_outside_aoi_pixels = np.argmax(np.bincount(sums))
total_num_of_pixels = np.size(multitemporal_array[0, 0, :, :])
aoi_pixels = total_num_of_pixels - num_of_outside_aoi_pixels
if scl_arrays:
multitemporal_scl_array = np.stack(scl_arrays).astype(np.int16)
dataset_dict[S2Band.SCL.value] = (
["time", "y", "x"],
multitemporal_scl_array,
)
# count aoi pixels if not already counted with spectral data
if not aoi_pixels:
aoi_pixels = np.size(multitemporal_scl_array[0, :, :]) - np.sum(
multitemporal_scl_array[0, :, :] == SCL_NODATA
)
# Add metadata to dataset
dataset_dict.update(
{
var: (["time"], metadata_var)
for var, metadata_var in all_metadata.items()
if metadata_var
}
)
# crs from projection
crs = self.s2_items[0].metadata.projection
tileid = self.s2_items[0].metadata.tileid
coords = {
"time": acquisition_times,
"band": bands_str,
"y": self.s2_items[0].coordinates[bands[0]].y.astype(np.float64),
"x": self.s2_items[0].coordinates[bands[0]].x.astype(np.float64),
}
ds = xr.Dataset(
dataset_dict,
coords=coords,
attrs={
"name": self.aoi.name,
"crs": crs,
"tile_id": tileid,
"aoi_geometry": self.aoi.geometry.wkt,
"aoi_pixels": aoi_pixels,
},
)
self.xr_dataset = ds
[docs]
def filter_s2_qi_dataframe(
s2_qi_dataframe: pd.DataFrame,
qi_thresh: float,
s2_filter: List[SCLClass] = S2_FILTER1,
) -> pd.DataFrame:
"""Filter qi dataframe.
Parameters:
----------------
s2_qi_dataframe : pd.DataFrame
Quality information dataframe.
qi_thresh : float
Quality indicator threshold.
s2_filter : List[SCLClass]
Quality indicator filter.
Returns:
----------------
pd.DataFrame
Filtered quality information dataframe.
"""
# Drop if SCL data is nan
s2_qi_dataframe = s2_qi_dataframe.dropna(axis=0, subset=S2_SCL_CLASSES)
s2_filter_str = [scl.name for scl in s2_filter]
filtered_s2_qi_df = s2_qi_dataframe.loc[
s2_qi_dataframe[s2_filter_str].sum(axis=1) <= qi_thresh
]
return filtered_s2_qi_df
[docs]
def most_common(lst: list[str]) -> str:
return max(set(lst), key=lst.count)