Skip to content

Commit 18a03d5

Browse files
committed
placeholder while I wait for Rob's snappl PR. also, some ruff
1 parent d8dc15f commit 18a03d5

File tree

3 files changed

+49
-37
lines changed

3 files changed

+49
-37
lines changed

phrosty/pipeline.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def __init__( self, image, pipeline ):
9494
self.psf_data = None
9595

9696
def run_sky_subtract( self, mp=True ):
97-
"""
98-
Run sky subtraction using Source Extractor.
97+
"""Run sky subtraction using Source Extractor.
9998
10099
Parameters
101100
----------
@@ -108,24 +107,23 @@ def run_sky_subtract( self, mp=True ):
108107
Tuple containing sky subtracted image, detection
109108
mask array, and sky RMS value.
110109
Output of phrosty.imagesubtraction.sky_subtract().
111-
"""
110+
"""
112111
try:
113112
return sky_subtract( self.image )
114113
except Exception as ex:
115114
SNLogger.exception( ex )
116115
raise
117116

118117
def save_sky_subtract_info( self, info ):
119-
"""
120-
Saves the sky-subtracted image, detection mask array,
118+
"""Saves the sky-subtracted image, detection mask array,
121119
and sky RMS values to attributes.
122120
123121
Parameters
124122
----------
125123
info : tuple
126124
Output of self.run_sky_subtract(). See documentation
127125
for phrosty.imagesubtraction.sky_subtract().
128-
"""
126+
"""
129127
try:
130128
SNLogger.debug( f"Saving sky_subtract info for path {info[0]}" )
131129
self.skysub_img = info[0]
@@ -177,8 +175,7 @@ def get_psf( self, ra, dec ):
177175
return None
178176

179177
def keep_psf_data( self, psf_data ):
180-
"""
181-
Save PSF data to attribute. If self.get_psf() failed,
178+
"""Save PSF data to attribute. If self.get_psf() failed,
182179
then the image is recorded as a failure.
183180
184181
Parameters
@@ -298,8 +295,7 @@ def __init__( self, diaobj, imgcol, band,
298295

299296

300297
def _read_csv( self, csvfile ):
301-
"""
302-
Reads input csv files with columns:
298+
"""Reads input csv files with columns:
303299
'path pointing sca mjd band'.
304300
305301
Parameters
@@ -332,21 +328,20 @@ def _read_csv( self, csvfile ):
332328

333329

334330
def sky_sub_all_images( self ):
335-
"""
336-
Sky subtracts all snappl.image.Image objects in
331+
"""Sky subtracts all snappl.image.Image objects in
337332
self.science_images and self.template_images using
338333
Source Extractor.
339334
340335
Contains its own error logging function, log_error().
341336
342-
"""
337+
"""
343338

344339
# Currently, this writes out a bunch of FITS files. Further refactoring needed
345340
# to support more general image types.
346341
all_imgs = self.science_images.copy() # shallow copy
347342
all_imgs.extend( self.template_images )
348343

349-
def log_error( img, x ):
344+
def log_error( img, x ):
350345

351346
SNLogger.error( f"Sky subtraction failure on {img.image.path}: {x}" )
352347
self.failures['skysub'].append( f'{img.image.band} {img.image.pointing} {img.image.sca}' )
@@ -364,8 +359,7 @@ def log_error( img, x ):
364359
img.save_sky_subtract_info( img.run_sky_subtract( mp=False ) )
365360

366361
def get_psfs( self ):
367-
"""
368-
Retrieve PSFs for all snappl.image.Image objects in
362+
"""Retrieve PSFs for all snappl.image.Image objects in
369363
self.science_images and self.template_images.
370364
371365
Contains its own error logging function, log_error().
@@ -385,7 +379,7 @@ def log_error( img, x ):
385379
# callback_partial = partial( img.save_psf_path, all_imgs )
386380
pool.apply_async( img.get_psf, (self.diaobj.ra, self.diaobj.dec), {},
387381
callback=img.keep_psf_data,
388-
error_callback=partial(log_error,img) )
382+
error_callback=partial(log_error, img) )
389383
pool.close()
390384
pool.join()
391385
else:
@@ -598,8 +592,7 @@ def make_phot_info_dict( self, sci_image, templ_image, ap_r=4 ):
598592
return results_dict
599593

600594
def add_to_results_dict( self, one_pair ):
601-
"""
602-
Record results from self.make_phot_info_dict() to the
595+
"""Record results from self.make_phot_info_dict() to the
603596
aggregate dictionary for the entire light curve.
604597
605598
Parameters
@@ -619,9 +612,7 @@ def add_to_results_dict( self, one_pair ):
619612
SNLogger.debug( "Done adding to results dict" )
620613

621614
def save_stamp_paths( self, sci_image, templ_image, paths ):
622-
"""
623-
624-
Helper function for recording the stamp paths returned in
615+
"""Helper function for recording the stamp paths returned in
625616
self.do_stamps.
626617
627618
Parameters
@@ -641,8 +632,7 @@ def save_stamp_paths( self, sci_image, templ_image, paths ):
641632
sci_image.diff_var_stamp_path[ templ_image.image.name ] = paths[2]
642633

643634
def do_stamps( self, sci_image, templ_image ):
644-
"""
645-
Make stamps from the zero point image, decorrelated
635+
"""Make stamps from the zero point image, decorrelated
646636
difference image, and variance image centered at the
647637
location of the supernova.
648638
@@ -659,7 +649,7 @@ def do_stamps( self, sci_image, templ_image ):
659649
Paths to the stamps corresponding to the zero point image,
660650
decorrelated difference image, and variance image centered
661651
at the location of the supernova.
662-
"""
652+
"""
663653

664654
try:
665655
zptim = FITSImageOnDisk( sci_image.decorr_zptimg_path[ templ_image.image.name ] )
@@ -694,9 +684,9 @@ def do_stamps( self, sci_image, templ_image ):
694684
self.failures['make_stamps'].append({'science': f'{sci_image.image.band} {sci_image.image.pointing} {sci_image.image.sca}',
695685
'template': f'{templ_image.image.band} {templ_image.image.pointing} {templ_image.image.sca}'
696686
})
687+
697688
def make_lightcurve( self ):
698-
"""
699-
Collect all results from photometry in one dictionary.
689+
"""Collect all results from photometry in one dictionary.
700690
Write the output to a csv as a table.
701691
702692
Contains its own error logging function, log_error().
@@ -705,7 +695,7 @@ def make_lightcurve( self ):
705695
-------
706696
pathlib.Path
707697
Path to output csv file that contains a light curve.
708-
"""
698+
"""
709699
SNLogger.info( "Making lightcurve." )
710700

711701
self.results_dict = {
@@ -760,8 +750,7 @@ def log_error( sci_image, templ_image, x ):
760750
return results_savepath
761751

762752
def write_fits_file( self, data, header, savepath ):
763-
"""
764-
Helper function for writing fits files.
753+
"""Helper function for writing fits files.
765754
766755
Parameters
767756
----------
@@ -784,15 +773,14 @@ def write_fits_file( self, data, header, savepath ):
784773
raise
785774

786775
def clear_contents( self, directory ):
787-
"""
788-
Delete contents of a directory. Used to clear temporary
776+
"""Delete contents of a directory. Used to clear temporary
789777
files.
790778
791779
Parameters
792780
----------
793781
directory : pathlib.Path
794782
Path to directory to empty.
795-
"""
783+
"""
796784
for f in directory.iterdir():
797785
try:
798786
if f.is_dir():

phrosty/tests/conftest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import numpy as np
12
import pytest # noqa: F401
23
import pathlib
34

45
import tox # noqa: F401
56
from tox.pytest import init_fixture # noqa: F401
67

78
from snpit_utils.config import Config
8-
from snappl.image import FITSImageOnDisk
9+
from snappl.image import FITSImageOnDisk, ManualFITSImage
910
from snappl.diaobject import DiaObject
1011
from snappl.imagecollection import ImageCollection
1112

@@ -174,3 +175,14 @@ def two_ou2024_science_images( ou2024_image_collection ):
174175
img1 = ou2024_image_collection.get_image( pointing=35198, sca=2, band='Y106' )
175176
img2 = ou2024_image_collection.get_image( pointing=39790, sca=15, band='Y106' )
176177
return img1, img2
178+
179+
@pytest.fixture
180+
def nan_image( ou2024_image_collection ):
181+
ou_img = ou2024_image_collection.get_image( pointing=35198, sca=2, band='Y106' )
182+
ou_header = ou_img.get_fits_header()
183+
184+
nan_arr = np.empty((4088,4088))
185+
nan_arr[:] = np.nan
186+
187+
nan_img = ManualFITSImage(header=ou_header, data=nan_arr, pointing=35198, sca=2)
188+
return nan_img

phrosty/tests/test_pipeline.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import numpy as np
23
import pytest
34
from phrosty.pipeline import Pipeline
45

@@ -70,16 +71,27 @@ def test_pipeline_run( object_for_tests, ou2024_image_collection,
7071
# directories for tests and for running... so don't do that... but the
7172
# way we're set up right now, you probably are.
7273

74+
7375
@pytest.mark.skipif( os.getenv("SKIP_GPU_TESTS", 0 ), reason="SKIP_GPU_TESTS is set" )
7476
def test_pipeline_failures( object_for_tests, ou2024_image_collection,
75-
one_ou2024_template_image, two_ou2024_science_images ):
77+
one_ou2024_template_image, two_ou2024_science_images,
78+
nan_image ):
7679
pip = Pipeline( object_for_tests, ou2024_image_collection, 'Y106',
7780
science_images=two_ou2024_science_images,
7881
template_images=[one_ou2024_template_image],
7982
nprocs=2, nwrite=3 )
80-
83+
8184
# First, check the images as-is. Make sure there are no failures.
8285
for key in pip.failures:
8386
assert len(pip.failures[key]) == 0
87+
88+
new_test_imgs = [nan_image, two_ou2024_science_images[1]]
8489

85-
90+
pip = Pipeline( object_for_tests, ou2024_image_collection, 'Y106',
91+
science_images=new_test_imgs,
92+
template_images=[one_ou2024_template_image],
93+
nprocs=2, nwrite=3 )
94+
95+
for key in pip.failures:
96+
print(key)
97+
print(len(pip.failures[key]))

0 commit comments

Comments
 (0)