Skip to content

Commit 4d537c2

Browse files
authored
Merge pull request #95 from Roman-Supernova-PIT/91-free-mem
free memory. update phrosty_config.yaml to match
2 parents 6f5aecb + d0846d4 commit 4d537c2

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

changes/95.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add function to free memory after processing images. Also, add option to trace CPU memory. Updates example phrosty_config.yaml file to match new mem_trace option.

examples/perlmutter/phrosty_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ photometry:
77
image_type: ou2024fits
88
force_sky_subtract: true
99
keep_intermediate: true
10+
mem_trace: false
1011

1112
paths:
1213
scratch_dir: /scratch

phrosty/pipeline.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import nvtx
1010
import pathlib
1111
import re
12+
import tracemalloc
1213

1314
# Imports ASTRO
1415
from astropy.coordinates import SkyCoord
@@ -172,6 +173,10 @@ def get_psf( self, ra, dec ):
172173
def keep_psf_data( self, psf_data ):
173174
self.psf_data = psf_data
174175

176+
def free( self ):
177+
"""Try to free memory. More might be done here."""
178+
self.image.free()
179+
175180

176181
class Pipeline:
177182
def __init__( self, object_id, ra, dec, band, science_images, template_images, nprocs=1, nwrite=5,
@@ -227,6 +232,7 @@ def __init__( self, object_id, ra, dec, band, science_images, template_images, n
227232
SNLogger.warning( "nuke_temp_dir not implemented" )
228233

229234
self.keep_intermediate = self.config.value( 'photometry.phrosty.keep_intermediate' )
235+
self.mem_trace = self.config.value( 'photometry.phrosty.mem_trace' )
230236

231237

232238
def sky_sub_all_images( self ):
@@ -588,6 +594,11 @@ def write_fits_file( self, data, header, savepath ):
588594
fits.writeto( savepath, data, header=header, overwrite=True )
589595

590596
def __call__( self, through_step=None ):
597+
if self.mem_trace:
598+
tracemalloc.start()
599+
tracemalloc.reset_peak()
600+
601+
591602
if through_step is None:
592603
through_step = 'make_lightcurve'
593604

@@ -603,11 +614,17 @@ def __call__( self, through_step=None ):
603614
with nvtx.annotate( "skysub", color=0xff8888 ):
604615
self.sky_sub_all_images()
605616

617+
if self.mem_trace:
618+
SNLogger.info( f"After sky_subtract, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" )
619+
606620
if 'get_psfs' in steps:
607621
SNLogger.info( "Getting PSFs" )
608622
with nvtx.annotate( "getpsfs", color=0xff8888 ):
609623
self.get_psfs()
610624

625+
if self.mem_trace:
626+
SNLogger.info( f"After get_psfs, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" )
627+
611628
# Create a process pool to write fits files
612629
with Pool( self.nwrite ) as fits_writer_pool:
613630

@@ -741,7 +758,16 @@ def log_fits_write_error( savepath, x ):
741758
savepath = self.scratch_dir / f'{key}_{imgtype}_{name}'
742759
self.write_fits_file( data, header, savepath=savepath )
743760

744-
SNLogger.info( f"DONE processing {sci_image.image.name} minus {templ_image.image.name}" )
761+
SNLogger.info( f"DONE processing {sci_image.image.name} minus {templ_image.image.name}" )
762+
if self.mem_trace:
763+
SNLogger.info( f"After preprocessing, subtracting, and postprocessing \
764+
a science image, memory usage = \
765+
{tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" )
766+
767+
sci_image.free()
768+
769+
SNLogger.info( f"DONE with all science images for template {templ_image.image.name}" )
770+
templ_image.free()
745771

746772
SNLogger.info( "Waiting for FITS writer processes to finish" )
747773
with nvtx.annotate( "fits_write_wait", color=0xff8888 ):
@@ -788,14 +814,21 @@ def log_fits_write_error( savepath, x ):
788814

789815
SNLogger.info('...finished making stamps.')
790816

817+
if self.mem_trace:
818+
SNLogger.info( f"After make_stamps, memory usage = {tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" )
819+
791820
if 'make_lightcurve' in steps:
792821
SNLogger.info( "Making lightcurve" )
793822
with nvtx.annotate( "make_lightcurve", color=0xff8888 ):
794823
self.make_lightcurve()
795824

825+
if self.mem_trace:
826+
SNLogger.info( f"After make_lightcurve, memory usage = \
827+
{tracemalloc.get_traced_memory()[1]/(1024**2):.2f} MB" )
796828

797829
# ======================================================================
798830

831+
799832
def main():
800833
# Run one arg pass just to get the config file, so we can augment
801834
# the full arg parser later with config options

0 commit comments

Comments
 (0)