Source code for exosim.tasks.subexposures.instantaneousReadOut

import copy
from copy import deepcopy

import astropy.units as u
import numpy as np
from numba import jit, prange
from scipy.interpolate import RectBivariateSpline

from exosim.models.signal import Counts
from exosim.tasks.task import Task
from exosim.utils.iterators import iterate_over_chunks
from exosim.utils.operations import operate_over_axis


[docs]class InstantaneousReadOut(Task): """ This task implements the instantaneous read out. It loads the readout configuration from a dictionary that is produced by :class:`~exosim.tasks.subexposures.prepareInstantaneousReadOut.PrepareInstantaneousReadOut`. Then it creates the sub-exposure datacube, with sub-exposure for each NDRs in the ramp sampling scheme. Each of these sub-exposures collects more simulation steps, which have their own jitter offset. This Task iterates over the time steps to jitter the focal plane and pile it to the appropriate sub-exposure. The jittering is based on the focal plane oversampling factor. The contributions to each sub-exposures are averaged and the final product is multiplied by its integration time. Returns -------- :class:`~exosim.models.signal.Counts` sub-exposure cached signal class """ def __init__(self): """ Parameters ---------- readout_parameters: dict readout_parameters dict focal_plane: :class:`~exosim.models.signal.CountsPerSecond` channel focal plane pointing_jitter: (:class:`~astropy.units.Quantity`, :class:`~astropy.units.Quantity`, :class:`~astropy.units.Quantity`) Tuple containing the pointing jitter in the spatial and spectral direction expressed in units of deg, adn jitter time expressed as sec. parameters: dict dictionary containing the channel parameters. This is usually parsed from :class:`~exosim.tasks.load.loadOptions.LoadOptions` output: str or :class:`~exosim.output.hdf5.hdf5.HDF5Output` or :class:`~exosim.output.hdf5.hdf5.HDF5OutputGroup` output file dataset_name: str (optional) dataset name. Default is "SubExposures". """ self.add_task_param("parameters", "channel parameters dict") self.add_task_param("readout_parameters", "channel parameters dict") self.add_task_param("focal_plane", "loaded focal plane") self.add_task_param("pointing_jitter", "") self.add_task_param("output_file", "output file") self.add_task_param("dataset_name", "dataset name", "SubExposures") self.add_task_param( "slicing", "jittering by slice, avoid to create a large cube in memory", False, ) self.store_dict = {}
[docs] def execute(self): parameters = self.get_task_param("parameters") readout_parameters = self.get_task_param("readout_parameters") pointing_jitter = self.get_task_param("pointing_jitter") focal_plane = self.get_task_param("focal_plane") output_file = self.get_task_param("output_file") dataset_name = self.get_task_param("dataset_name") focal = copy.deepcopy(focal_plane.data.astype(np.float64)) base_osf = focal_plane.metadata["oversampling"] ndr_integration_times = readout_parameters["ndr_integration_times"] clock = readout_parameters["simulation_clock"] fp_time = readout_parameters["fp_time"] start_index = readout_parameters[ "ndr_start_cumulative_sequence" ].astype(int) end_index = readout_parameters["ndr_end_cumulative_sequence"].astype( int ) saveMemory = self.get_task_param("slicing") # saveMemory=True out = Counts( spectral=focal_plane.spectral[int(base_osf // 2) :: base_osf] * focal_plane.spectral_units, time=(start_index * clock).to(u.hr), data=None, spatial=focal_plane.spatial[int(base_osf // 2) :: base_osf] * focal_plane.spatial_units, shape=( end_index.shape[0], focal.shape[1] // base_osf, focal.shape[2] // base_osf, ), cached=True, output=output_file, dataset_name=dataset_name, output_path=None, metadata={"integration_times": ndr_integration_times}, dtype=np.float64, ) out.metadata["focal_plane_time_indexes"] = fp_time if pointing_jitter != (None, None, None): self.debug("Pointing jitter found") # if jitter is enabled, the following key are available mag = readout_parameters["mag"] osf = base_osf * mag y_jit = (readout_parameters["y_jit"] * osf / u.pix).value x_jit = (readout_parameters["x_jit"] * osf / u.pix).value y_jit = y_jit.astype(int) x_jit = x_jit.astype(int) if saveMemory: """ debug=False if debug: new_focal = [self.oversample(fp, mag) for fp in focal] """ if mag != 1: xin, yin, xout, yout = self.getOversampleFactors( focal[0, ...], mag ) yshape = int(yout.shape[0] // osf) xshape = int(xout.shape[0] // osf) # time_line = np.zeros((start_index.shape[0], int(yout.shape[0] // osf), # int( xout.shape[0]// osf)),dtype=np.float64,) else: yshape = int(focal.shape[1] // osf) xshape = int(focal.shape[2] // osf) self.info( "jittering {} for {}".format( dataset_name, parameters["value"] ) ) for chunk in iterate_over_chunks( out.dataset, desc="jittering {}".format(parameters["value"]), ): time_line = np.zeros( (start_index[chunk[0]].shape[0], yshape, xshape), dtype=np.float64, ) # focal: (3420, 192, 1068) # chunk 274, 64, 356 # time_line (7856, 64, 356 ) # iterate over the timeline sub-exposureA t_cache = -1 fp_cache = None time_line_slice = None """ ndrs=prange(start_index[chunk[0]].shape[0]) message="chunk iteration: ndr %i-%i , t= %i -%i "%(ndrs.start,ndrs.stop-1,fp_time[chunk[0]][ndrs.start], fp_time[chunk[0]][ndrs.stop-1] ) self.info(message) """ for ndr in prange(start_index[chunk[0]].shape[0]): # select the focal plane at the right time t = fp_time[chunk[0]][ndr] # print (" ndr=",ndr, "t=",t, end='\r') fp_slice = focal[t, ...] if t == t_cache: fp = fp_cache else: if mag != 1: fp = self.oversample(fp_slice, mag) else: fp = fp_slice t_cache = t fp_cache = fp """ if debug: fp_slice_test= new_focal[t] message="Total difference sliced focal plane ndr=%i t=%i difference =%g"%(ndr,t,np.nansum(fp-fp_slice_test)) self.info(message) print (" fp.shape=", fp.shape) """ time_line_slice = self.jittering_the_focalplane_by_slice( fp, osf, start_index[chunk[0]], end_index[chunk[0]], x_jit, # TODO: check if this is correct: is it correct to use chunks? y_jit, fp_time[chunk[0]], time_line[ndr, ...], ndr, ) time_line[ndr, ...] = time_line_slice # dset= time_line[chunk[0]] out.dataset[chunk] = time_line out.output.flush() else: # apply jitter magnification if mag != 1: self.info( "resampling the focal plane: magnification factor {}".format( mag ) ) # resampling the focal plane and replace it with a new array new_focal = [self.oversample(fp, mag) for fp in focal] # focal = np.array(new_focal) focal = copy.deepcopy(np.array(new_focal)) self.debug( "focal plane resampled: new shape {}".format( focal.shape ) ) self.info( "jittering {} for {}".format( dataset_name, parameters["value"] ) ) for chunk in iterate_over_chunks( out.dataset, desc="jittering {}".format(parameters["value"]), ): dset = self.jittering_the_focalplane( focal, osf, start_index[chunk[0]], end_index[chunk[0]], x_jit, # TODO: check if this is correct: is it correct to use chunks? y_jit, fp_time[chunk[0]], ) out.dataset[chunk] = dset out.output.flush() # Here we force the power conservation, if the user enabled the option if "force_power_conservation" in parameters.keys(): if parameters["force_power_conservation"]: self.warning("forcing power conservation") self.force_power_conservation( out, parameters, focal_plane, fp_time, osf ) # applying integtation time to the jittered focal planes for chunk in iterate_over_chunks( out.dataset, desc="applying integration time {}".format( parameters["value"] ), ): dset = out.dataset[chunk] dset = operate_over_axis( dset, ndr_integration_times[chunk[0]].value, 0, "*" ) out.dataset[chunk] = dset out.output.flush() else: self.info( "no jitter in {} for {}".format( dataset_name, parameters["value"] ) ) focal = deepcopy(focal_plane.data[:, 0::base_osf, 0::base_osf]) for chunk in out.dataset.iter_chunks(): dset = self.replicating_the_focalplane( focal, start_index[chunk[0]], fp_time[chunk[0]] ) self.debug("focal plane replicated") dset = operate_over_axis( dset, ndr_integration_times[chunk[0]].value, 0, "*" ) out.dataset[chunk] = dset out.output.flush() self.set_output(out)
[docs] def force_power_conservation( self, out, parameters, focal_plane, fp_time, osf ): # to compute the total power on the focal plane I use the undersampled focal plane total_power = np.empty(out.dataset.shape[0]) desired_power = np.empty(out.dataset.shape[0]) for chunk in iterate_over_chunks( out.dataset, desc="computing median incoming power {}".format( parameters["value"] ), ): # computing the total power in the jittered focal plane dset = out.dataset[chunk] total_power[chunk[0]] = dset.sum(axis=-1).sum(axis=-1) # computing the desired power from the original focal planes fp_time_ = fp_time[chunk[0]] fp_times = list(set(fp_time_)) for time_id in fp_times: mask = np.where(fp_time_ == time_id)[0] # I estimated the expected power from the oversampled focal plane desired_power[chunk[0]][mask] = ( np.sum(focal_plane.data[time_id]) / osf**2 ) # applying integration time to the jittered focal planes for chunk in iterate_over_chunks( out.dataset, desc="forcing conservation of power {}".format( parameters["value"] ), ): dset = out.dataset[chunk] dset = operate_over_axis( dset, desired_power[chunk[0]] / total_power[chunk[0]], 0, "*" ) out.dataset[chunk] = dset out.output.flush() self.store_dict.update( {"median_power": desired_power, "total_power": total_power} )
@staticmethod
[docs] def getOversampleFactors(fp, ad_osf): """ Used in oversample method to determine the shape of the arrays . Parameters ---------- fp: :class:`~numpy.ndarray` 2D focal plane ad_osf: int magnification factor Returns ------- :class:`~numpy.ndarray` the x and y grids used in the input and output of the oversampling """ xin = np.linspace(0, fp.shape[1] - 1, fp.shape[1]) yin = np.linspace(0, fp.shape[0] - 1, fp.shape[0]) x_step = abs(xin[1]) - abs(xin[0]) y_step = abs(yin[1]) - abs(yin[0]) # calculates the new step sizes for new grid x_step_new = np.float64(x_step / ad_osf) y_step_new = np.float64(y_step / ad_osf) # new grid must start with an exact offset to produce correct number of new points x_start = -x_step_new * np.float64((ad_osf - 1) / 2) y_start = -y_step_new * np.float64((ad_osf - 1) / 2) # new grid points- with correct start, end and spacing xout = np.arange( x_start, x_start + x_step_new * fp.shape[1] * ad_osf, x_step_new ) yout = np.arange( y_start, y_start + y_step_new * fp.shape[0] * ad_osf, y_step_new ) return xin, yin, xout, yout
@staticmethod
[docs] def oversample(fp: np.array, ad_osf: int) -> np.array: """ It increases the oversampling factor of the focal plane. Parameters ---------- fp: :class:`~numpy.ndarray` 2D focal plane ad_osf: int magnification factor Returns ------- :class:`~numpy.ndarray` 2D focal plane sampled with the new oversampling factor """ xin, yin, xout, yout = InstantaneousReadOut.getOversampleFactors( fp, ad_osf ) # interpolate fp onto new grid fn = RectBivariateSpline(yin, xin, fp) new_fp = fn(yout, xout) return new_fp
@staticmethod @jit(nopython=True, parallel=True)
[docs] def jittering_the_focalplane_by_slice( fp: np.array, osf: int, start_index: np.array, end_index: np.array, x_jit: np.array, y_jit: np.array, fp_time: np.array, time_line_slice: np.array, ndr: int, ) -> np.array: """ Same as jittering_the_focalplane but to operate only on a single slice fp """ # iterate over the timeline sub-exposures j = int(osf // 2) # starting index for spatial direction # iterate over the spatial direction for y in range(time_line_slice.shape[0]): i = int(osf // 2) # starting index for spectral direction # iterate over the spectral direction for x in range(time_line_slice.shape[1]): # iterate over the jitter indices for idx in range(start_index[ndr], end_index[ndr]): # selecting the jitter offset indices j_jit = y_jit[idx] + j i_jit = x_jit[idx] + i # if negative index, then roll the array if i_jit < 0: i_jit = i_jit + fp.shape[1] elif i_jit >= fp.shape[1]: i_jit = i_jit - fp.shape[1] if j_jit < 0: j_jit = j_jit + fp.shape[0] elif j_jit >= fp.shape[0]: j_jit = j_jit - fp.shape[0] # if negative index, then roll the array time_line_slice[y, x] = ( time_line_slice[y, x] + fp[j_jit, i_jit] ) # move to the next pixel in the spectral direction i = i + osf # move to the next pixel in the spatial direction j = j + osf # divide by the number of jitter positions added time_line_slice = time_line_slice / (end_index[ndr] - start_index[ndr]) return time_line_slice
@staticmethod @jit(nopython=True, parallel=True)
[docs] def jittering_the_focalplane( fp: np.array, osf: int, start_index: np.array, end_index: np.array, x_jit: np.array, y_jit: np.array, fp_time: np.array, ) -> np.array: # create an empty array to store the jittered focal plane time_line = np.zeros( ( start_index.shape[0], int(fp.shape[1] // osf), int(fp.shape[2] // osf), ), dtype=np.float64, ) # iterate over the timeline sub-exposures for ndr in prange(start_index.shape[0]): # select the focal plane at the right time t = fp_time[ndr] j = int(osf // 2) # starting index for spatial direction # iterate over the spatial direction for y in range(time_line.shape[1]): i = int(osf // 2) # starting index for spectral direction # iterate over the spectral direction for x in range(time_line.shape[2]): # iterate over the jitter indices for idx in range(start_index[ndr], end_index[ndr]): # selecting the jitter offset indices j_jit = y_jit[idx] + j i_jit = x_jit[idx] + i # if negative index, then roll the array if i_jit < 0: i_jit = i_jit + fp.shape[2] elif i_jit >= fp.shape[2]: i_jit = i_jit - fp.shape[2] if j_jit < 0: j_jit = j_jit + fp.shape[1] elif j_jit >= fp.shape[1]: j_jit = j_jit - fp.shape[1] # if negative index, then roll the array time_line[ndr, y, x] = ( time_line[ndr, y, x] + fp[t, j_jit, i_jit] ) # move to the next pixel in the spectral direction i = i + osf # move to the next pixel in the spatial direction j = j + osf # divide by the number of jitter positions added time_line[ndr] = time_line[ndr] / ( end_index[ndr] - start_index[ndr] ) return time_line
@staticmethod @jit(nopython=True, parallel=True)
[docs] def replicating_the_focalplane( fp: np.ndarray, index: np.ndarray, fp_time: np.ndarray ) -> np.ndarray: fp = fp.astype(np.float64) index = index.astype(np.int64) fp_time = fp_time.astype(np.int64) n_nDR, n_y, n_x = index.shape[0], fp.shape[1], fp.shape[2] time_line = np.empty((n_nDR, n_y, n_x), dtype=np.float64) for ndr in prange(n_nDR): t = fp_time[ndr] time_line[ndr] = fp[t] return time_line