import pickle
from copy import deepcopy
from dataclasses import asdict, dataclass
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import Angle, SkyCoord
from astropy.io import fits
from astropy.io.fits.hdu.base import _BaseHDU
from astropy.io.fits.header import Header
from astropy.nddata import Cutout2D as astopy_Cutout2D
from astropy.nddata import overlap_slices
from astropy.time import Time
from astropy.wcs import WCS
from astropy.wcs.wcs import WCS
from dateutil import parser as dparser
from matplotlib import gridspec
from prose import utils, viz
from prose.core.source import Sources
from prose.telescope import Telescope
[docs]@dataclass
class Image:
"""
Image object containing image data and metadata.
This is a Python Data Class, so that most attributes described below can be used as
keyword-arguments when instantiated.
"""
data: Optional[np.ndarray] = None
"""Image data"""
metadata: Optional[dict] = None
"""Image metadata"""
catalogs: Optional[dict] = None
"""Catalogs associated with the image contained in a dictionary of
pandas dataframes"""
_sources: Optional[Union[Sources, dict]] = None
origin: tuple = (0, 0)
"""Image origin"""
discard: bool = False
"""Whether image as been discarded by a block"""
computed: Optional[dict] = None
"""A dictionary containing any user and block-defined attributes"""
header: Optional[Header] = None
"""FITS header associated with the image (optional)"""
_wcs = None
def __post_init__(self):
assert (
isinstance(self.data, np.ndarray) or self.data is None
), f"data must be a np.ndarray, not {type(self.data)}"
if self.metadata is None:
self.metadata = {}
if self.catalogs is None:
self.catalogs = {}
if self.computed is None:
self.computed = {}
# backward compatibility
self.header = self.computed.get("fits_header", None)
if self.header is None:
self.header = Header()
if isinstance(self._sources, dict):
self._sources = Sources(**self._sources)
if self._sources is None:
self._sources = Sources([])
def __setattr__(self, name, value):
if hasattr(self, name):
super().__setattr__(name, value)
else:
if "computed" in self.__dict__:
self.computed[name] = value
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if "computed" not in self.__dict__:
super.__getattr__(self, name)
else:
if name in self.computed:
return self.computed[name]
else:
raise AttributeError(f"Image has no '{name}'")
[docs] def copy(self, data=True):
"""Copy of image object
Parameters
----------
data : bool, optional
whether to copy data, by default True
Returns
-------
Image
copied object
"""
new_self = deepcopy(self)
return new_self
def __copy__(self):
return self.copy()
[docs] def show(
self,
cmap="Greys_r",
ax=None,
figsize=8,
zscale=True,
frame=False,
contrast=0.1,
sources=True,
**kwargs,
):
"""Show image data
Parameters
----------
cmap : str, optional
matplotlib colormap, by default "Greys_r"
ax : subplot, optional
matplotlbib Axes in which to plot, by default None
figsize : tuple, optional
matplotlib figure size if ax not sepcified, by default (10,10)
stars : bool, optional
whether to show ``Image.stars_coords``, by default None
stars_labels : bool, optional
whether top show stars indexes, by default True
zscale : bool, optional
whether to apply a z scale to plotted image data, by default False
frame : bool, optional
whether to show astronomical coordinates axes, by default False
contrast : float, optional
image contrast used in image scaling, by default 0.1
ms: int
stars markers size
ft: int
stars label font size
See also
--------
show_cutout :
Show a specific star cutout
plot_catalog :
Plot catalog stars on an image
plot_circle :
Plot circle with radius in astronomical units
"""
if ax is None:
if not isinstance(figsize, (list, tuple)):
if isinstance(figsize, (float, int)):
figsize = (figsize, figsize)
else:
raise TypeError("figsize must be tuple or list or float or int")
fig = plt.figure(figsize=figsize)
if frame:
ax = fig.add_subplot(111, projection=self.wcs)
else:
ax = fig.add_subplot(111)
if zscale is False:
vmin = np.nanmedian(self.data)
vmax = vmax = vmin * (1 + contrast) / (1 - contrast)
_ = ax.imshow(
self.data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax, **kwargs
)
else:
_ = ax.imshow(
utils.z_scale(self.data, contrast), cmap=cmap, origin="lower", **kwargs
)
if frame:
overlay = ax.get_coords_overlay(self.wcs)
overlay.grid(color="white", ls="dotted")
overlay[0].set_axislabel("Right Ascension (J2000)")
overlay[1].set_axislabel("Declination (J2000)")
if sources:
if self.sources is not None:
self.sources.plot()
ax.set_xlim(0, self.shape[1] - 1)
ax.set_ylim(0, self.shape[0] - 1)
self._wcs = None
def _from_metadata_with_unit(self, name):
unit_name = f"{name}_unit"
value = self.metadata[name]
unit = str_to_astropy_unit(self.metadata[unit_name])
if name in ["ra", "dec"]:
if value is not None:
return Angle(value, unit).to(u.deg)
else:
return None
if value is not None:
return value * unit
else:
return None
@property
def shape(self):
"""Image.data shape"""
return np.array(self.data.shape)
@property
def ra(self):
"""Right-Ascension as an astropy Quantity"""
return self._from_metadata_with_unit("ra")
@property
def dec(self):
"""Declination as an astropy Quantity"""
return self._from_metadata_with_unit("dec")
@property
def exposure(self):
"""Exposure time as an astropy Quantity"""
if "exposure" in self.metadata:
return self._from_metadata_with_unit("exposure")
else:
return None
@property
def jd(self):
"""Julian Date of the observation"""
return self.metadata["jd"]
@property
def pixel_scale(self):
"""Pixel scale (or plate scale) as an astropy Quantity"""
return self._from_metadata_with_unit("pixel_scale")
@property
def filter(self):
"""Filter name"""
return self.metadata["filter"]
@property
def gain(self):
"""Gain of the camera"""
return self.metadata.get("gain", 1.0)
@property
def read_noise(self):
"""Read noise of the camera"""
return self.metadata.get("read_noise", 0.0)
@property
def fov(self):
"""RA-DEC field of view of the image in degrees
Returns
-------
astropy.units Quantity
"""
return np.array(self.shape)[::-1] * self.pixel_scale.to(u.deg)
@property
def fits_header(self):
"""Same as :code:`header` (backward compatibility)"""
return self.computed.get("fits_header", self.header)
@property
def date(self):
"""datetime of the observation
Returns
-------
datetime.datetime
"""
return dparser.parse(self.metadata["date"])
@property
def night_date(self):
"""date of the night when night started.
Returns
-------
datetime.date
"""
# TODO: do according to last astronomical twilight?
return (self.date - timedelta(hours=15)).date()
[docs] def set(self, name: str, value):
"""Set a computed value
Parameters
----------
name : str
name of the computed value
value : any
value to set
"""
self.computed[name] = value
[docs] def get(self, name):
"""Get computed value
Parameters
----------
name : str
name of the computed value
Returns
-------
any
computed value
"""
return self.computed[name]
@property
def sources(self) -> Sources:
"""Image sources.
Returns
-------
prose.core.source.Sources
"""
return self._sources if self._sources is not None else Sources()
@sources.setter
def sources(self, new_sources):
if new_sources is None:
self._sources = None
elif isinstance(new_sources, Sources):
self._sources = new_sources
else:
self._sources = Sources(np.array(new_sources))
[docs] def cutout(
self,
coords: Union[list, tuple, np.ndarray],
shape: Union[int, tuple],
wcs: bool = True,
sources: bool = True,
reset_index: bool = True,
):
"""Return a cutout Image instance.
Parameters
----------
coords : list, tuple, np.ndarray
cutout center coordinates
shape : tuple or int
The shape of the cutouts to extract. If int, shape is (shape, shape)
wcs : bool, optional
whether to compute and include cutouts WCS (takes more time), by default True
sources : bool, optional
whether to compute and include cutouts sources, by default True
reset_index: bool,
whether to reset the sources indexes, by default True
Returns
-------
Image
Image instance with data and catalogs containing cutout data and sources
"""
if isinstance(shape, (int, float)):
shape = (shape, shape)
if isinstance(coords, int):
coords = self.sources.coords[coords]
new_image = astopy_Cutout2D(
self.data,
coords,
shape,
wcs=self.wcs if wcs else None,
fill_value=np.nan,
mode="partial",
)
(y0, _), (x0, _) = new_image.bbox_original
# get sources
new_sources = []
if sources:
if len(self._sources) > 0:
sources_in = np.all(
np.abs(self.sources.coords - coords) < np.array(shape)[::-1] / 2, 1
)
_sources = self._sources[sources_in]
for s in _sources:
_s = s.copy()
_s.coords = _s.coords - [x0, y0]
new_sources.append(_s)
image = Image(new_image.data, deepcopy(self.metadata), deepcopy(self.computed))
image._sources = Sources(new_sources)
image.wcs = new_image.wcs
image.origin = tuple(np.array(new_image.bbox_original).T[0][::-1])
image.catalogs = deepcopy(self.catalogs)
for name, catalog in image.catalogs.items():
image.catalogs[name][["x", "y"]] -= coords - np.array(shape) / 2
# xy = catalog[["x", "y"]].values
# idxs = np.all((xy - coords / 2) < np.array(shape) / 2, 1)
# image.catalogs[name] = catalog[idxs]
if reset_index:
for i, s in enumerate(image.sources):
s.i = i
return image
@property
def wcs(self):
"""astropy.wcs.WCS object associated with the FITS ``Image.header``."""
if self._wcs is None:
self._wcs = WCS(self.metadata.get("wcs", None))
return self._wcs
@wcs.setter
def wcs(self, new_wcs):
if new_wcs is not None:
if isinstance(new_wcs, WCS):
self.metadata["wcs"] = new_wcs.to_header().tostring()
self._wcs = new_wcs
@property
def plate_solved(self):
"""Return whether the image is plate solved."""
return self.wcs.has_celestial
[docs] def writeto(self, destination: Union[str, Path]):
"""Write image to FITS file
Parameters
----------
destination : Union[str, Path]
destination path
"""
hdu = fits.PrimaryHDU(
data=self.data, header=fits.Header(utils.clean_header(self.header))
)
hdu.writeto(destination, overwrite=True)
@property
def skycoord(self):
"""Astropy SkyCoord object based on header RAn, DEC."""
return SkyCoord(self.ra, self.dec, frame="icrs")
[docs] def plot_catalog(self, name, color="y", label=False, n=100000):
"""Plot catalog stars
must be over :py:class:`Image.show` or :py:class:`Image.show_cutout` plot
Parameters
----------
name : str
catalog name as stored in :py:class:Image.catalog`
color : str, optional
color of stars markers, by default "y"
label : bool, optional
whether to show stars catalogs ids, by default False
n : int, optional
number of brightest catalog stars to show, by default 100000
"""
assert (
name in self.catalogs
), f"Catalog '{name}' not present, consider using ..."
x, y = self.catalogs[name][["x", "y"]].values.T
labels = self.catalogs[name]["id"].values if label else None
viz.plot_marks(x, y, labels, color=color)
[docs] def plot_model(self, data, figsize=(5, 5), cmap=None, c="C0", contour=False):
"""
Plot the data and a model side by side.
Parameters
----------
data : numpy.ndarray
The model data to plot.
figsize : tuple, optional
The size of the figure, by default (5, 5).
cmap : str or matplotlib.colors.Colormap, optional
The colormap to use for the data, by default None.
c : str, optional
The color to use for the data, by default "C0".
contour : bool, optional
Whether to plot the contours of the model, by default False.
Returns
-------
None
"""
plt.figure(figsize=figsize)
axes = gridspec.GridSpec(2, 2, width_ratios=[9, 2], height_ratios=[2, 9])
axes.update(wspace=0, hspace=0)
# axtt = plt.subplot(gs[1, 1])
ax = plt.subplot(axes[1, 0])
axr = plt.subplot(axes[1, 1], sharey=ax)
axt = plt.subplot(axes[0, 0], sharex=ax)
ax.imshow(self.data, alpha=1, cmap=cmap, origin="lower")
if contour:
ax.contour(data, colors="w", alpha=0.7)
x, y = np.indices(data.shape)
axt.plot(y[0], np.mean(self.data, axis=0), c=c, label="data")
axt.plot(y[0], np.mean(data, axis=0), "--", c="k", label="model")
axt.axis("off")
axt.legend()
axr.plot(np.mean(self.data, axis=1), y[0], c=c)
axr.plot(np.mean(data, axis=1), y[0], "--", c="k")
axr.axis("off")
[docs] def asdict(self, image_dtype="int16", low_data=True):
im_dict = asdict(self.copy())
if low_data:
im_dict["data"] = utils.z_scale(im_dict["data"]) * (2**7 - 1)
image_dtype = "int8"
im_dict["data"] = im_dict["data"].astype(image_dtype)
return im_dict
[docs] def save(self, filepath, image_dtype="float64", low_data=False):
"""
Save the image to a file using pickle.
Note that the pickle will hold the Image dataclass dict attributes.
Parameters
----------
filepath : str
The path to the file to save.
image_dtype : str, optional
The data type to use for the image data, by default "int16".
low_data : bool, optional
Whether to scale the data to a lower range, by default True.
Returns
-------
None
"""
with open(filepath, "wb") as f:
pickle.dump(self.asdict(image_dtype=image_dtype, low_data=low_data), f)
[docs] @classmethod
def load(cls, filepath):
"""
Load an image from a file.
Parameters
----------
filepath : str
The path to the file to load.
Returns
-------
Image
The loaded image.
"""
return cls(**pickle.load(open(filepath, "rb")))
def _symetric_profile(self, source, binn=1.0):
x, y = source.coords
Y, X = np.indices(self.shape)
radii = (np.sqrt((X - x) ** 2 + (Y - y) ** 2)).flatten()
d, values = self._profile(radii)
idxs = utils.index_binning(d, binn)
mean = lambda x: np.array([np.mean(x[i]) for i in idxs])
return mean(d), mean(values)
def _profile(self, d):
idxs = np.argsort(d)
_d = d[idxs]
pixels = self.data.flatten()
pixels = pixels[idxs]
return _d, pixels
[docs] def data_cutouts(
self, sources: Union[np.ndarray, Sources], shape: Union[int, tuple]
) -> np.ndarray:
"""Extract cutouts from image.
Parameters
----------
sources : Union[np.ndarray, Sources]
Coordinates (or Sources) of the cutouts centers
shape : Union[int, tuple]
Shape of the cutouts
Returns
-------
np.ndarray
cutouts
"""
if isinstance(sources, Sources):
sources = sources.coords
if isinstance(shape, int):
shape = (shape, shape)
cutouts = []
for x, y in sources:
c = np.zeros(shape)
large, small = overlap_slices(self.shape, shape, (y, x))
c[small] = self.data[large]
cutouts.append(c)
return np.array(cutouts)
def _major_profile(self, source, binn=1.0, debug=False):
p1 = source.coords[:, None, None]
p2 = (source.vertexes[0])[:, None, None]
Y, X = np.indices(self.data.shape)
p3 = np.array([X, Y])
# projection
# https://stackoverflow.com/questions/61341712/calculate-projected-point-location-x-y-on-given-line-startx-y-endx-y
l2 = np.sum((p1 - p2) ** 2)
assert l2 != 0, "p1 and p2 are the same points"
distances = np.sum((p3 - p1) * (p2 - p1), 0) / np.sqrt(l2)
flat_distance = distances.flatten()
idxs = utils.index_binning(flat_distance, binn)
distance = np.array([flat_distance[i].mean() for i in idxs])
values = np.array([np.nanmax(self.data.flatten()[i]) for i in idxs])
if debug:
D = np.zeros(self.data.flatten().shape)
for i, j in enumerate(idxs):
D[j] = i
plt.figure()
plt.imshow(np.reshape(D, self.shape), origin="lower")
return distance, values
@property
def label(self):
"""A conveniant {Telescope}_{Date}_{Object}_{Filter} string
Returns
-------
str
"""
return "_".join(
[
self.metadata["telescope"],
self.night_date.strftime("%Y%m%d"),
self.metadata["object"],
self.filter,
]
)
@property
def fits_header(self):
"""Same as :code:`header` (backward compatibility)"""
return self.header
def str_to_astropy_unit(unit_string):
return u.__dict__[unit_string]
[docs]def FITSImage(
filepath_or_hdu: Union[str, Path, _BaseHDU],
verbose: bool = False,
load_units: bool = True,
load_data: bool = True,
telescope: Telescope = None,
skip_wcs: bool = False,
) -> Image:
"""Create an image from a FITS file
Parameters
----------
filepath_or_hdu : str
path of fits file of HDU object
verbose : bool, optional
whether to be verbose, by default False
load_units : bool, optional
whether to load metadata units, by default True
load_data : bool, optional
whether to load image data, by default True
skip_wcs : bool, optional
whether to skip WCS loading, by default False
Returns
-------
:py:class:`~prose.Image`
"""
if isinstance(filepath_or_hdu, (str, Path)):
values = fits.getdata(filepath_or_hdu).astype(float) if load_data else None
header = fits.getheader(filepath_or_hdu)
path = filepath_or_hdu
elif issubclass(type(filepath_or_hdu), _BaseHDU):
values = filepath_or_hdu.data
header = filepath_or_hdu.header
path = None
else:
raise ValueError("filepath must be a str")
if telescope is None:
telescope = Telescope.from_names(
header.get("INSTRUME", ""), header.get("TELESCOP", ""), verbose=verbose
)
metadata = {
"telescope": telescope.name,
"exposure": header.get(telescope.keyword_exposure_time, None),
"ra": header.get(telescope.keyword_ra, None),
"dec": header.get(telescope.keyword_dec, None),
"filter": header.get(telescope.keyword_filter, None),
"date": telescope.date(header).isoformat(),
"jd": header.get(telescope.keyword_jd, None),
"object": header.get(telescope.keyword_object, None),
"pixel_scale": telescope.pixel_scale,
"gain": telescope.gain,
"read_noise": telescope.read_noise,
"overscan": telescope.trimming[::-1],
"path": path,
"dimensions": (header.get("NAXIS1", 1), header.get("NAXIS2", 1)),
"type": telescope.image_type(header),
}
if load_units:
metadata.update(
{
"exposure_unit": "s",
"ra_unit": telescope.ra_unit,
"dec_unit": telescope.dec_unit,
"jd_scale": telescope.jd_scale,
"pixel_scale_unit": "arcsec",
}
)
image = Image(values, metadata, {})
if image.metadata["jd"] is None:
image.metadata["jd"] = Time(image.date).jd
image.header = header
if not skip_wcs:
image.wcs = WCS(header)
image.telescope = telescope
return image
[docs]class Buffer:
[docs] def __init__(self, size: int, loader: callable = None):
"""Object to load and access adjacent items in a list
Parameters
----------
size : int
number of items accessible
loader : callable, optional
a function that load an item in the buffer, by default None corresponding
to lambda x: x
Example
-------
.. code-block:: python
from prose.core.image import Buffer
import numpy as np
# items to be loaded in the buffer
init = np.arange(0, 10)
# create and initialize
buffer = Buffer(size=3)
buffer.init(init)
for buffer in buffer:
print(buffer.previous, buffer.current, buffer.next)
.. code-block:: text
None 0 1
0 1 2
1 2 3
2 3 4
3 4 5
4 5 6
5 6 7
6 7 8
7 8 9
8 9 None
"""
assert size % 2 == 1, "size must be odd"
self.mid_index = int((size - 1) // 2)
self.items = [None] * max(size, 1)
if loader is None:
loader = lambda item: item
self.loader = loader
self.queue = None # items to be loaded
def __len__(self):
return len(self.items)
def __getitem__(self, i: int):
"""Get item by index relative to current
Parameters
----------
i : int
index
Returns
-------
Image or None
images[current + i]
"""
return self.items[self.mid_index + i]
def __setitem__(self, i: int, item: Image):
self.items[self.mid_index + i] = item
[docs] def append(self, item):
"""Add an item to the buffer (and delete last)
Parameters
----------
item : any
item to be loaded
"""
last_item = self.items.pop(0)
del last_item
self.items.append(item)
[docs] def init(self, items):
"""Prepare items to be loaded in the buffer.
The first items are loaded with the :code:`Buffer.loader` function
Parameters
----------
items : list
items to be loaded in the buffer
"""
for item in items[: self.mid_index]:
self.append(self.loader(item))
self.queue = [*items[self.mid_index :], *[None] * self.mid_index]
def __iter__(self):
for item in self.queue:
self.append(self.loader(item))
yield self
[docs] def sub(self, size, offset):
pass
@property
def previous(self):
return self[-1]
@property
def current(self):
return self[0]
@property
def next(self):
return self[1]