import copy
from dataclasses import dataclass
from typing import Literal, Union
import matplotlib.pyplot as plt
import numpy as np
from astropy.utils import lazyproperty
from matplotlib.patches import Circle, Ellipse
from photutils.aperture import *
from photutils.isophote import Ellipse as IsoEllipse
from photutils.isophote import EllipseGeometry
color = [0.51, 0.86, 1.0]
__all__ = [
"Source",
"PointSource",
"ExtendedSource",
"TraceSource",
"auto_source",
"Sources",
]
def distance(p1, p2):
return np.sqrt(np.power(p1[0] - p2[0], 2) + np.power(p1[1] - p2[1], 2))
def clean_stars_positions(positions, tolerance=50):
keep = []
distance_to_others = np.array(
[[distance(v, w) for w in positions] for v in positions]
)
for i, _distances in enumerate(distance_to_others):
_distances[i] = np.inf
close_stars = np.flatnonzero(_distances < tolerance)
if len(close_stars) == 0:
keep.append(i)
return np.unique(keep)
# Note: Why not using photutils.segmentation.SourceCatalog?
# source: https://photutils.readthedocs.io/en/stable/api/photutils.segmentation.SourceCatalog.html#photutils.segmentation.SourceCatalog
#
# Main reason is full control and no need to subclass SourceCatalog. Reasons:
# - Ability to Source.plot and Source.aperture differently depending of the type of source
# - Ability to easily instantiate a fake/incomplete source only defined by its coords (output of many detection algorithms like DAOPHOT)
# - We will use it, as region so that users have access to it if needed
# - I don't like this as_scalar behavior, I prefer separate Source and Sources
@dataclass
class Source:
"""A object containing a source information
This is a Python Data Class, so that most attributes described below can be used as
keyword-arguments when instantiated
"""
a: float = 1.0
"""Semi-major axis of the source"""
b: float = 1.0
"""Semi-minor axis of the source"""
orientation: float = 0.0
"""Orientation of the source in radians"""
coords: np.ndarray = None
"""(x,y) pixel coordinates of the source"""
peak: float = 0.0
"""Peak ADU value of the source"""
i: int = None
"""Index of the source"""
discarded: bool = False
"""Whether source is discarded"""
@classmethod
def from_region(cls, region, keep_region: bool = False, **kwargs):
"""Source from region
Parameters
----------
region : skimage.measure.RegionProperties
An skimage RegionProperties containing the source
keep_region: bool, optional
whether to keep region object in source
**kwargs:
other sources attributes to set
"""
source = cls(
a=region.axis_major_length / 2,
b=region.axis_minor_length / 2,
orientation=np.pi / 2 - region.orientation,
coords=np.array(region.centroid_weighted[::-1]),
peak=region.intensity_max,
**kwargs,
)
return source
@property
def vertexes(self):
"""Coordinates of the Ellipse vertexes, endpoints of the major axis
Returns
-------
np.array
vertexes coordinates
"""
theta = self.orientation
shifts = np.array([np.cos(theta), np.sin(theta)]) * self.a
return self.coords + (shifts[:, None] * [-1, 1]).T
@property
def co_vertexes(self):
"""Coordinates of the Ellipse co-vertexes, endpoints of the minor axis
Returns
-------
np.array
co-vertexes coordinates
"""
theta = self.orientation + np.pi / 2
shifts = np.array([np.cos(theta), np.sin(theta)]) * self.b
return self.coords + (shifts[:, None] * [-1, 1]).T
@lazyproperty
def eccentricity(self):
"""Eccentricity of the source
Returns
-------
float
"""
return self.b / self.a
def copy(self):
"""Return a copy of the Source
Returns
-------
Source
copy
"""
copy = self.__class__()
copy.a = self.a
copy.b = self.b
copy.peak = self.peak
copy.orientation = self.orientation
copy.i = self.i
copy.coords = self.coords.copy()
return copy
def __copy__(self):
return self.copy()
def plot_circle(self, radius, c=color, ax=None, label=True, fontsize=12, **kwargs):
"""Plot a circle centered on source
Parameters
----------
radius : float
radii of the circle in pixels
c : str, optional
color of the circle, by default color
ax : Axe, optional
pyplot axe in which to plot the circle, by default None
label : bool, optional
whether to display the Source.i index, by default True
fontsize : int, optional
Font size for the source index, by default 12
"""
if ax is None:
ax = plt.gca()
circle = Circle(self.coords, radius, fill=None, ec=c, **kwargs)
ax.add_artist(circle)
if label and self.i is not None:
plt.text(
*(np.array(self.coords) - [0, 1.5 * radius]),
self.i,
c=c,
ha="center",
va="top",
fontsize=fontsize,
)
def plot_ellipse(self, a=None, c=color, ax=None, label=True, fontsize=12, **kwargs):
"""Plot an ellipse centered on source, with semi-major/minor length defined by the source itself
Parameters
----------
n : float
offset added to the major and minor axis (major axis of the plotted ellipse will be `Source.a + n`)
c : str, optional
color of the circle, by default color
ax : Axe, optional
pyplot axe in which to plot the circle, by default None
label : bool, optional
whether to display the Source.i index, by default True
fontsize : int, optional
Font size for the source index, by default 12
"""
if ax is None:
ax = plt.gca()
if a is None:
a = 2 * self.a * 1.1
ax = plt.gca()
e = Ellipse(
xy=self.coords,
width=a,
height=a * self.eccentricity,
angle=np.rad2deg(self.orientation),
**kwargs,
)
e.set_facecolor("none")
e.set_edgecolor(c)
ax.add_artist(e)
if label and self.i is not None:
rad = self.orientation
label_coord = self.coords + [0, -(np.abs(self.a * rad) + self.b)]
plt.text(
*label_coord, self.i, c=c, ha="center", va="top", fontsize=fontsize
)
def circular_aperture(self, r, scale=True):
"""`photutils.aperture.CircularAperture` centered on the source
Parameters
----------
r : float
radius
scale : bool, optional
whether to scale r to Source.a, by default True
Returns
-------
photutils.aperture.CircularAperture
"""
if scale:
radius = r * self.a
else:
radius = r
return CircularAperture(self.coords, float(np.abs(radius)))
def elliptical_aperture(self, r, scale=True):
"""`photutils.aperture.EllipticalAperture` centered on the source
Parameters
----------
r : float
semi-major axis of the aperture. Semi minor will be `r*Source.b/Source.a`
scale : bool, optional
whether to scale r to Source.a, by default True
Returns
-------
photutils.aperture.CircularAperture
"""
if scale:
a, b = r * self.a, r * self.b
else:
a, b = r, r * self.eccentricity
return EllipticalAperture(self.coords, a, b, self.orientation)
def rectangular_aperture(self, r, scale=True):
if scale:
a, b = 2 * r * self.a, 2 * r * self.b
else:
a, b = 2 * r, 2 * r * self.eccentricity
a = np.max([0.01, a])
b = np.max([0.01, b])
return RectangularAperture(
self.coords, float(np.abs(a)), float(np.abs(b)), self.orientation
)
def circular_annulus(self, r0, r1, scale=False):
if scale:
r0 = r0 * self.a
r1 = r1 * self.a
else:
r0 = r0
r1 = r1
return CircularAnnulus(self.coords, r0, r1)
def elliptical_annulus(self, r0, r1, scale=False):
if scale:
a0 = r0 * self.a
a1, b1 = r1 * self.a, r1 * self.b
else:
a0 = (r0,)
a1, b1 = r1, r1 * self.eccentricity
return EllipticalAnnulus(self.coords, a0, a1, b1, theta=self.orientation)
def rectangular_annulus(self, r0, r1, scale=False):
if scale:
a0 = 2 * r0 * self.a
a1, b1 = 2 * r1 * self.a, 2 * r1 * self.b
else:
a0 = r0
a1, b1 = r1, r1 * self.eccentricity
a0 = np.max([0.01, a0])
a1 = np.max([a0 + 0.001, a1])
b1 = np.max([0.01, b1])
return RectangularAnnulus(self.coords, a0, a1, b1, theta=self.orientation)
def fit_isophotes(self, debug=False):
"""Fit a photutils.isophote.Ellipse to the source. Requires the source to be instantiated from a skimage RegionProperties
Parameters
----------
debug : bool, optional
whether to plot the result for debugging, by default False
Returns
-------
output of photutils.isophote.Ellipse.fit_image
"""
data = self._region.image_intensity
y0, x0 = np.unravel_index(np.argmax(data), data.shape)
geometry = EllipseGeometry(
x0, y0, sma=self.a / 2, eps=self.eccentricity, pa=self.orientation
)
ellipse = IsoEllipse(data - np.median(data), geometry)
isolist = ellipse.fit_image()
if debug:
plt.imshow(data)
smas = np.linspace(3, 20, 15)
for sma in smas:
iso = isolist.get_closest(sma)
(
x,
y,
) = iso.sampled_coordinates()
plt.plot(x, y, color="white")
return isolist
@property
def _symbol(self):
return "?"
@property
def _desc(self):
return (
f"{self._symbol} {self.__class__.__name__}" + f" {self.i}"
if self.i is not None
else ""
)
def _repr_dict(self, n=8):
return {
"coords": f"{self.coords[0]:.2f}".rjust(n)
+ f"{self.coords[1]:.2f}".rjust(n),
"a, b": f"{self.a:.2f}".rjust(n) + f"{self.b:.2f}".rjust(n),
"e": f"{self.b/self.a:.2f}".rjust(n),
}
def __str__(self):
table = "\n".join(
[f" {n}".ljust(8) + f"{v}" for n, v in self._repr_dict().items()]
)
return f"{self._desc}\n {'-'*(len(self._desc)-2)}\n{table}"
def centroid_isophote(self):
isolist = self.fit_isophotes()
origin = np.array(self._region.bbox)[0:2][::-1]
return np.array([isolist[0].x0, isolist[0].y0]) + origin
def centroid_max(self):
y0, x0 = np.unravel_index(
np.argmax(self._region.image_intensity), self._region.image.shape
)
dy, dx, _, _ = self._region.bbox
return np.array([x0 + dx, y0 + dy])
@property
def area(self):
"""Area of the source as :code:`a*b`
Returns
-------
float
"""
return self.a * self.b
def auto_source(region, i=None, trace=0.3, extended=0.9, discard=False):
if region is None:
return DiscardedSource.from_region(region, i=i)
a = region.axis_major_length
b = region.axis_minor_length
if a == 0.0:
if discard:
return DiscardedSource.from_region(region, i=i)
else:
return PointSource.from_region(region, i=i)
eccentricity = b / a
if eccentricity <= extended:
if eccentricity <= trace:
return TraceSource.from_region(region, i=i)
else:
return ExtendedSource.from_region(region, i=i)
else:
return PointSource.from_region(region, i=i)
class DiscardedSource(Source):
def __init__(self, region, i=None):
super().__init__(region, i=i)
self.discarded = True
def plot(self, ms=15, c="C0", ax=None, **kwargs):
if ax is None:
ax = plt.gca()
ax.plot(*self.coords, "x", c=c, ms=ms, **kwargs)
[docs]class PointSource(Source):
"""Point source (star)"""
@property
def _symbol(self):
return chr(8226)
[docs] def plot(self, radius=15, **kwargs):
"""Plot circle centered on source
Parameters
----------
radius : int, optional
radius, by default 15
"""
self.plot_circle(radius, **kwargs)
[docs] def aperture(self, r=1, scale=True):
return self.circular_aperture(r, scale=scale)
[docs] def annulus(self, r0=1.05, r1=1.4, scale=True):
return self.circular_annulus(r0, r1, scale=scale)
[docs]class ExtendedSource(Source):
"""Extended source (comet, galaxy or lensed source)"""
@property
def _symbol(self):
return chr(11053)
[docs] def plot(self, radius=None, **kwargs):
"""Plot Ellipse on source
Parameters
----------
radius : int, optional
extension to minor/major axis, by default 6
"""
self.plot_ellipse(radius, **kwargs)
[docs] def aperture(self, r=1, scale=True):
return self.elliptical_aperture(r, scale=scale)
[docs] def annulus(self, r0=1.05, r1=1.4, scale=True):
return self.elliptical_annulus(r0, r1, scale=scale)
[docs]class TraceSource(Source):
"""Trace source (diffracted spectrum, satellite streak or cosmic ray)"""
[docs] def plot(self, offset=10, ax=None, c=color, label=True, fontsize=12):
if ax is None:
ax = plt.gca()
ax.plot(*self.vertexes.T, c=c)
if label and self.i is not None:
label_coords = self.coords + [0, -offset]
plt.text(
*label_coords, self.i, c=c, ha="center", va="top", fontsize=fontsize
)
[docs] def aperture(self, r=1, scale=True):
return self.rectangular_aperture(r, scale=scale)
[docs] def annulus(self, r0=1.05, r1=1.4, scale=True):
return self.rectangular_annulus(r0, r1, scale=scale)
@dataclass
class Sources:
sources: list = None
"""List of sources"""
type: Literal["PointSource", None] = None
"""Source type"""
source_type: Literal["PointSource", None] = None
"""Legacy source type"""
def __post_init__(self):
if self.sources is None:
self.sources = []
self.type = self.source_type
if isinstance(self.sources, np.ndarray):
if self.sources.dtype != object:
self.sources = [
PointSource(coords=s, i=i) for i, s in enumerate(self.sources)
]
self.type = "PointSource"
if self.type is not None:
for s in self.sources:
assert (
s.__class__.__name__ == self.type
), f"list can only contain {self.type}"
self.sources = np.array(self.sources)
def __getitem__(self, i):
if np.isscalar(i):
i = int(i)
return self.sources[i]
else:
return self.__class__(self.sources[i])
def __len__(self):
return len(self.sources)
def __str__(self):
return str(self.sources)
def __repr__(self):
return self.sources.__repr__()
def copy(self):
return copy.deepcopy(self)
def __copy__(self):
return self.copy()
@property
def coords(self):
return np.array([source.coords for source in self.sources])
@coords.setter
def coords(self, new_coords):
for source, new_coord in zip(self.sources, new_coords):
source.coords = new_coord
def apertures(self, r, scale=False):
if self.type == "PointSource":
return CircularAperture(self.coords, r)
else:
return [source.aperture(r, scale=scale) for source in self.sources]
def annulus(self, rin, rout, scale=False):
if self.type == "PointSource":
return CircularAnnulus(self.coords, rin, rout)
else:
return [source.annulus(rin, rout, scale=scale) for source in self.sources]
def plot(self, *args, **kwargs):
for s in self.sources:
s.plot(*args, **kwargs)