Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup resample in preparation for refactoring #1601

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ dependencies = [
#"roman_datamodels @ git+https://github.com/spacetelescope/roman_datamodels.git",
"scipy >=1.14.1",
# "stcal>=1.10.0,<1.11.0",
"stcal @ git+https://github.com/spacetelescope/stcal.git@main",
# "stcal @ git+https://github.com/spacetelescope/stcal.git@main",
"stcal @ git+https://github.com/mcara/stcal.git@resample-common-code2",
# "stpipe >=0.7.0,<0.8.0",
"stpipe @ git+https://github.com/spacetelescope/stpipe.git@main",
"tweakwcs >=0.8.8",
Expand Down
1 change: 0 additions & 1 deletion romancal/outlier_detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def detect_outliers(
resamp = ResampleData(
library,
single=True,
blendheaders=False,
# FIXME prior code provided weight_type when only wht_type is understood
# both default to 'ivm' but tests that set this to something else did
# not change the resampling weight type. For now, disabling it to match
Expand Down
4 changes: 2 additions & 2 deletions romancal/resample/gwcs_drizzle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import numpy as np
from drizzle import cdrizzle, util
from drizzle import cdrizzle

from . import resample_utils

Expand Down Expand Up @@ -375,7 +375,7 @@ def dodrizzle(
"""

# Insure that the fillval parameter gets properly interpreted for use with tdriz
if util.is_blank(str(fillval)):
if fillval.strip() == "":
fillval = "INDEF"
else:
fillval = str(fillval)
Expand Down
194 changes: 34 additions & 160 deletions romancal/resample/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord
from drizzle import cdrizzle, util
from drizzle import cdrizzle
from roman_datamodels import datamodels, maker_utils, stnode
from stcal.alignment.util import compute_s_region_keyword, compute_scale

Expand All @@ -18,11 +18,7 @@
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

__all__ = ["OutputTooLargeError", "ResampleData"]


class OutputTooLargeError(MemoryError):
"""Raised when the output is too large for in-memory instantiation"""
__all__ = ["ResampleData"]


class ResampleData:
Expand Down Expand Up @@ -55,7 +51,12 @@ def __init__(
good_bits="0",
pscale_ratio=1.0,
pscale=None,
**kwargs,
in_memory=True,
output_wcs=None,
output_shape=None,
crpix=None,
crval=None,
rotation=None,
):
"""
Parameters
Expand All @@ -66,19 +67,6 @@ def __init__(

output : str
filename for output

kwargs : dict
Other parameters.

.. note::
``output_shape`` is in the ``x, y`` order.

.. note::
``in_memory`` controls whether or not the resampled
array from ``resample_many_to_many()``
should be kept in memory or written out to disk and
deleted from memory. Default value is `True` to keep
all products in memory.
"""
if (input_models is None) or (len(input_models) == 0):
raise ValueError(
Expand All @@ -94,7 +82,7 @@ def __init__(
self.fillval = fillval
self.weight_type = wht_type
self.good_bits = good_bits
self.in_memory = kwargs.get("in_memory", True)
self.in_memory = in_memory
if "target" in input_models.asn:
self.location_name = input_models.asn["target"]
else:
Expand All @@ -105,12 +93,6 @@ def __init__(
log.info(f"Driz parameter fillval: {self.fillval}")
log.info(f"Driz parameter weight_type: {self.weight_type}")

output_wcs = kwargs.get("output_wcs", None)
output_shape = kwargs.get("output_shape", None)
crpix = kwargs.get("crpix", None)
crval = kwargs.get("crval", None)
rotation = kwargs.get("rotation", None)

if pscale is not None:
log.info(f"Output pixel scale: {pscale} arcsec.")
pscale /= 3600.0
Expand Down Expand Up @@ -151,7 +133,7 @@ def __init__(
models = list(self.input_models)

# update meta.basic
populate_mosaic_basic(self.blank_output, models)
self.blank_output.meta.basic.product_type = "TBD"

# update meta.cal_step
self.blank_output.meta.cal_step = maker_utils.mk_l3_cal_step(
Expand All @@ -164,13 +146,26 @@ def __init__(
cal_logs = model.meta.cal_logs
# removing meta.cal_logs
del model.meta["cal_logs"]

# Update the output with all the component metas
populate_mosaic_individual(self.blank_output, [model])
self.blank_output.append_individual_image_meta(model.meta)

# re-attaching cal_logs to meta
model.meta.cal_logs = cal_logs

# update meta data and wcs
l2_into_l3_meta(self.blank_output.meta, models[0].meta)
l2_meta = models[0].meta
self.blank_output.meta.basic.visit = l2_meta.observation.visit
self.blank_output.meta.basic.segment = l2_meta.observation.segment
self.blank_output.meta.basic["pass"] = l2_meta.observation["pass"]
self.blank_output.meta.basic.program = l2_meta.observation.program
self.blank_output.meta.basic.optical_element = (
l2_meta.instrument.optical_element
)
self.blank_output.meta.basic.instrument = l2_meta.instrument.name
self.blank_output.meta.coordinates = l2_meta.coordinates
self.blank_output.meta.program = l2_meta.program

self.blank_output.meta.wcs = self.output_wcs
gwcs_into_l3(self.blank_output, self.output_wcs)

Expand Down Expand Up @@ -206,7 +201,11 @@ def resample_group(self, input_models, indices):
output_model.meta["resample"] = maker_utils.mk_resample()
output_model.meta.basic.location_name = self.location_name

copy_asn_info_from_library(input_models, output_model)
# copy over asn information
if (asn_pool := input_models.asn.get("asn_pool", None)) is not None:
output_model.meta.asn.pool_name = asn_pool
if (asn_table_name := input_models.asn.get("table_name", None)) is not None:
output_model.meta.asn.table_name = asn_table_name

with input_models:
example_image = input_models.borrow(indices[0])
Expand Down Expand Up @@ -539,18 +538,20 @@ def update_exposure_times(self, output_model, exptime_tot):
f"Mean, max exposure times: {total_exposure_time:.1f}, "
f"{max_exposure_time:.1f}"
)
exposure_times = {"start": [], "end": []}
exposure_times = {"start": [], "end": [], "mid": []}
with self.input_models:
for indices in self.input_models.group_indices.values():
index = indices[0]
model = self.input_models.borrow(index)
exposure_times["start"].append(model.meta.exposure.start_time)
exposure_times["end"].append(model.meta.exposure.end_time)
exposure_times["mid"].append(model.meta.exposure.mid_time.mjd)
self.input_models.shelve(model, index, modify=False)

# Update some basic exposure time values based on output_model
output_model.meta.basic.mean_exposure_time = total_exposure_time
output_model.meta.basic.time_first_mjd = min(exposure_times["start"]).mjd
output_model.meta.basic.time_mean_mjd = np.mean(exposure_times["mid"])
output_model.meta.basic.time_last_mjd = max(exposure_times["end"]).mjd
output_model.meta.basic.max_exposure_time = max_exposure_time
output_model.meta.resample.product_exposure_time = max_exposure_time
Expand Down Expand Up @@ -673,7 +674,7 @@ def drizzle_arrays(
"""

# Insure that the fillval parameter gets properly interpreted for use with tdriz
fillval = "INDEF" if util.is_blank(str(fillval)) else str(fillval)
fillval = "INDEF" if str(fillval).strip() == "" else str(fillval)
if insci.dtype > np.float32:
insci = insci.astype(np.float32)

Expand Down Expand Up @@ -731,39 +732,6 @@ def drizzle_arrays(
)


def l2_into_l3_meta(l3_meta, l2_meta):
"""Update the level 3 meta with info from the level 2 meta

Parameters
----------
l3_meta : dict
The meta to update. This is updated in-place

l2_meta : stnode
The Level 2-like meta to pull from

Notes
-----
The list of meta that is pulled from the Level 2 meta into the Level 3 meta is as follows:
basic.visit: observation.visit
basic.segment: observation.segment
basic.pass: observation.pass
basic.program: observation.program
basic.optical_element: optical_element
basic.instrument: instrument.name
basic.telescope: telescope
program: program
"""
l3_meta.basic.visit = l2_meta.observation.visit
l3_meta.basic.segment = l2_meta.observation.segment
l3_meta.basic["pass"] = l2_meta.observation["pass"]
l3_meta.basic.program = l2_meta.observation.program
l3_meta.basic.optical_element = l2_meta.instrument.optical_element
l3_meta.basic.instrument = l2_meta.instrument.name
l3_meta.coordinates = l2_meta.coordinates
l3_meta.program = l2_meta.program


def gwcs_into_l3(model, wcs):
"""Update the Level 3 wcsinfo block from a GWCS object

Expand Down Expand Up @@ -870,97 +838,3 @@ def calc_pa(wcs, ra, dec):
coord = SkyCoord(ra, dec, frame="icrs", unit="deg")

return coord.position_angle(delta_coord).degree


def populate_mosaic_basic(
output_model: datamodels.MosaicModel, input_models: list | ModelLibrary
):
"""
Populate basic metadata fields in the output mosaic model based on input models.

Parameters
----------
output_model : MosaicModel
Object to populate with basic metadata.
input_models : [List, ModelLibrary]
List of input data models from which to extract the metadata.
ModelLibrary is also supported.

Returns
-------
None
"""

input_meta = [datamodel.meta for datamodel in input_models]

# time data
output_model.meta.basic.time_first_mjd = np.min(
[x.exposure.start_time.mjd for x in input_meta]
)
output_model.meta.basic.time_last_mjd = np.max(
[x.exposure.end_time.mjd for x in input_meta]
)
output_model.meta.basic.time_mean_mjd = np.mean(
[x.exposure.mid_time.mjd for x in input_meta]
)

# observation data
output_model.meta.basic.visit = (
input_meta[0].observation.visit
if len({x.observation.visit for x in input_meta}) == 1
else -1
)
output_model.meta.basic.segment = (
input_meta[0].observation.segment
if len({x.observation.segment for x in input_meta}) == 1
else -1
)
output_model.meta.basic["pass"] = (
input_meta[0].observation["pass"]
if len({x.observation["pass"] for x in input_meta}) == 1
else -1
)
output_model.meta.basic.program = (
input_meta[0].observation.program
if len({x.observation.program for x in input_meta}) == 1
else -1
)

# instrument data
output_model.meta.basic.optical_element = input_meta[0].instrument.optical_element
output_model.meta.basic.instrument = input_meta[0].instrument.name

# association product type
output_model.meta.basic.product_type = "TBD"


def populate_mosaic_individual(
output_model: datamodels.MosaicModel, input_models: [list, ModelLibrary]
):
"""
Populate individual meta fields in the output mosaic model based on input models.

Parameters
----------
output_model : MosaicModel
Object to populate with basic metadata.
input_models : [List, ModelLibrary]
List of input data models from which to extract the metadata.
ModelLibrary is also supported.

Returns
-------
None
"""

input_metas = [datamodel.meta for datamodel in input_models]
for input_meta in input_metas:
output_model.append_individual_image_meta(input_meta)


def copy_asn_info_from_library(input_models, output_model):
# copy over asn information
if (asn_pool := input_models.asn.get("asn_pool", None)) is not None:
output_model.meta.asn.pool_name = asn_pool
if (asn_table_name := input_models.asn.get("table_name", None)) is not None:
output_model.meta.asn.table_name = asn_table_name
Loading
Loading