from copy import copy
import warnings
import numpy as np
from scipy.interpolate import interp1d
from scipy.special import gammaincinv
from astropy.utils.console import ProgressBar
from astropy.modeling import models
from astropy.table import Table
from ..modeling.models import PSFConvolvedModel2D, sersic_enclosed, sersic_enclosed_inv
from ..modeling.fitting import model_to_image
from ..photometry import radial_photometry
from ..utils import mpl_tick_frame
from .core import Petrosian, calculate_petrosian_r, calculate_petrosian
from matplotlib import pyplot as plt
__all__ = ["generate_petrosian_sersic_correction", "PetrosianCorrection"]
def _generate_petrosian_correction(args):
"""
Helper function to compute corrections for a single pair of `r_eff` and `n`.
`args` should be a list `[r_eff, n, psf, oversample, plot]`. See
`generate_petrosian_sersic_correction` doctring for more information.
"""
# Unpack params
r_eff, n, psf, oversample, psf_oversample, plot = args
amplitude = 100 / np.exp(gammaincinv(2.0 * n, 0.5))
# Total flux
L_total = sersic_enclosed(np.inf, amplitude=amplitude, r_eff=r_eff, n=n)
total_flux = L_total * 0.99
# Calculate radii
r_20, r_80, r_total_flux = [
sersic_enclosed_inv(
total_flux * fraction, amplitude=amplitude, r_eff=r_eff, n=n
)
for fraction in [0.2, 0.8, 1.0]
]
# Make r_list
max_r = r_total_flux * 3 if n < 2 else r_total_flux * 1.3
if r_eff < 7 and n < 1:
max_r = r_total_flux * 50
if max_r >= 200:
r_list = [x for x in range(1, 201, 2)]
r_list += [x for x in range(300, int(max_r) + 100, 100)]
else:
r_list = [x for x in range(1, int(max_r) + 2, 2)]
r_list = np.array(r_list)
image_size = max(r_list) * 2
x_0 = image_size // 2
y_0 = image_size // 2
# Make Model Image
# ----------------
# Define model
galaxy_model = models.Sersic2D(
amplitude=amplitude,
r_eff=r_eff,
n=n,
x_0=x_0,
y_0=y_0,
ellip=0.0,
theta=0.0,
)
# Wrap model with PSFConvolvedModel2D
galaxy_model = PSFConvolvedModel2D(
galaxy_model, psf=psf, oversample=oversample, psf_oversample=psf_oversample
)
# Make galaxy image from PSFConvolvedModel2D
galaxy_image = model_to_image(galaxy_model, image_size, center=(x_0, y_0))
# Do photometry on model galaxy image
flux_list, area_list, err = radial_photometry(
galaxy_image, (x_0, y_0), r_list, plot=plot, vmax=amplitude / 100
)
if plot:
plt.show()
# Calculate Photometry and petrosian
# ----------------------------------
# Petrosian from Photometry
p = Petrosian(r_list, area_list, flux_list)
rc1, rc2, c_index = p.concentration_index()
if np.any(np.isnan(np.array([rc1, rc2, c_index]))):
raise Exception(
"concentration_index cannot be computed (n={}, r_e={})".format(n, r_eff)
)
# Compute new r_total_flux
_, indices = np.unique(flux_list, return_index=True)
indices = np.array(indices)
f = interp1d(flux_list[indices], r_list[indices], kind="linear")
model_r_total_flux = f(total_flux)
# Compute new r_80
model_r_80 = f(total_flux * 0.8)
# Compute corrections
corrected_epsilon = model_r_total_flux / p.r_petrosian
corrected_epsilon_80 = model_r_80 / p.r_petrosian
corrected_p = copy(p)
corrected_p.epsilon = corrected_epsilon
# Make output list
# ----------------
# Petrosian indices
petrosian_list = calculate_petrosian(p.area_list, p.flux_list)[0]
p02, p03, p04, p05 = [
calculate_petrosian_r(p.r_list, petrosian_list, eta=i)[0]
for i in (0.2, 0.3, 0.4, 0.5)
]
assert np.round(p.r_petrosian, 6) == np.round(p02, 6)
u_r_eff = p.fraction_flux_to_r(0.5)
u_r_20 = p.fraction_flux_to_r(0.2)
u_r_50 = p.fraction_flux_to_r(0.5)
u_r_80 = p.fraction_flux_to_r(0.8)
c_r_eff = corrected_p.fraction_flux_to_r(0.5)
c_r_20 = corrected_p.fraction_flux_to_r(0.2)
c_r_50 = corrected_p.fraction_flux_to_r(0.5)
c_r_80 = corrected_p.fraction_flux_to_r(0.8)
row = [
n,
r_eff,
r_20,
r_80,
r_total_flux,
L_total,
p02,
p03,
p04,
p05,
5 * np.log10(p02 / p05),
5 * np.log10(p02 / p03),
p.epsilon,
u_r_50 / p.r_petrosian,
u_r_80 / p.r_petrosian,
u_r_eff,
p.r_total_flux,
u_r_20,
u_r_80,
p.c2080,
p.c5090,
corrected_epsilon,
c_r_50 / p.r_petrosian,
corrected_epsilon_80,
c_r_eff,
corrected_p.r_total_flux,
c_r_20,
c_r_80,
corrected_p.c2080,
corrected_p.c5090,
]
if plot:
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
plt.sca(axs[0])
corrected_p.plot()
plt.sca(axs[1])
corrected_p.plot_cog()
plt.show()
print(corrected_epsilon)
print(r_eff, p.r_half_light, corrected_p.r_half_light)
print(" ")
del galaxy_model, galaxy_image
del flux_list, area_list, err
del corrected_p, p
return row
[docs]
def generate_petrosian_sersic_correction(
output_file_name,
psf=None,
r_eff_list=None,
n_list=None,
oversample=("x_0", "y_0", 10, 50),
psf_oversample=None,
out_format=None,
overwrite=False,
ipython_widget=False,
n_cpu=None,
plot=False,
):
"""
Generate corrections for Petrosian profiles by simulating a galaxy image (single component sersic) and measuring its
properties. This is done to identify the correct `epsilon` value that, when multiplied with `r_petrosian`, gives
`r_total_flux`. To achieve this, an image is created from a Sersic model and convolved with a PSF (if provided).
The Petrosian radii and concentrations are computed using the default `epsilon` = 2. Since the real `r_total_flux`
of the simulated galaxy is known, the correct `epsilon` can be determined by
`epsilon = r_petrosian / corrceted_r_total_flux`. The resulting grid is used to map measured properties to the
correct `epsilon` value. If `output_file_name` is provided, the grid is saved to using an astropy table file which
is readable by `petrofit.petrosian.PetrosianCorrection`.
Parameters
----------
output_file_name : str
Name of output file, must have .yaml or .yml extension.
psf : numpy.array or None
2D PSF image to pass to `petrofit.fitting.models.PSFConvolvedModel2D`.
r_eff_list : list, (optional)
List of `r_eff` (half light radii) in pixels to evaluate.
n_list : list, (optional)
List of Sersic indices to evaluate.
oversample : int or tuple
oversampling to pass to `petrofit.fitting.models.PSFConvolvedModel2D`.
psf_oversample : None or int
Oversampling factor of the PSF relative to data. The `oversample` factor should be an integer multiple
of the PSF oversampling factor (i.e `oversample > psf_oversample`).
out_format : str, optional
Format passed to the resulting astropy table when writing to file.
overwrite : bool, optional
Overwrite if file exists.
ipython_widget : bool, optional
If True, the progress bar will display as an IPython notebook widget.
n_cpu : bool, int, optional
If True, use the multiprocessing module to distribute each task to a different
processor core. If a number greater than 1, then use that number of cores. This
should be selected taking ram in consideration (since high n and large r_eff
create large images).
plot : bool
Shows plot of photometry and Petrosian. Not available if n_cpu > 1.
Returns
-------
petrosian_grid : Table
Astropy Table that is readable by `petrofit.petrosian.PetrosianCorrection`
"""
if r_eff_list is None:
r_eff_list = np.arange(10, 100 + 5, 5)
if n_list is None:
n_list = np.arange(0.5, 4.5 + 0.5, 0.5)
if psf is not None and psf.sum() != 1:
warnings.warn(
"Input PSF not normalized to 1, current sum = {}. This may cause major errors".format(
psf.sum()
)
)
r_eff_list = np.array(r_eff_list)
n_list = np.round(np.array(n_list), 6)
# Make list of args for _generate_petrosian_correction
args = []
for n_idx, n in enumerate(n_list):
for r_eff_idx, r_eff in enumerate(r_eff_list):
args.append([r_eff, n, psf, oversample, psf_oversample, plot])
# Call _generate_petrosian_correction
# either on one thread on using multiprocessing
if n_cpu is None or n_cpu == 1:
with ProgressBar(len(args), ipython_widget=ipython_widget) as bar:
rows = []
for arg in args:
row = _generate_petrosian_correction(arg)
rows.append(row)
bar.update()
else:
assert plot == False, "Plotting not available for ncpu > 1"
step = 50 if len(r_eff_list) * len(n_list) > 500 else 2
rows = ProgressBar.map(
_generate_petrosian_correction,
args,
multiprocess=n_cpu,
ipython_widget=ipython_widget,
step=step,
)
names = [
"n",
"r_eff",
"sersic_r_20",
"sersic_r_80",
"sersic_r_99",
"sersic_L_inf",
"p02",
"p03",
"p04",
"p05",
"p0502",
"p0302",
"u_epsilon",
"u_epsilon_50",
"u_epsilon_80",
"u_r_50",
"u_r_99",
"u_r_20",
"u_r_80",
"u_c2080",
"u_c5090",
"c_epsilon",
"c_epsilon_50",
"c_epsilon_80",
"c_r_50",
"c_r_99",
"c_r_20",
"c_r_80",
"c_c2080",
"c_c5090",
]
petrosian_grid = Table(rows=rows, names=names)
if output_file_name is not None:
try:
petrosian_grid.write(
output_file_name, format=out_format, overwrite=overwrite
)
except Exception as e:
print("Could not save to file: {}".format(e))
print("You can save the returned table using `petrosian_grid.write`")
return petrosian_grid
[docs]
class PetrosianCorrection:
"""
This class computes corrections for Petrosian given default Petrosian measurements.
"""
def __init__(self, grid, enforce_range=True):
"""
Parameters
----------
grid : str
Correction grid generated by `petrofit.correction.generate_petrosian_sersic_correction`.
Use `PetrosianCorrection.read(file_path)` to read grid from file.
enforce_range : bool
If true, the nearest approximation is returned. If false, an assertion will be applied that makes sure
that the profiles to be corrected are covered by the correction grid.
"""
self.enforce_range = enforce_range
self.eta_keys = {0.2: "p02", 0.3: "p03", 0.4: "p04", 0.5: "p05"}
if isinstance(grid, Table):
self.grid = grid
elif isinstance(grid, str):
raise TypeError(
"Input grid should be an astropy Table use `PetrosianCorrection.read(file_path)`"
)
else:
raise TypeError("Input grid should be an astropy Table")
self.x = self.grid["p02"].value
self.y = self.grid["u_r_50"].value
self.z = self.grid["u_c2080"].value
self.r = [self.x, self.y, self.z]
self.weights = np.array([100, 100, 100])
def _get_xyz_from_p(self, p):
px_list = (0.2, 0.3, 0.4, 0.5)
p02, p03, p04, p05 = [
calculate_petrosian_r(
p.r_list, p.petrosian_list, petrosian_err=None, eta=i
)[0]
for i in px_list
]
x0 = p02
y0 = p.r_50
z0 = p.c2080
return [x0, y0, z0]
@staticmethod
def _read_grid_file(grid_file, file_format=None):
return Table.read(grid_file, format=file_format)
[docs]
@classmethod
def read(cls, grid_file, file_format=None):
"""Read grid from file."""
grid = cls._read_grid_file(grid_file, file_format)
return cls(grid)
[docs]
def write(self, grid_file, file_format=None):
"""Write grid to file."""
self.grid.write(grid_file, format=file_format)
@property
def grid_keys(self):
"""Return dictionary keys of the grid."""
return self.grid.colnames
[docs]
def unique_grid_values(self, key):
"""Return unique values of a key in the grid."""
return np.array(np.unique(self.grid[key]))
[docs]
def filter_grid(self, key, value):
"""Return a filtered grid based on key and value."""
idx = np.where(self.grid[key] == value)
return self.grid[idx]
def _dr(self, x0, y0, z0):
"""Given x0, y0, z0 nearest grid row distance. Called by `_closest_row`"""
wx, wy, wz = self.weights
# Standardize the data
std_x = np.std(self.x)
std_y = np.std(self.y)
std_z = np.std(self.z)
dx = wx * (self.x - x0) / std_x
dy = wy * (self.y - y0) / std_y
dz = wz * (self.z - z0) / std_z
dr = np.sqrt(dx**2 + dy**2 + dz**2)
return dr
def _validate_input(self, x0, y0, z0):
assert (
np.min(self.x) <= x0 <= np.max(self.x)
), "r_petro(eta=0.2) is outside of the range of the grid"
assert (
np.min(self.y) <= y0 <= np.max(self.y)
), "r_50 is outside of the range of the grid"
assert (
np.min(self.z) <= z0 <= np.max(self.z)
), "C2080 is outside of the range of the grid"
def _closest_row(self, x0, y0, z0):
"""Given x0, y0, z0 nearest grid row"""
if self.enforce_range:
self._validate_input(x0, y0, z0)
dr = self._dr(x0, y0, z0)
return self.grid[dr.argmin()]
def _get_corrected_row(self, p):
"""Given a Petrosian object, return the correction row"""
x0, y0, z0 = self._get_xyz_from_p(p)
return self._closest_row(x0, y0, z0)
[docs]
def correct(self, p):
"""Given a Petrosian object, return a corrected Petrosian object"""
corrected_p = copy(p)
corrected_p.epsilon = self.estimate_epsilon(p)
return corrected_p
[docs]
def estimate_n(self, p):
"""
Given the half light radius and c2080 computed using the default epsilon value,
return an estimated sersic index n.
"""
row = self._get_corrected_row(p)
return row["n"]
[docs]
def estimate_epsilon(self, p):
"""
Given the half light radius and c2080 computed using the default epsilon value,
return a corrected epsilon value.
"""
row = self._get_corrected_row(p)
epsilon_fraction = p.epsilon_fraction
eta = p.eta
if epsilon_fraction == 0.5:
r_ep = row["c_r_eff"]
elif epsilon_fraction == 0.8:
r_ep = row["c_r_80"]
elif epsilon_fraction == 0.99:
r_ep = row["c_r_99"]
else:
raise ValueError(
"Input epsilon_fraction={} is not supported, choose from [0.5, 0.8, 0.99]"
)
if eta not in self.eta_keys.keys():
raise ValueError(
"Input eta={} is not supported, choose from eta={}".format(
eta, list(self.eta_keys.keys())
)
)
r_p = row[self.eta_keys[eta]]
epsilon = r_ep / r_p
return epsilon
def _plot_grid(
self,
x0=None,
y0=None,
z0=None,
cmap="hot",
target_c="blue",
cmap_key="n",
colorbar_label=None,
suptitle=None,
axs=None,
minorticks=False,
):
"""
Plots a grid of scatter plots for the given data.
Parameters
----------
x0, y0, z0 : float, optional
cmap : str, optional
Colormap to use for the scatter plots (default "hot").
target_c : str, optional
Color to use for highlighting the target point (default is "blue").
cmap_key : str, optional
Key to use for colormap data from the grid (default is "n").
colorbar_label : str, optional
Label for the colorbar (default is None).
suptitle : str, optional
Suptitle for the figure (default is None).
axs : list of matplotlib.axes.Axes, optional
List of 3 axes to plot on (default is None, which creates new axes).
minorticks : bool, optional
Whether to show minor ticks on the axes (default is False).
Returns
-------
fig : matplotlib.figure.Figure
The figure object containing the plots.
axs : list of matplotlib.axes.Axes
The list of axes containing the scatter plots.
Notes
-----
This function creates a grid of 3 scatter plots showing the relationships
between x, y, and z coordinates with color mapping based on `cmap_key`.
If `x0`, `y0`, and `z0` are provided, the closest point in the grid is
highlighted and connected to the target point.
"""
if axs is None:
fig, axs = plt.subplots(1, 3, figsize=[6 * 3, 6])
else:
assert len(axs) == 3, "axs should be a list of 3 axis"
fig = axs[0].figure
cm = plt.cm.get_cmap(cmap)
sim_n_list = self.grid[cmap_key]
ax = axs[0]
sc = ax.scatter(
self.x,
self.y,
c=sim_n_list,
vmin=0,
vmax=max(sim_n_list) + 1,
s=35,
cmap=cm,
)
ax.set_xlabel(r"$r_{{p}}(\eta=0.2)$")
ax.set_ylabel(r"$r_{{50}}$")
mpl_tick_frame(ax=ax, minorticks=minorticks)
ax = axs[1]
sc = ax.scatter(
self.x,
self.z,
c=sim_n_list,
vmin=0,
vmax=max(sim_n_list) + 1,
s=35,
cmap=cm,
)
ax.set_xlabel(r"$r_{{p}}(\eta=0.2)$")
ax.set_ylabel(r"$C_{2080}$")
mpl_tick_frame(ax=ax, minorticks=minorticks)
ax = axs[2]
sc = ax.scatter(
self.y,
self.z,
c=sim_n_list,
vmin=0,
vmax=max(sim_n_list) + 1,
s=35,
cmap=cm,
)
ax.set_xlabel(r"$r_{{50}}$")
ax.set_ylabel(r"$C_{2080}$")
mpl_tick_frame(ax=ax, minorticks=minorticks)
if None not in [x0, y0, z0]:
idx = self._dr(x0, y0, z0).argmin()
cx, cy, cz = self.x[idx], self.y[idx], self.z[idx]
axs[0].scatter(x0, y0, marker="o", s=200, ec=target_c, fc="None", lw=5)
axs[1].scatter(x0, z0, marker="o", s=200, ec=target_c, fc="None", lw=5)
axs[2].scatter(y0, z0, marker="o", s=200, ec=target_c, fc="None", lw=5)
axs[0].plot([x0, cx], [y0, cy], marker="o", lw=5)
axs[1].plot([x0, cx], [z0, cz], marker="o", lw=5)
axs[2].plot([y0, cy], [z0, cz], marker="o", lw=5)
fig.colorbar(
sc,
ax=axs,
location="bottom",
aspect=50,
label=colorbar_label if colorbar_label else cmap_key,
)
fig.suptitle(suptitle if suptitle else "Petrosian Correction Grid")
return fig, axs
[docs]
def plot_correction(
self,
p,
cmap="hot",
target_c="blue",
cmap_key="n",
colorbar_label=None,
suptitle=None,
axs=None,
):
"""
Plots the correction grid for the given a Petrosian object.
Parameters:
-----------
p : array-like
The Petrosian object for which the correction grid is to be plotted.
cmap : str, optional
The colormap to be used for the plot. Default is "hot".
target_c : str, optional
The color to be used for the target. Default is "blue".
cmap_key : str, optional
The key to be used for the colormap. Default is "n".
colorbar_label : str, optional
The label for the colorbar. Default is None.
suptitle : str, optional
The super title for the plot. Default is None.
axs : matplotlib.axes.Axes, optional
The axes on which to plot. If None, a new figure and axes are created. Default is None.
Returns:
--------
fig : matplotlib.figure.Figure
The figure object containing the plot.
axs : matplotlib.axes.Axes
The axes object containing the plot.
"""
x0, y0, z0 = self._get_xyz_from_p(p)
fig, axs = self._plot_grid(
x0=x0,
y0=y0,
z0=z0,
cmap=cmap,
cmap_key=cmap_key,
colorbar_label=colorbar_label,
target_c=target_c,
suptitle=suptitle,
axs=axs,
)
return fig, axs