Source code for prose.blocks.geometry

from typing import Optional, Union

import numpy as np
from scipy.spatial import cKDTree
from skimage.transform import AffineTransform
from twirl import find_transform, quads
from twirl.geometry import get_transform_matrix, pad
from twirl.match import count_cross_match

from prose.core import Block, Image
from prose.utils import cross_match

__all__ = [
    "Trim",
    "Cutouts",
    "Drizzle",
    "ComputeTransformXYShift",
    "ComputeTransformTwirl",
]


[docs]class Trim(Block): """Image trimming If trim is not specified, triming is taken from the "overscan" in image metadata |write| ``Image.header`` |modify| Parameters ---------- skip_wcs : bool, optional whether to skip applying trim to WCS (If None: wcs is skipped only if image is not plated solved), by default None trim : tuple, int or flot, optional (x, y) trim values, by default None which uses the ``trim`` value from the image telescope definition. If an int or a float is provided trim will be be applied to both axes. """ def __init__(self, trim=None, skip_wcs=None, name=None, verbose=False): super().__init__(name, verbose) assert (skip_wcs is None) or isinstance( skip_wcs, bool ), "skip_wcs must be None or bool" self.skip_wcs = skip_wcs if isinstance(trim, (int, float)): trim = (trim, trim) self.trim = trim self._parallel_friendly = True def run(self, image): trim = self.trim if self.trim is not None else image.metadata["overscan"] center = image.shape[::-1] / 2 shape = image.shape - 2 * np.array(trim) if self.skip_wcs is None: skip_wcs = not image.plate_solved else: skip_wcs = self.skip_wcs cutout = image.cutout(center, shape, wcs=skip_wcs) image.data = cutout.data image.sources = cutout.sources image.wcs = cutout.wcs
[docs]class Cutouts(Block): def __init__( self, shape: Union[int, tuple] = 50, wcs: bool = False, name: Optional[str] = None, sources: bool = False, ): """Create cutouts around all sources |read| :code:`Image.sources` |write| :code:`Image.cutouts` Parameters ---------- shape : int or tuple, optional cutout shape, by default 50 wcs : bool, optional whether to compute cutouts WCS, by default False name : str, optional name of the blocks, by default None sources: bool, optional whether to keep sources in cutouts, by default False """ super().__init__(name=name, read=["sources"]) if isinstance(shape, int): shape = (shape, shape) self.shape = shape self.wcs = wcs self.sources = sources self._parallel_friendly = True def run(self, image: Image): image.cutouts = [ image.cutout(coords, self.shape, wcs=self.wcs, sources=self.sources) for coords in image.sources.coords ] f = 0
# TODO: delete? class _SetAffineTransform(Block): def __init__(self, name=None, verbose=False): super().__init__(name, verbose) self._parallel_friendly = True def run(self, image): rotation = image.__dict__.get("rotation", 0) translation = image.__dict__.get("translation", (0, 0)) scale = image.__dict__.get("scale", 0) image.transform = AffineTransform( rotation=rotation, translation=translation, scale=scale ) @property def citations(self): return super().citations + ["scikit-image"]
[docs]class Drizzle(Block): def __init__(self, reference, pixfrac=1.0, **kwargs): """Produce a dithered image. Requires :code:`drizzle` package. All images (including reference must be plate-solved). After :code:`terminate` is called (e.g. when a sequence is entirely ran), the dithered image can be found in self.image Parameters ---------- reference : prose.Image Reference image on which the stacking is based pixfrac : float, optional fraction of pixel used in dithering, by default 1. """ from drizzle import drizzle super().__init__(self, **kwargs) self.pixfrac = pixfrac reference.wcs.pixel_shape = reference.shape self.drizzle = drizzle.Drizzle(outwcs=reference.wcs, pixfrac=pixfrac) self.image = reference.copy() def run(self, image): WCS = image.wcs self.drizzle.add_image(image.data, image.wcs) def terminate(self): self.image.data = self.drizzle.outsci
[docs]class ComputeTransformTwirl(Block): """ Compute transformation of an image to a reference image |read| :code:`Image.sources` on both reference and input image |write| :code:`Image.transform` Parameters ---------- ref : Image Image containing detected sources n : int, optional Number of stars to consider to compute transformation, by default 10 Raises ------ SingularMatrix Transformation matrix could not be computed. Check the sources in both the reference and input image. """ def __init__(self, reference_image: Image, n=10, rtol=0.02, **kwargs): super().__init__(**kwargs) ref_coords = reference_image.sources.coords self.ref = ref_coords[0:n].copy() self.n = n self._parallel_friendly = True # twirl self.quads_ref, self.asterisms_ref = quads.hashes(ref_coords[0:n]) self.tree_ref = cKDTree(self.quads_ref) self.rtol = rtol def run(self, image): if len(image.sources.coords) >= 5: result = self.solve(image.sources.coords) if result is not None: image.transform = AffineTransform(result).inverse else: image.discard = True else: image.discard = True def solve(self, coords, tolerance=2, refine=True): quads_image, asterisms_image = quads.hashes(coords) tree_image = cKDTree(quads_image) min_match = 0.7 ball_query = tree_image.query_ball_tree(self.tree_ref, r=self.rtol) pairs = [] for i, j in enumerate(ball_query): if len(j) > 0: pairs += [[i, k] for k in j] matches = [] for i, j in pairs: M = get_transform_matrix(self.asterisms_ref[j], asterisms_image[i]) test = (M @ pad(self.ref).T)[0:2].T match = count_cross_match(coords, test, tolerance) matches.append(match) if min_match is not None: if isinstance(min_match, float): if match >= min_match * len(coords): break if len(matches) == 0: return None else: i, j = pairs[np.argmax(matches)] if refine: M = get_transform_matrix(self.asterisms_ref[j], asterisms_image[i]) test = (M @ pad(self.ref).T)[0:2].T s1, s2 = cross_match( coords, test, tolerance=tolerance, return_idxs=True ).T M = get_transform_matrix(self.ref[s2], coords[s1]) else: M = get_transform_matrix(self.asterisms_ref[j], asterisms_image[i]) return M
# backward compatibility ComputeTransform = ComputeTransformTwirl
[docs]class ComputeTransformXYShift(Block): """ Compute translational transform of a target image to a reference image |read| ``Image.sources`` on both reference and input image |write| ``Image.transform`` Parameters ---------- ref : Image image containing detected sources n : int, optional number of stars to consider to compute transformation, by default 10 """ def __init__(self, reference_image: Image, n=10, discard=True, **kwargs): super().__init__(**kwargs) ref_coords = reference_image.sources.coords self.n = n self.ref_coords = ref_coords[0:n].copy() self.discard = discard self._parallel_friendly = True def run(self, image): if len(image.sources.coords) >= 5: if len(image.sources.coords) <= 2: shift = self.ref_coords[0] - image.sources.coords[0] else: shift = self.xyshift(image.sources.coords, self.ref_coords) if shift is not None: image.transform = AffineTransform(translation=shift) else: image.discard = True else: image.discard = True def xyshift(self, im_stars_pos, ref_stars_pos, tolerance=1.5): """ Compute shift between two set of coordinates (e.g. stars) Parameters ---------- im_stars_pos : list or ndarray (x,y) coordinates of n points (shape should be (2, n)) ref_stars_pos : list or ndarray [(x,y) coordinates of n points (shape should be (2, n)). Reference set tolerance : float, optional by default 1.5 clean : bool, optional Merge coordinates if too close, by default False Returns ------- ndarray (dx, dy) shift """ assert ( len(im_stars_pos) > 2 ), f"{len(im_stars_pos)} star coordinates provided (should be > 2)" clean_ref = ref_stars_pos clean_im = im_stars_pos delta_x = np.array([clean_ref[:, 0] - v for v in clean_im[:, 0]]).flatten() delta_y = np.array([clean_ref[:, 1] - v for v in clean_im[:, 1]]).flatten() delta_x_compare = [] for i, dxi in enumerate(delta_x): dcxi = dxi - delta_x dcxi[i] = np.inf delta_x_compare.append(dcxi) delta_y_compare = [] for i, dyi in enumerate(delta_y): dcyi = dyi - delta_y dcyi[i] = np.inf delta_y_compare.append(dcyi) tests = [ np.logical_and(np.abs(dxc) < tolerance, np.abs(dyc) < tolerance) for dxc, dyc in zip(delta_x_compare, delta_y_compare) ] num = np.array([np.count_nonzero(test) for test in tests]) max_count_num_i = int(np.argmax(num)) max_nums_ids = np.argwhere(num == num[max_count_num_i]).flatten() dxs = np.array([delta_x[np.where(tests[i])] for i in max_nums_ids]) dys = np.array([delta_y[np.where(tests[i])] for i in max_nums_ids]) return np.nan_to_num(np.array([np.mean(dxs), np.mean(dys)]))