"""Common utility functions for datamodel alignment."""
from __future__ import annotations
import functools
import logging
from typing import TYPE_CHECKING, Protocol
import gwcs
import numpy as np
from astropy import units as u
from astropy import wcs as fitswcs
from astropy.coordinates import SkyCoord
from astropy.modeling import models as astmodels
from astropy.utils.misc import isiterable
from gwcs.wcstools import wcs_from_fiducial
if TYPE_CHECKING:
from asdf import AsdfFile
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
__all__ = [
"compute_scale",
"compute_fiducial",
"calc_rotation_matrix",
"wcs_from_footprints",
"reproject",
]
class SupportsDataWithWcs(Protocol):
_asdf: AsdfFile
def to_flat_dict():
...
def _calculate_fiducial_from_spatial_footprint(
spatial_footprint: np.ndarray,
) -> np.ndarray:
"""
Calculates the fiducial coordinates from a given spatial footprint.
Parameters
----------
spatial_footprint : numpy.ndarray
A 2xN array containing the world coordinates of the WCS footprint's
bounding box, where N is the number of bounding box positions.
Returns
-------
lon_fiducial, lat_fiducial : numpy.ndarray, numpy.ndarray
The world coordinates of the fiducial point in the output coordinate frame.
"""
lon, lat = spatial_footprint
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
z = np.sin(lat)
x_mid = (np.max(x) + np.min(x)) / 2.0
y_mid = (np.max(y) + np.min(y)) / 2.0
z_mid = (np.max(z) + np.min(z)) / 2.0
lon_fiducial = np.rad2deg(np.arctan2(y_mid, x_mid)) % 360.0
lat_fiducial = np.rad2deg(np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2)))
return lon_fiducial, lat_fiducial
def _generate_tranform(
refmodel: SupportsDataWithWcs,
ref_fiducial: np.array,
pscale_ratio: int | None = None,
pscale: float | None = None,
rotation: float | None = None,
transform: astmodels.Model | None = None,
) -> astmodels.Model:
"""
Creates a transform from pixel to world coordinates based on a
reference datamodel's WCS.
Parameters
----------
refmodel :
The datamodel that should be used as reference for calculating the
transform parameters.
pscale_ratio : int, None
Ratio of input to output pixel scale. This parameter is only used when
``pscale=None`` and, in that case, it is passed on to ``compute_scale``.
pscale : float, None
The plate scale. If `None`, the plate scale is calculated from the reference
datamodel.
rotation : float, None
Position angle of output image's Y-axis relative to North.
A value of 0.0 would orient the final output image to be North up.
The default of `None` specifies that the images will not be rotated,
but will instead be resampled in the default orientation for the camera
with the x and y axes of the resampled image corresponding
approximately to the detector axes. Ignored when ``transform`` is
provided. If `None`, the rotation angle is extracted from the
reference model's ``meta.wcsinfo.roll_ref``.
ref_fiducial : numpy.array
A two-elements array containing the world coordinates of the fiducial point.
transform : ~astropy.modeling.Model
A transform between frames.
Returns
-------
transform : ~astropy.modeling.Model
An :py:mod:`~astropy` model containing the transform between frames.
"""
if transform is None:
sky_axes = refmodel.meta.wcs._get_axes_indices().tolist() # noqa: SLF001
v3yangle = np.deg2rad(refmodel.meta.wcsinfo.v3yangle)
vparity = refmodel.meta.wcsinfo.vparity
if rotation is None:
roll_ref = np.deg2rad(refmodel.meta.wcsinfo.roll_ref)
else:
roll_ref = np.deg2rad(rotation) + (vparity * v3yangle)
# reshape the rotation matrix returned from calc_rotation_matrix
# into the correct shape for constructing the transformation
pc = np.reshape(calc_rotation_matrix(roll_ref, v3yangle, vparity=vparity), (2, 2))
rotation = astmodels.AffineTransformation2D(pc, name="pc_rotation_matrix")
transform = [rotation]
if sky_axes:
if not pscale:
pscale = compute_scale(refmodel.meta.wcs, ref_fiducial, pscale_ratio=pscale_ratio)
transform.append(astmodels.Scale(pscale, name="cdelt1") & astmodels.Scale(pscale, name="cdelt2"))
if transform:
transform = functools.reduce(lambda x, y: x | y, transform)
return transform
def _get_axis_min_and_bounding_box(ref_model, wcs_list, ref_wcs):
"""
Calculates axis minimum values and bounding box.
Parameters
----------
ref_model :
The reference datamodel for which to determine the minimum axis values and
bounding box.
wcs_list : list
The list of WCS objects.
ref_wcs : ~gwcs.wcs.WCS
The reference WCS object.
Returns
-------
tuple
A tuple containing two elements:
1 - a :py:class:`numpy.ndarray` with the minimum value in each axis;
2 - a tuple containing the bounding box region in the format
((x0_lower, x0_upper), (x1_lower, x1_upper)).
"""
footprints = [w.footprint().T for w in wcs_list]
domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints])
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T
output_bounding_box = []
for axis in ref_model.meta.wcs.output_frame.axes_order:
axis_min, axis_max = (
domain_bounds[axis].min(),
domain_bounds[axis].max(),
)
# populate output_bounding_box
output_bounding_box.append((axis_min, axis_max))
output_bounding_box = tuple(output_bounding_box)
return (axis_min_values, output_bounding_box)
def _calculate_fiducial(wcs_list, bounding_box, crval=None):
"""
Calculates the coordinates of the fiducial point and, if necessary, updates it with
the values in CRVAL (the update is applied to spatial axes only).
Parameters
----------
wcs_list : list
A list of WCS objects.
bounding_box : tuple, or list, optional
The bounding box over which the WCS is valid. It can be a either tuple of tuples
or a list of lists of size 2 where each element represents a range of
(low, high) values. The bounding_box is in the order of the axes, axes_order.
For two inputs and axes_order(0, 1) the bounding box can be either
((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]].
crval : list, optional
A reference world coordinate associated with the reference pixel. If not `None`,
then the fiducial coordinates of the spatial axes will be updated with the
values from ``crval``.
Returns
-------
fiducial : numpy.ndarray
A two-elements array containing the world coordinate of the fiducial point.
"""
fiducial = compute_fiducial(wcs_list, bounding_box=bounding_box)
if crval is not None:
i = 0
for k, axt in enumerate(wcs_list[0].output_frame.axes_type):
if axt == "SPATIAL":
# overwrite only spatial axes with user-provided CRVAL
fiducial[k] = crval[i]
i += 1
return fiducial
def _calculate_offsets(fiducial, wcs, axis_min_values, crpix):
"""
Calculates the offsets to the transform.
Parameters
----------
fiducial : numpy.ndarray
A two-elements containing the world coordinates of the fiducial point.
wcs : ~gwcs.wcs.WCS
A WCS object. It will be used to determine the
axis_min_values : numpy.ndarray
A two-elements array containing the minimum pixel value for each axis.
crpix : list or tuple
Pixel coordinates of the reference pixel.
Returns
-------
~astropy.modeling.Model
A model with the offsets to be added to the WCS's transform.
Notes
-----
If ``crpix=None``, then ``fiducial``, ``wcs``, and ``axis_min_values`` must be
provided, in which case, the offsets will be calculated using the WCS object to
find the pixel coordinates of the fiducial point and then correct it by the minimum
pixel value for each axis.
"""
if crpix is None and fiducial is not None and wcs is not None and axis_min_values is not None:
offset1, offset2 = wcs.backward_transform(*fiducial)
offset1 -= axis_min_values[0]
offset2 -= axis_min_values[1]
else:
offset1, offset2 = crpix
return astmodels.Shift(-offset1, name="crpix1") & astmodels.Shift(-offset2, name="crpix2")
def _calculate_new_wcs(ref_model, shape, wcs_list, fiducial, crpix=None, transform=None):
"""
Calculates a new WCS object based on the combined WCS objects provided.
Parameters
----------
ref_model :
The reference model to be used when extracting metadata.
shape : list
The shape of the new WCS's pixel grid. If `None`, then the output bounding box
will be used to determine it.
wcs_list : list
A list containing WCS objects.
fiducial : numpy.ndarray
A two-elements array containing the location on the sky in some standard
coordinate system.
crpix : tuple, optional
The coordinates of the reference pixel.
transform : ~astropy.modeling.Model
An optional transform to be prepended to the transform constructed by the
fiducial point. The number of outputs of this transform must equal the number
of axes in the coordinate frame.
Returns
-------
wcs_new : ~gwcs.wcs.WCS
The new WCS object that corresponds to the combined WCS objects in `wcs_list`.
"""
wcs_new = wcs_from_fiducial(
fiducial,
coordinate_frame=ref_model.meta.wcs.output_frame,
projection=astmodels.Pix2Sky_TAN(),
transform=transform,
input_frame=ref_model.meta.wcs.input_frame,
)
axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(ref_model, wcs_list, wcs_new)
offsets = _calculate_offsets(
fiducial=fiducial,
wcs=wcs_new,
axis_min_values=axis_min_values,
crpix=crpix,
)
wcs_new.insert_transform("detector", offsets, after=True)
wcs_new.bounding_box = output_bounding_box
if shape is None:
shape = [int(axs[1] - axs[0] + 0.5) for axs in output_bounding_box[::-1]]
wcs_new.pixel_shape = shape[::-1]
wcs_new.array_shape = shape
return wcs_new
def _validate_wcs_list(wcs_list):
"""
Validates wcs_list.
Parameters
----------
wcs_list : list
A list of WCS objects.
Returns
-------
bool or Exception
If wcs_list is valid, returns True. Otherwise, it will raise an error.
Raises
------
ValueError
Raised whenever wcs_list is not an iterable.
TypeError
Raised whenever wcs_list is empty or any of its content is not an
instance of WCS.
"""
if not isiterable(wcs_list):
msg = "Expected 'wcs_list' to be an iterable of WCS objects."
raise ValueError(msg)
if len(wcs_list):
if not all(isinstance(w, gwcs.WCS) for w in wcs_list):
msg = "All items in 'wcs_list' are to be instances of gwcs.wcs.WCS."
raise TypeError(msg)
else:
msg = "'wcs_list' should not be empty."
raise TypeError(msg)
return True
def wcsinfo_from_model(input_model: SupportsDataWithWcs) -> dict[str, np.ndarray | str | bool]:
"""
Creates a dict {wcs_keyword: array_of_values} pairs from a datamodel.
Parameters
----------
input_model :
The input datamodel.
Returns
-------
wcsinfo : dict
A dict containing the WCS FITS keywords and corresponding values.
"""
defaults = {
"CRPIX": 0,
"CRVAL": 0,
"CDELT": 1.0,
"CTYPE": "",
"CUNIT": u.Unit(""),
}
wcsaxes = input_model.meta.wcsinfo.wcsaxes
wcsinfo = {"WCSAXES": wcsaxes}
for key in ["CRPIX", "CRVAL", "CDELT", "CTYPE", "CUNIT"]:
val = []
for ax in range(1, wcsaxes + 1):
k = (key + f"{ax}").lower()
v = getattr(input_model.meta.wcsinfo, k, defaults[key])
val.append(v)
wcsinfo[key] = np.array(val)
pc = np.zeros((wcsaxes, wcsaxes), dtype=np.float32)
for i in range(1, wcsaxes + 1):
for j in range(1, wcsaxes + 1):
pc[i - 1, j - 1] = getattr(input_model.meta.wcsinfo, f"pc{i}_{j}", 1)
wcsinfo["PC"] = pc
wcsinfo["RADESYS"] = input_model.meta.coordinates.reference_frame
wcsinfo["has_cd"] = False
return wcsinfo
[docs]def compute_scale(
wcs: gwcs.WCS,
fiducial: tuple | np.ndarray,
disp_axis: int | None = None,
pscale_ratio: float | None = None,
) -> float:
"""Compute the scale at the fiducial point on the detector..
Parameters
----------
wcs : ~gwcs.wcs.WCS
Reference WCS object from which to compute a scaling factor.
fiducial : tuple
Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating
reference points.
disp_axis : int
Dispersion axis integer. Assumes the same convention as
``wcsinfo.dispersion_direction``
pscale_ratio : int
Ratio of input to output pixel scale
Returns
-------
scale : float
Scaling factor for x and y or cross-dispersion direction.
"""
spectral = "SPECTRAL" in wcs.output_frame.axes_type
if spectral and disp_axis is None:
msg = "If input WCS is spectral, a disp_axis must be given"
raise ValueError(msg)
crpix = np.array(wcs.invert(*fiducial))
delta = np.zeros_like(crpix)
spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0]
delta[spatial_idx[0]] = 1
crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T
crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)
coords = SkyCoord(
ra=crval_with_offsets[spatial_idx[0]],
dec=crval_with_offsets[spatial_idx[1]],
unit="deg",
)
xscale = np.abs(coords[0].separation(coords[1]).value)
yscale = np.abs(coords[0].separation(coords[2]).value)
if pscale_ratio is not None:
xscale *= pscale_ratio
yscale *= pscale_ratio
if spectral:
# Assuming scale doesn't change with wavelength
# Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
return yscale if disp_axis == 1 else xscale
return np.sqrt(xscale * yscale)
[docs]def compute_fiducial(wcslist: list, bounding_box: tuple | list | None = None) -> np.ndarray:
"""
Calculates the world coordinates of the fiducial point of a list of WCS objects.
For a celestial footprint this is the center. For a spectral footprint, it is the
beginning of its range.
Parameters
----------
wcslist : list
A list containing all the WCS objects for which the fiducial is to be
calculated.
bounding_box : tuple, list, None
The bounding box over which the WCS is valid. It can be a either tuple of tuples
or a list of lists of size 2 where each element represents a range of
(low, high) values. The bounding_box is in the order of the axes, axes_order.
For two inputs and axes_order(0, 1) the bounding box can be either
((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]].
Returns
-------
fiducial : numpy.ndarray
A two-elements array containing the world coordinates of the fiducial point
in the combined output coordinate frame.
Notes
-----
This function assumes all WCSs have the same output coordinate frame.
"""
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == "SPATIAL"
spectral_axes = np.array(axes_types) == "SPECTRAL"
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]
fiducial = np.empty(len(axes_types))
if spatial_footprint.any():
fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint(spatial_footprint)
if spectral_footprint.any():
fiducial[spectral_axes] = spectral_footprint.min()
return fiducial
[docs]def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) -> list[float]:
r"""Calculate the rotation matrix.
Parameters
----------
roll_ref : float
Telescope roll angle of V3 North over East at the ref. point in radians
v3i_yangle : float
The angle between ideal Y-axis and V3 in radians.
vparity : int
The x-axis parity, usually taken from the JWST SIAF parameter VIdlParity.
Value should be "1" or "-1".
Returns
-------
matrix: list
A list containing the rotation matrix elements in column order.
Notes
-----
The rotation matrix is
.. math::
PC = \\begin{bmatrix}
pc_{1,1} & pc_{2,1} \\\\
pc_{1,2} & pc_{2,2}
\\end{bmatrix}
"""
if vparity not in (1, -1):
msg = f"vparity should be 1 or -1. Input was: {vparity}"
raise ValueError(msg)
rel_angle = roll_ref - (vparity * v3i_yangle)
pc1_1 = vparity * np.cos(rel_angle)
pc1_2 = np.sin(rel_angle)
pc2_1 = vparity * -np.sin(rel_angle)
pc2_2 = np.cos(rel_angle)
return [pc1_1, pc1_2, pc2_1, pc2_2]
def update_s_region_imaging(model, center=True):
"""
Update the ``S_REGION`` keyword using ``WCS.footprint``.
Parameters
----------
model :
The input datamodel.
center : bool, optional
Whether or not to use the center of the pixel as reference for the
coordinates, by default True
"""
bbox = model.meta.wcs.bounding_box
if bbox is None:
bbox = wcs_bbox_from_shape(model.data.shape)
# footprint is an array of shape (2, 4) as we
# are interested only in the footprint on the sky
### TODO: we shouldn't use center=True in the call below because we want to
### calculate the coordinates of the footprint based on the *bounding box*,
### which means we are interested in each pixel's vertice, not its center.
### By using center=True, a difference of 0.5 pixel should be accounted for
### when comparing the world coordinates of the bounding box and the footprint.
footprint = model.meta.wcs.footprint(bbox, center=center, axis_type="spatial").T
# take only imaging footprint
footprint = footprint[:2, :]
# Make sure RA values are all positive
negative_ind = footprint[0] < 0
if negative_ind.any():
footprint[0][negative_ind] = 360 + footprint[0][negative_ind]
footprint = footprint.T
update_s_region_keyword(model, footprint)
def wcs_bbox_from_shape(shape):
"""Create a bounding box from the shape of the data.
This is appropriate to attach to a wcs object
Parameters
----------
shape : tuple
The shape attribute from a `numpy.ndarray` array
Returns
-------
bbox : tuple
Bounding box in x, y order.
"""
return (-0.5, shape[-1] - 0.5), (-0.5, shape[-2] - 0.5)
def update_s_region_keyword(model, footprint):
"""Update the S_REGION keyword.
Parameters
----------
model :
The input model
footprint : numpy.array
A 4x2 numpy array containing the coordinates of the vertices of the footprint.
Returns
-------
s_region : str
String containing the S_REGION object.
"""
s_region = "POLYGON ICRS {:.9f} {:.9f} {:.9f} {:.9f} {:.9f} {:.9f} {:.9f} {:.9f}".format(
*footprint.flatten()
)
if "nan" in s_region:
# do not update s_region if there are NaNs.
log.info("There are NaNs in s_region, S_REGION not updated.")
else:
model.meta.wcsinfo.s_region = s_region
log.info("Update S_REGION to %s", model.meta.wcsinfo.s_region)
[docs]def reproject(wcs1, wcs2):
"""
Given two WCSs or transforms return a function which takes pixel
coordinates in the first WCS or transform and computes them in pixel coordinates
in the second one. It performs the forward transformation of ``wcs1`` followed by the
inverse of ``wcs2``.
Parameters
----------
wcs1 : astropy.wcs.WCS or gwcs.wcs.WCS
Input WCS objects or transforms.
wcs2 : astropy.wcs.WCS or gwcs.wcs.WCS
Output WCS objects or transforms.
Returns
-------
Function to compute the transformations. It takes x, y
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""
def _get_forward_transform_func(wcs1):
"""Get the forward transform function from the input WCS. If the wcs is a
fitswcs.WCS object all_pix2world requires three inputs, the x (str, ndarrray),
y (str, ndarray), and origin (int). The origin should be between 0, and 1
https://docs.astropy.org/en/latest/wcs/index.html#loading-wcs-information-from-a-fits-file
).
"""
if isinstance(wcs1, fitswcs.WCS):
forward_transform = wcs1.all_pix2world
elif isinstance(wcs1, gwcs.WCS):
forward_transform = wcs1.forward_transform
else:
msg = "Expected input to be astropy.wcs.WCS or gwcs.WCS object"
raise TypeError(msg)
return forward_transform
def _get_backward_transform_func(wcs2):
if isinstance(wcs2, fitswcs.WCS):
backward_transform = wcs2.all_world2pix
elif isinstance(wcs2, gwcs.WCS):
backward_transform = wcs2.backward_transform
else:
msg = "Expected input to be astropy.wcs.WCS or gwcs.WCS object"
raise TypeError(msg)
return backward_transform
def _reproject(x: float | np.ndarray, y: float | np.ndarray) -> tuple:
"""
Reprojects the input coordinates from one WCS to another.
Parameters
----------
x : float or np.ndarray
x-coordinate(s) to be reprojected.
y : float or np.ndarray
y-coordinate(s) to be reprojected.
Returns
-------
tuple
Tuple of np.ndarrays including reprojected x and y coordinates.
"""
# example inputs to resulting function (12, 13, 0) # third number is origin
# uses np.arrays for shape functionality
if not isinstance(x, (np.ndarray)):
x = np.array(x)
if not isinstance(y, (np.ndarray)):
y = np.array(y)
if x.shape != y.shape:
msg = "x and y must be the same length"
raise ValueError(msg)
sky = _get_forward_transform_func(wcs1)(x, y, 0)
# rearrange into array including flattened x and y values
flat_sky = [axis.flatten() for axis in sky]
det = np.array(_get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0))
det_reshaped = [axis.reshape(x.shape) for axis in det]
return tuple(det_reshaped)
return _reproject