Source code for prose.blocks.centroids

import warnings
from os import path

import numpy as np
from astropy.utils.exceptions import AstropyUserWarning
from photutils.centroids import (
    centroid_2dg,
    centroid_com,
    centroid_quadratic,
    centroid_sources,
)

from prose import CONFIG, Block

from .geometry import Cutouts

__all__ = [
    "CentroidCOM",
    "CentroidGaussian2D",
    "CentroidQuadratic",
    "CentroidBallet",
]

TF_LOADED = False


class _PhotutilsCentroid(Block):
    def __init__(self, centroid_func, limit=None, cutout=21, name=None):
        """Photutils centroiding

        Parameters
        ----------
        centroid_func : function
            photutils.centroids function
        limit : int, optional
            maximum deviation from initial coordinate, by default `cutout/2`
        cutout : int, optional
            size of the cutout to be used for centroiding, by default 21
        """
        super().__init__(name=name, read=["sources", "data"])
        self.cutout = cutout
        self.centroid_func = centroid_func
        if limit is None:
            limit = cutout / 2
        self.limit = limit

    def run(self, image):
        # *%+#@ photutils check (see photutils.centroids.core code...)
        in_image = np.all(image.sources.coords < image.shape[::-1] - (1, 1), axis=1)
        in_image = np.logical_and(
            in_image, np.all(image.sources.coords > (0, 0), axis=1)
        )
        x, y = image.sources.coords[in_image].T.copy()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", AstropyUserWarning)
            centroid_sources_coords = np.array(
                centroid_sources(
                    image.data,
                    x,
                    y,
                    box_size=self.cutout,
                    centroid_func=self.centroid_func,
                )
            ).T

        sources_coords = image.sources.coords.copy()
        sources_coords[in_image] = centroid_sources_coords
        in_limit = (
            np.linalg.norm(image.sources.coords - sources_coords, axis=1) < self.limit
        )
        final_sources_coords = image.sources.coords.copy()
        final_sources_coords[in_limit] = sources_coords[in_limit]
        image.sources.coords = final_sources_coords

    @property
    def citations(self) -> list:
        return super().citations + ["photutils"]


[docs]class CentroidCOM(_PhotutilsCentroid): """Centroiding using ``photutils.centroids.centroid_com`` |read| ``Image.sources`` |write| ``Image.sources`` Parameters ---------- limit : int, optional maximum deviation from initial coordinate, by default `cutout/2` cutout : int, optional size of the cutout to be used for centroiding, by default 21 """ def __init__(self, limit=None, cutout=21): super().__init__(centroid_func=centroid_com, limit=limit, cutout=cutout)
[docs]class CentroidGaussian2D(_PhotutilsCentroid): """Centroiding using ``photutils.centroids.centroid_2dg`` |read| ``Image.sources`` |write| ``Image.sources`` Parameters ---------- limit : int, optional maximum deviation from initial coordinate, by default `cutout/2` cutout : int, optional size of the cutout to be used for centroiding, by default 21 """ def __init__(self, limit=None, cutout=21): super().__init__(centroid_func=centroid_2dg, limit=limit, cutout=cutout)
[docs]class CentroidQuadratic(_PhotutilsCentroid): """Centroiding using ``photutils.centroids.centroid_quadratic`` |read| ``Image.sources`` |write| ``Image.sources`` Parameters ---------- limit : int, optional maximum deviation from initial coordinate, by default `cutout/2` cutout : int, optional size of the cutout to be used for centroiding, by default 21 """ def __init__(self, limit=None, cutout=21): super().__init__(centroid_func=centroid_quadratic, limit=limit, cutout=cutout)
class _CNNCentroid(Block): def __init__(self, cutout=15, filename=None, limit=None, **kwargs): super().__init__(**kwargs) self.filename = filename self.model = None self.cutout = cutout self.x, self.y = np.indices((cutout, cutout)) if limit is None: limit = cutout / 2 self.limit = limit def import_and_check_model(self): model_file = path.join(CONFIG.folder_path, self.filename) if path.exists(model_file): self.build_model() self.model.load_weights(model_file) else: raise AssertionError("Still on dev, contact lgrcia") def build_model(self): raise NotImplementedError() def run(self, image): n = 15 in_image = np.all(image.sources.coords < image.shape[::-1] - (1, 1), axis=1) in_image = np.logical_and( in_image, np.all(image.sources.coords > (0, 0), axis=1) ) in_image_coords = image.sources.coords[in_image].copy() cutouts = image.data_cutouts(in_image_coords, (n, n)) cutouts_reshaped = cutouts / np.mean(cutouts, (1, 2))[:, None, None] cutouts_reshaped = cutouts_reshaped[..., None] cutouts_origins = in_image_coords - n / 2 # apply model centroid_sources_coords = ( cutouts_origins + self.model(cutouts_reshaped, training=False).numpy()[:, ::-1] ) # if coords is nan (any of x, y), keep old coord nan_mask = np.any(np.isnan(centroid_sources_coords), 1) centroid_sources_coords[nan_mask] = in_image_coords[nan_mask] # apply limit sources_coords = image.sources.coords.copy() sources_coords[in_image] = centroid_sources_coords in_limit = ( np.linalg.norm(image.sources.coords - sources_coords, axis=1) < self.limit ) final_sources_coords = image.sources.coords.copy() final_sources_coords[in_limit] = sources_coords[in_limit] image.sources.coords = final_sources_coords @property def citations(self): return super().citations + ["tensorflow", "keras"]
[docs]class CentroidBallet(_CNNCentroid): """Centroiding with `ballet <https://github.com/lgrcia/ballet>`_. |write| ``Image.stars_coords`` """ def __init__(self, **kwargs): super().__init__(cutout=15, filename="centroid.h5", **kwargs) self.import_and_check_model() def build_model(self): try: import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D from tensorflow.keras.models import Sequential except ModuleNotFoundError: TF_LOADED = True self.model = Sequential( [ Conv2D( 64, (3, 3), activation="relu", input_shape=(self.cutout, self.cutout, 1), use_bias=True, padding="same", ), MaxPooling2D((2, 2), padding="same"), Conv2D(128, (3, 3), activation="relu", use_bias=True, padding="same"), MaxPooling2D((2, 2), padding="same"), Conv2D(256, (3, 3), activation="relu", use_bias=True, padding="same"), Flatten(), Dense(2048, activation="sigmoid", use_bias=True), Dense(512, activation="sigmoid", use_bias=True), Dense(2), ] )
# For reference class _OldNNCentroid(_CNNCentroid): def __init__(self, **kwargs): try: import os os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D from tensorflow.keras.models import Sequential except ModuleNotFoundError: TF_LOADED = True super().__init__(cutout=21, filename="oldcentroid.h5", **kwargs) self.import_and_check_model() def build_model(self): self.model = self.tf_models.Sequential( [ self.tf_layers.Conv2D( self.cutout, (3, 3), activation="relu", input_shape=(self.cutout, self.cutout, 1), ), self.tf_layers.MaxPooling2D((2, 2)), self.tf_layers.Conv2D(64, (3, 3), activation="relu"), self.tf_layers.MaxPooling2D((2, 2)), self.tf_layers.Conv2D(124, (3, 3), activation="relu"), self.tf_layers.Flatten(), self.tf_layers.Dense(2048, activation="relu"), self.tf_layers.Dense(2), ] )