Source code for geoxarray.accessor

#!/usr/bin/env python
# Copyright geoxarray Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XArray extensions via accessor objects.

The functionality in this module can be accessed via the ``.geo`` accessor on
any xarray DataArray or Dataset object.

Geolocation cases that these accessors are supposed to be able to handle:

1. CF compliant Dataset: A :class:`~xarray.Dataset` object with one or
   more data variables and one CRS specification variable. By default
   a 'grid_mapping' attribute is used to specify the name of the variable.

   https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch05s06.html
2. Geotiff DataArrays: A :class:`~xarray.DataArray` object returned by either
   the ``rioxarray`` library or by :func:`xarray.open_rasterio``.
3. Raw lon/lat coordinate arrays: A :class:`~xarray.DataArray` or :class:`~xarray.Dataset`
   object that contains 1D or 2D longitude and latitude arrays defining the
   coordinates of the data.

These accessors attempt to provide standard interfaces to the following information
from these different data cases:

1. Standard dimensions: X (columns), Y (rows), Vertical, and Time.
2. Coordinate Reference System (CRS)

Lastly, these accessor provide basic wrappers around other tools that are
typically used with geospatial data (resampling, plotting, etc) or converting
to other formats (CF compatible NetCDF file).

"""

from __future__ import annotations

import warnings
from typing import Any, Generic, Literal, TypeVar

import numpy as np
import xarray
import xarray as xr
from pyproj import CRS
from pyproj.exceptions import CRSError

from .coords import spatial_coords

try:
    from pyresample.geometry import AreaDefinition, SwathDefinition

    has_pyresample = True
except ImportError:
    AreaDefinition = SwathDefinition = None
    has_pyresample = False


try:
    from rasterio.crs import CRS as RioCRS

    has_rasterio = True
except ImportError:
    RioCRS = None
    has_rasterio = False


DEFAULT_GRID_MAPPING_VARIABLE_NAME = "spatial_ref"
XarrayObject = TypeVar("XarrayObject", xr.DataArray, xr.Dataset)


class _SharedGeoAccessor(Generic[XarrayObject]):
    """Accessor functionality shared between Dataset and DataArray objects."""

    def __init__(self, xarray_obj: XarrayObject) -> None:
        """Set handle for xarray object."""
        self._obj = xarray_obj
        self._crs: CRS | Literal[False] | None = None
        self._x_dim = None
        self._y_dim = None
        self._vertical_dim = None
        self._time_dim = None
        self._dim_map = False

    def _get_obj(self, inplace: bool) -> XarrayObject:
        """Get the object to modify.

        Parameters
        ----------
        inplace
            If True, returns self.

        Returns
        -------
        :obj:`xarray.Dataset` | :obj:`xarray.DataArray`

        """
        if inplace:
            return self._obj
        obj_copy = self._obj.copy(deep=True)
        # preserve attribute information
        obj_copy.geo._crs = self._crs
        obj_copy.geo._x_dim = self._x_dim
        obj_copy.geo._y_dim = self._y_dim
        obj_copy.geo._vertical_dim = self._vertical_dim
        obj_copy.geo._time_dim = self._time_dim
        return obj_copy

    @property
    def dim_map(self):
        """Map current data dimension to geoxarray preferred dim name."""
        if self._dim_map is False:
            # we haven't determined dimensions yet
            self.set_dims(inplace=True)

        if self._dim_map is None:
            self._dim_map = {}
            if self._x_dim is not None:
                self._dim_map[self._x_dim] = "x"
            if self._y_dim is not None:
                self._dim_map[self._y_dim] = "y"
            if self._vertical_dim is not None:
                self._dim_map[self._vertical_dim] = "vertical"
            if self._time_dim is not None:
                self._dim_map[self._time_dim] = "time"

        return self._dim_map

    @property
    def _geo_dim_map(self):
        """Map geoxarray preferred dim name to current data dimension name."""
        dim_map = self.dim_map
        return {gx_dim: curr_dim for curr_dim, gx_dim in dim_map.items()}

    def set_dims(self, *args, **kwargs) -> None:
        """Tell geoxarray the names of the provided dimensions in this Xarray object."""
        raise NotImplementedError()

    @property
    def dims(self):
        """Get preferred dimension names in order."""
        return tuple(self.dim_map.get(dname, dname) for dname in self._obj.dims)

    @property
    def sizes(self):
        """Get size map with preferred dimension names."""
        # return the same type of object as xarray
        sizes_dict = {}
        for dname, size in self._obj.sizes.items():
            sizes_dict[self.dim_map.get(dname, dname)] = size
        return self._obj.sizes.__class__(sizes_dict)

    def write_dims(self) -> XarrayObject:
        """Rename object's dimensions to match geoxarray's preferred dimension names.

        This is a simple wrapper around Xarray's :meth:`xarray.DataArray.rename`
        or :meth:`xarray.Dataset.rename` methods along with ``.geo.dim_map`` to
        rename the dimension names. These methods always produce copies of the
        original object. It is not possible to do this operation "inplace".

        """
        obj_copy = self._get_obj(inplace=False)
        return obj_copy.rename(self.dim_map)

    @property
    def crs(self) -> None | CRS:
        if self._crs is False:
            # we've tried to find the CRS, there isn't one
            return None
        elif self._crs is not None:
            # we've already determined what the CRS is, return it
            return self._crs

        crs_methods = (
            self._get_crs_from_grid_mapping,
            self._get_crs_from_pyresample,
        )
        for crs_method in crs_methods:
            crs = crs_method()
            if crs is not None:
                self._crs = crs
                break
        else:
            self._crs = False
            return None
        return self._crs

    def _get_crs_from_grid_mapping(self):
        gm_var = self._get_gm_var()
        if gm_var is None:
            return None
        for crs_attr in ("spatial_ref", "crs_wkt"):
            try:
                crs_info = gm_var.attrs[crs_attr]
            except KeyError:
                continue
            crs = CRS.from_wkt(crs_info)
            return crs
        else:
            return self._get_crs_from_cf()

    def _get_gm_var(self):
        gm_var_name = self.grid_mapping
        if gm_var_name is None:
            return None
        if gm_var_name not in self._obj.coords:
            warnings.warn(
                "'grid_mapping' attribute found, but the variable it refers to "
                f"{gm_var_name} is not a coordinate variable. "
                "Use 'data_arr.geo.set_cf_grid_mapping' to "
                "provide one. Will search other metadata for CRS information.",
                stacklevel=2,
            )
            return None
        return self._obj.coords[gm_var_name]

    def _get_crs_from_cf(self):
        try:
            return CRS.from_cf(self._obj.coords[self.grid_mapping or DEFAULT_GRID_MAPPING_VARIABLE_NAME].attrs)
        except (KeyError, CRSError):
            return None

    def _get_crs_from_pyresample(self):
        area = self._obj.attrs.get("area")
        if area is None:
            return None
        if hasattr(area, "crs"):
            return area.crs
        return None

    def write_crs(
        self, new_crs_info: Any | None = None, grid_mapping_name: str | None = None, inplace: bool = False
    ) -> XarrayObject:
        """Write the CRS to the xarray object in a CF compliant manner.

        .. note::

            Much of this code is copied from the rioxarray project and is under the Apache 2.0 license.
            A copy of this license is available in the source file ``LICENSE_rioxarray``.

        Parameters
        ----------
        new_crs_info:
            Coordinate Reference System (CRS) information to write to the
            Xarray object. Can be a :class:`pyproj.CRS` object or anything
            understood by the :meth:`pyproj.CRS.from_user_input` method.
            If not provided, the ``.crs`` property will be used.
            If ``.crs`` returns ``None`` a ``RuntimeError`` is raised.
        grid_mapping_name:
            Name to use for the coordinate variable created and written by this
            method. The coordinate variable, also known as the grid mapping
            variable, will have this name when written to a NetCDF file.
            Defaults to "spatial_ref".
        inplace:
            Whether to modify the current Xarray object inplace or to create
            a copy first. Default (``False``) is to make a copy.

        """
        obj = self._get_obj(inplace)
        crs = self._optional_crs_from_input(new_crs_info, obj)
        grid_mapping_var_name = self.grid_mapping if grid_mapping_name is None else grid_mapping_name
        if grid_mapping_var_name is None:
            grid_mapping_var_name = DEFAULT_GRID_MAPPING_VARIABLE_NAME

        gm_attrs = crs.to_cf()
        crs_wkt = crs.to_wkt()
        gm_attrs["crs_wkt"] = crs_wkt  # CF compatibility
        gm_attrs["spatial_ref"] = crs_wkt  # GDAL support

        self._add_empty_grid_mapping(obj, grid_mapping_var_name)
        obj.coords[grid_mapping_var_name].attrs.update(gm_attrs)
        _assign_grid_mapping(obj, grid_mapping_var_name)
        return obj

    @staticmethod
    def _add_empty_grid_mapping(obj: XarrayObject, grid_mapping_var_name: str) -> None:
        obj.coords[grid_mapping_var_name] = xr.Variable((), np.int64(0))

    def _optional_crs_from_input(self, new_crs_info: Any | None, obj: XarrayObject) -> CRS:
        if new_crs_info is None:
            crs = self.crs
            if crs is None:
                raise RuntimeError("No CRS information provided or found.")
        else:
            crs = CRS.from_user_input(new_crs_info)
            obj.geo._crs = crs
        return crs

    @property
    def grid_mapping(self) -> str | None:
        """Name of a grid mapping variable associated with this DataArray.

        .. note::

            Much of this code is copied from the rioxarray project and is under the Apache 2.0 license.
            A copy of this license is available in the source file ``LICENSE_rioxarray``.

        Returns
        -------
        Grid mapping variable name defined in the xarray object. If not found,
        None is returned.

        """
        gm_var_name = self._obj.encoding.get("grid_mapping") or self._obj.attrs.get("grid_mapping")
        if gm_var_name is not None:
            return gm_var_name
        if hasattr(self._obj, "data_vars"):
            var_grid_mappings = set(self._all_grid_mapping_names())
            if len(var_grid_mappings) > 1:
                raise RuntimeError("Multiple grid mapping variables exist.")
            if len(var_grid_mappings) == 1:
                return var_grid_mappings.pop()
        return None

    def _all_grid_mapping_names(self):
        for var_name in self._obj.data_vars:
            var_grid_mapping = _get_encoding_or_attr(self._obj[var_name], "grid_mapping")
            if var_grid_mapping is None:
                continue
            yield var_grid_mapping

    def write_spatial_coords(self) -> xarray.DataArray:
        """Write 'y' and 'x' coordinate arrays to ``.coords``.

        This operation always produces a new copy of the xarray object. See
        :meth:`xarray.DataArray.assign_coords` and
        :meth:`xarray.Dataset.assign_coords`.

        See :func:`geoxarray.coords.spatial_coords` for supported metadata
        structures.

        Coordinate arrays are currently always a single point per data pixel
        representing the center of the pixel.

        """
        # don't make an extra copy, assign_coords will do it for us
        obj = self._get_obj(inplace=True)
        geo_dim_map = obj.geo._geo_dim_map
        y_dim_name = geo_dim_map["y"]
        x_dim_name = geo_dim_map["x"]

        coords_dict = spatial_coords(obj)
        obj = obj.assign_coords(
            {
                y_dim_name: coords_dict["y"],
                x_dim_name: coords_dict["x"],
            }
        )
        return obj

    @property
    def gcps(self) -> dict | None:
        """Get GeoJSON-formatted GCPs, if any, from the grid mapping coordinate variable."""
        grid_mapping_var_name = self.grid_mapping
        if grid_mapping_var_name is None:
            grid_mapping_var_name = DEFAULT_GRID_MAPPING_VARIABLE_NAME
        if grid_mapping_var_name not in self._obj.coords:
            return None
        return self._obj.coords[grid_mapping_var_name].attrs.get("gcps")

    def write_gcps(self, gcps: str, grid_mapping_name: str | None = None, inplace: bool = False) -> None:
        """Write GeoJSON-formatted GCPs to the spatial ref.

        GCPs can be retrieved later from the ``obj.geo.gcps`` property.
        The GeoJSON will also be available from the grid mapping coordinate
        variable as an attribute named "gcps".

        More information on the GeoJSON format and examples can be found here:
        https://geojson.org/. GCP GeoJSON is almost always constructed from a
        series of "Point" features in a ``FeatureCollection``. A basic example::

            {'type': 'FeatureCollection', 'features': [
                {'type': 'Feature', 'properties': {'id': '1', 'info': '', 'row': 0.0, 'col': 0.0},
                 'geometry': {'type': 'Point', 'coordinates': [33.03, 61.80, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '2', 'info': '', 'row': 0.0, 'col': 530.0},
                 'geometry': {'type': 'Point', 'coordinates': [32.64, 61.85, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '3', 'info': '', 'row': 0.0, 'col': 1060.0},
                 'geometry': {'type': 'Point', 'coordinates': [32.25, 61.90, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '4', 'info': '', 'row': 0.0, 'col': 1590.0},
                 'geometry': {'type': 'Point', 'coordinates': [31.86, 61.95, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '5', 'info': '', 'row': 0.0, 'col': 2120.0},
                 'geometry': {'type': 'Point', 'coordinates': [31.47, 62.00, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '6', 'info': '', 'row': 0.0, 'col': 2650.0},
                 'geometry': {'type': 'Point', 'coordinates': [31.08, 62.04, 126.43]}},
                {'type': 'Feature', 'properties': {'id': '7', 'info': '', 'row': 0.0, 'col': 3180.0},
                 'geometry': {'type': 'Point', 'coordinates': [30.68, 62.09, 126.43]}}]}

        Parameters
        ----------
        gcps:
            GeoJSON-formatted Ground Control Points (GCPs).
        grid_mapping_name:
            Name to use for the coordinate variable created and written by this
            method. The coordinate variable, also known as the grid mapping
            variable, will have this name when written to a NetCDF file.
            Defaults to "spatial_ref".
        inplace:
            Whether to modify the current Xarray object inplace or to create
            a copy first. Default (``False``) is to make a copy.

        """
        obj = self._get_obj(inplace)
        grid_mapping_var_name = self.grid_mapping if grid_mapping_name is None else grid_mapping_name
        if grid_mapping_var_name is None:
            grid_mapping_var_name = DEFAULT_GRID_MAPPING_VARIABLE_NAME
        if grid_mapping_var_name not in obj.coords:
            self._add_empty_grid_mapping(obj, grid_mapping_var_name)
        obj.coords[grid_mapping_var_name].attrs["gcps"] = gcps
        return obj


def _get_encoding_or_attr(xr_obj: xr.Dataset | xr.DataArray, attr_name: str) -> Any:
    return xr_obj.encoding.get(attr_name, xr_obj.attrs.get(attr_name))


def _assign_grid_mapping(xr_obj: xr.DataArray | xr.Dataset, grid_mapping_var_name: str) -> None:
    xr_obj.attrs.pop("grid_mapping", None)
    xr_obj.encoding["grid_mapping"] = grid_mapping_var_name

    if hasattr(xr_obj, "data_vars"):
        for var_name in xr_obj.data_vars:
            data_arr = xr_obj[var_name]
            dims = data_arr.geo.dims
            if not dims or all(dim_name not in dims for dim_name in ("x", "y", "z")):
                # no spatial dimensions
                continue
            _assign_grid_mapping(data_arr, grid_mapping_var_name)


[docs] @xr.register_dataset_accessor("geo") class GeoDatasetAccessor(_SharedGeoAccessor): """Provide Dataset geolocation helper functions from a `.geo` accessor."""
[docs] def set_dims( self, x: str | None = None, y: str | None = None, vertical: str | None = None, time: str | None = None, inplace: bool = False, ): """Tell geoxarray the names of the provided dimensions in this Dataset. Parameters ---------- x: Name of the X dimension. This dimension usually exists with a corresponding coordinate variable in meters for gridded/projected data. y: Name of the Y dimension. Similar to the X dimension but on the Y axis. vertical: Name of the vertical or Z dimension. This dimension usually exists with a corresponding coordinate variable in meters for altitude or pressure level (ex. hPa, millibar, etc). time: Name of the time dimension. This dimension usually exists with a corresponding coordinate variable with time objects. inplace: If True, changes are made to the current xarray object. Otherwise, a copy of the object is made first. Default is False. """ all_dims = { "x": x, "y": y, "vertical": vertical, "time": time, } obj_copy = self._get_obj(inplace) # tell the dim_map property to produce the "as-is" dim map obj_copy.geo._dim_map = None dim_map = obj_copy.geo.dim_map for data_arr in obj_copy.data_vars.values(): dims = {k: v for k, v in all_dims.items() if v in data_arr.dims} if not dims: continue var_dim_map = data_arr.geo.set_dims(**dims, inplace=True).geo.dim_map self._update_gx_dim_dict(dim_map, var_dim_map) # update our attributes for dim_name, gx_dim_name in dim_map.items(): if gx_dim_name is None: continue setattr(obj_copy.geo, f"_{gx_dim_name}_dim", dim_name) # tell the dim_map to get regenerated obj_copy.geo._dim_map = None return obj_copy
def _update_gx_dim_dict(self, old, new): for k, v in new.items(): if v in ("x", "y", "time", "vertical"): old[k] = v return old
[docs] @xr.register_dataarray_accessor("geo") class GeoDataArrayAccessor(_SharedGeoAccessor): """Provide DataArray geolocation helper functions from a `.geo` accessor.""" def __init__(self, data_arr_obj: xr.DataArray) -> None: """Initialize a 'best guess' dimension mapping to preferred dimension names.""" self._is_gridded: bool | None = None super().__init__(data_arr_obj) def _get_obj(self, inplace): """Get the object to modify. Parameters ---------- inplace: bool If True, returns self. Returns ------- :obj:`xarray.Dataset` | :obj:`xarray.DataArray` """ obj_copy = super()._get_obj(inplace) # preserve attribute information obj_copy.geo._is_gridded = self._is_gridded return obj_copy
[docs] def set_dims( self, x: None | str = None, y: None | str = None, vertical: None | str = None, time: None | str = None, inplace: bool = True, ) -> xr.DataArray: """Set preferred dimension names inside the Geoxarray accessor. Geoxarray will use this information for future operations. If any of the dimensions are not provided they will be found by best guess. This information does not rename or modify the data of the Xarray object itself. To easily rename the dimensions in a Geoxarray-friendly manner, follow a call of this method with :meth:`write_dims`. Parameters ---------- x: Name of the X dimension. This dimension usually exists with a corresponding coordinate variable in meters for gridded/projected data. y: Name of the Y dimension. Similar to the X dimension but on the Y axis. vertical: Name of the vertical or Z dimension. This dimension usually exists with a corresponding coordinate variable in meters for altitude or pressure level (ex. hPa, millibar, etc). time: Name of the time dimension. This dimension usually exists with a corresponding coordinate variable with time objects. inplace: If True, changes are made to the current xarray object. Otherwise, a copy of the object is made first. Default is False. See Also -------- GeoDataArrayAccessor.dims : Show the current dimensions as Geoxarray knows them GeoDataArrayAccessor.write_dims : Rename dimensions to match Geoxarray preferred dimension names """ obj = self._get_obj(inplace) self._set_x_dim(obj, x) self._set_y_dim(obj, y) self._set_2d_dims(obj) self._set_vertical_dims(obj, vertical) self._set_temporal_dims(obj, time) obj.geo._dim_map = None return obj
def _set_x_dim(self, obj, x): dims = obj.dims if x is None and self._x_dim is None: if "x" in dims: obj.geo._x_dim = "x" elif x is not None: assert x in dims obj.geo._x_dim = x def _set_y_dim(self, obj, y): dims = obj.dims if y is None and self._y_dim is None: if "y" in dims: obj.geo._y_dim = "y" elif y is not None: assert y in dims obj.geo._y_dim = y def _set_2d_dims(self, obj): dims = obj.dims if len(dims) == 2 and self._x_dim is None and self._y_dim is None: obj.geo._y_dim = dims[0] obj.geo._x_dim = dims[1] def _set_vertical_dims(self, obj, vertical): dims = obj.dims if vertical is None and self._vertical_dim is None: for z_dim in ("z", "vertical", "pressure_level"): if z_dim in dims: obj.geo._vertical_dim = z_dim break elif vertical is not None: assert vertical in dims obj.geo._vertical_dim = vertical def _set_temporal_dims(self, obj, time): dims = obj.dims if time is None and self._time_dim is None: for t_dim in ("time", "t"): if t_dim in dims: obj.geo._time_dim = t_dim break elif time is not None: assert time in dims obj.geo._time_dim = time
[docs] def get_lonlats(self, chunks=None): """Return longitude and latitude arrays. Parameters ---------- chunks : None or int Specify chunk size for dask arrays. Returns ------- Longitude and latitude dask arrays. If `chunks` is None then a numpy array is returned. """ raise NotImplementedError()
[docs] def plot(self): """Plot data on a map.""" # TODO: Support multiple backends (cartopy, geoviews, etc)? raise NotImplementedError()