Skip to content

Commit 54d0e09

Browse files
authored
Merge pull request #20 from LSSTDESC/dev/joseph
Dev/joseph
2 parents 67df9f7 + ebd710e commit 54d0e09

File tree

3 files changed

+79
-71
lines changed

3 files changed

+79
-71
lines changed

examples/SHIRE_demo_LSSTsim_mini.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@
9393
" ssp_file=\"ssp_data_fsps_v3.2_lgmet_age.h5\",\n",
9494
" filter_dict=lsst_filts_dict,\n",
9595
" wlmin=900.,\n",
96-
" wlmax=25000.,\n",
97-
" dwl=100.,\n",
96+
" wlmax=12000.,\n",
97+
" dwl=20.,\n",
9898
" zmin=0.01,\n",
9999
" zmax=3.0,\n",
100100
" nzbins=50,\n",
@@ -244,14 +244,14 @@
244244
" err_bands=_errbands,\n",
245245
" zmin=0.01,\n",
246246
" zmax=3.1,\n",
247-
" nzbins=310,\n",
247+
" nzbins=150,\n",
248248
" ssp_file=\"ssp_data_fsps_v3.2_lgmet_age.h5\",\n",
249249
" filter_dict=lsst_filts_dict,\n",
250-
" wlmin=500.,\n",
251-
" wlmax=25000.,\n",
252-
" dwl=5.,\n",
250+
" wlmin=900.,\n",
251+
" wlmax=12000.,\n",
252+
" dwl=20.,\n",
253253
" no_prior=not(use_prior),\n",
254-
" chunk_size=5000\n",
254+
" chunk_size=250\n",
255255
")\n",
256256
"\n",
257257
"run_shire_estimate = ShireEstimator.make_stage(\n",
@@ -355,7 +355,7 @@
355355
"stages = [run_shire_inform, run_shire_estimate]\n",
356356
"for stage in stages:\n",
357357
" pipe.add_stage(stage)\n",
358-
"pipe.stage_execution_config[f'shire_estimate_lsstSimhp10552_demo{_suffix}'].nprocess=1"
358+
"pipe.stage_execution_config[f'shireSPS_estimate_lsstSimhp10552_demo{_suffix}'].nprocess=1"
359359
]
360360
},
361361
{
@@ -426,7 +426,7 @@
426426
"name": "python",
427427
"nbconvert_exporter": "python",
428428
"pygments_lexer": "ipython3",
429-
"version": "3.11.8"
429+
"version": "3.13.5"
430430
}
431431
},
432432
"nbformat": 4,

src/rail/estimation/algos/inform.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def _load_templates(self):
422422
self.avs,
423423
sspdata
424424
)
425-
templ_d4k = treemap_d4000(templ_pars_arr, fwls, pzs, self.avs, sspdata)
425+
templ_d4k = treemap_d4000(templ_pars_arr, pzs, self.avs, sspdata)
426426
else:
427427
templ_tupl_sps = make_legacy_templates(
428428
templ_pars_arr,
@@ -439,7 +439,7 @@ def _load_templates(self):
439439
self.avs,
440440
sspdata
441441
)
442-
templ_d4k = treemap_d4000_leg(templ_pars_arr, fwls, templ_zref, self.avs, sspdata)
442+
templ_d4k = treemap_d4000_leg(templ_pars_arr, templ_zref, self.avs, sspdata)
443443

444444
filters_names = [_fnam for _fnam, _fdir in self.config.filter_dict.items()]
445445
color_names = [f"{n1}-{n2}" for n1,n2 in zip(filters_names[:-1], filters_names[1:])]
@@ -1553,19 +1553,19 @@ def plot_templ_seds(self, redshifts=None, wlmin=None, wlmax=None, ymin=None, yma
15531553
vmap_mean_spectrum(wls, templ_pars, redshifts, sspdata),
15541554
0.001
15551555
)
1556-
nuvk = v_nuvk_dusty(templ_pars, wls, redshifts, sspdata)
1557-
_selnorm = jnp.logical_and(wls>3950, wls<4000)
1556+
nuvk = v_nuvk_dusty(templ_pars, redshifts, sspdata)
1557+
_selnorm = jnp.logical_and(wls>=3950, wls<=4000)
15581558
norms = jnp.nanmean(restframe_fnus[:, :, _selnorm], axis=2)
15591559
restframe_fnus = restframe_fnus/jnp.expand_dims(jnp.squeeze(norms), 2)
15601560
else:
15611561
_vspec = vmap(mean_spectrum, in_axes=(None, 0, 0, None))
1562-
_vnuvk = vmap(calc_nuvk_dusty, in_axes=(0, None, 0, None))
1562+
_vnuvk = vmap(calc_nuvk_dusty, in_axes=(0, 0, None))
15631563
restframe_fnus = lsunPerHz_to_fnu_noU(
15641564
_vspec(wls, templ_pars, templ_zref, sspdata),
15651565
0.001
15661566
)
1567-
nuvk = _vnuvk(templ_pars, wls, templ_zref, sspdata)
1568-
_selnorm = jnp.logical_and(wls>3950, wls<4000)
1567+
nuvk = _vnuvk(templ_pars, templ_zref, sspdata)
1568+
_selnorm = jnp.logical_and(wls>=3950, wls<=4000)
15691569
norms = jnp.nanmean(restframe_fnus[:, _selnorm], axis=1)
15701570
restframe_fnus = restframe_fnus/jnp.expand_dims(jnp.squeeze(norms), 1)
15711571

@@ -1643,19 +1643,19 @@ def plot_templ_seds_d4000(self, redshifts=None, wlmin=None, wlmax=None, ymin=Non
16431643
vmap_mean_spectrum_nodust(wls, templ_pars, redshifts, sspdata),
16441644
0.001
16451645
)
1646-
d4000n = v_d4000n(templ_pars, wls, redshifts, sspdata)
1647-
_selnorm = jnp.logical_and(wls>3950, wls<4000)
1646+
d4000n = v_d4000n(templ_pars, redshifts, sspdata)
1647+
_selnorm = jnp.logical_and(wls>=3950, wls<=4000)
16481648
norms = jnp.nanmean(restframe_fnus[:, :, _selnorm], axis=2)
16491649
restframe_fnus = restframe_fnus/jnp.expand_dims(jnp.squeeze(norms), 2)
16501650
else:
16511651
_vspec = vmap(mean_spectrum_nodust, in_axes=(None, 0, 0, None))
1652-
_vd4k = vmap(calc_d4000n, in_axes=(0, None, 0, None))
1652+
_vd4k = vmap(calc_d4000n, in_axes=(0, 0, None))
16531653
restframe_fnus = lsunPerHz_to_fnu_noU(
16541654
_vspec(wls, templ_pars, templ_zref, sspdata),
16551655
0.001
16561656
)
1657-
d4000n = _vd4k(templ_pars, wls, templ_zref, sspdata)
1658-
_selnorm = jnp.logical_and(wls>3950, wls<4000)
1657+
d4000n = _vd4k(templ_pars, templ_zref, sspdata)
1658+
_selnorm = jnp.logical_and(wls>=3950, wls<=4000)
16591659
norms = jnp.nanmean(restframe_fnus[:, _selnorm], axis=1)
16601660
restframe_fnus = restframe_fnus/jnp.expand_dims(jnp.squeeze(norms), 1)
16611661
rbmap = mpl.colormaps['coolwarm']

src/rail/estimation/algos/template.py

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def templ_iclrs_nuvk(params, wls, filt_trans_arr, z_obs, av, ssp_data, id_imag):
432432

433433

434434
@jit
435-
def calc_nuvk(pars_arr, wls, z_obs, ssp_data):
435+
def calc_nuvk(pars_arr, z_obs, ssp_data):
436436
"""calc_nuvk _summary_
437437
438438
:param pars_arr: _description_
@@ -446,138 +446,146 @@ def calc_nuvk(pars_arr, wls, z_obs, ssp_data):
446446
:return: _description_
447447
:rtype: _type_
448448
"""
449-
sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
449+
#sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
450+
451+
# get the restframe spectra without and with dust attenuation
452+
ssp_wave, rest_sed, _ = ssp_spectrum_fromparam(pars_arr, z_obs, ssp_data)
450453
_nuvk = jnp.array(
451454
[
452-
calc_rest_mag(wls, sed, NUV_filt.wavelength, NUV_filt.transmission),
453-
calc_rest_mag(wls, sed, NIR_filt.wavelength, NIR_filt.transmission)
455+
calc_rest_mag(ssp_wave, rest_sed, NUV_filt.wavelength, NUV_filt.transmission),
456+
calc_rest_mag(ssp_wave, rest_sed, NIR_filt.wavelength, NIR_filt.transmission)
454457
]
455458
)
456459

457460
return _nuvk[0]-_nuvk[1]
458461

459462

460-
v_nuvk_zo = vmap(calc_nuvk, in_axes=(None, None, 0, None))
461-
v_nuvk = vmap(v_nuvk_zo, in_axes=(0, None, None, None))
463+
v_nuvk_zo = vmap(calc_nuvk, in_axes=(None, 0, None))
464+
v_nuvk = vmap(v_nuvk_zo, in_axes=(0, None, None))
462465

463466

464467
@jit
465-
def calc_nuvk_dusty(pars_arr, wls, z_obs, ssp_data):
468+
def calc_nuvk_dusty(pars_arr, z_obs, ssp_data):
466469
"""calc_nuvk_dusty _summary_
467470
468471
:param pars_arr: _description_
469472
:type pars_arr: _type_
470-
:param wls: _description_
471-
:type wls: _type_
472473
:param z_obs: _description_
473474
:type z_obs: _type_
474475
:param ssp_data: _description_
475476
:type ssp_data: _type_
476477
:return: _description_
477478
:rtype: _type_
478479
"""
479-
sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
480+
#sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
481+
482+
# get the restframe spectra without and with dust attenuation
483+
ssp_wave, _, sed_attenuated = ssp_spectrum_fromparam(pars_arr, z_obs, ssp_data)
480484
_nuvk = jnp.array(
481485
[
482-
calc_rest_mag(wls, sed, NUV_filt.wavelength, NUV_filt.transmission),
483-
calc_rest_mag(wls, sed, NIR_filt.wavelength, NIR_filt.transmission)
486+
calc_rest_mag(ssp_wave, sed_attenuated, NUV_filt.wavelength, NUV_filt.transmission),
487+
calc_rest_mag(ssp_wave, sed_attenuated, NIR_filt.wavelength, NIR_filt.transmission)
484488
]
485489
)
486490

487491
return _nuvk[0]-_nuvk[1]
488492

489493

490-
v_nuvk_zo_dusty = vmap(calc_nuvk_dusty, in_axes=(None, None, 0, None))
491-
v_nuvk_dusty = vmap(v_nuvk_zo_dusty, in_axes=(0, None, None, None))
494+
v_nuvk_zo_dusty = vmap(calc_nuvk_dusty, in_axes=(None, 0, None))
495+
v_nuvk_dusty = vmap(v_nuvk_zo_dusty, in_axes=(0, None, None))
492496

493497

494498
@jit
495-
def calc_d4000n(pars_arr, wls, z_obs, ssp_data):
499+
def calc_d4000n(pars_arr, z_obs, ssp_data):
496500
"""calc_d4000n _summary_
497501
498502
:param pars_arr: _description_
499503
:type pars_arr: _type_
500-
:param wls: _description_
501-
:type wls: _type_
502504
:param z_obs: _description_
503505
:type z_obs: _type_
504506
:param ssp_data: _description_
505507
:type ssp_data: _type_
506508
:return: _description_
507509
:rtype: _type_
508510
"""
509-
sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
511+
#sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
512+
513+
# get the restframe spectra without and with dust attenuation
514+
ssp_wave, rest_sed, _ = ssp_spectrum_fromparam(pars_arr, z_obs, ssp_data)
510515
d4000 = jnp.array(
511516
[
512-
calc_rest_mag(wls, sed, D4000b_filt.wavelength, D4000b_filt.transmission),
513-
calc_rest_mag(wls, sed, D4000r_filt.wavelength, D4000r_filt.transmission)
517+
calc_rest_mag(ssp_wave, rest_sed, D4000b_filt.wavelength, D4000b_filt.transmission),
518+
calc_rest_mag(ssp_wave, rest_sed, D4000r_filt.wavelength, D4000r_filt.transmission)
514519
]
515520
)
516521

517522
return d4000[0]-d4000[1]
518523

519524

520-
v_d4000n_zo = vmap(calc_d4000n, in_axes=(None, None, 0, None))
521-
v_d4000n = vmap(v_d4000n_zo, in_axes=(0, None, None, None))
525+
v_d4000n_zo = vmap(calc_d4000n, in_axes=(None, 0, None))
526+
v_d4000n = vmap(v_d4000n_zo, in_axes=(0, None, None))
522527

523528

524529
@jit
525-
def calc_d4000n_dusty(pars_arr, wls, z_obs, ssp_data):
530+
def calc_d4000n_dusty(pars_arr, z_obs, ssp_data):
526531
"""calc_d4000n_dusty _summary_
527532
528533
:param pars_arr: _description_
529534
:type pars_arr: _type_
530-
:param wls: _description_
531-
:type wls: _type_
532535
:param z_obs: _description_
533536
:type z_obs: _type_
534537
:param ssp_data: _description_
535538
:type ssp_data: _type_
536539
:return: _description_
537540
:rtype: _type_
538541
"""
539-
sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
542+
#sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
543+
544+
# get the restframe spectra without and with dust attenuation
545+
ssp_wave, _, sed_attenuated = ssp_spectrum_fromparam(pars_arr, z_obs, ssp_data)
540546
d4000 = jnp.array(
541547
[
542-
calc_rest_mag(wls, sed, D4000b_filt.wavelength, D4000b_filt.transmission),
543-
calc_rest_mag(wls, sed, D4000r_filt.wavelength, D4000r_filt.transmission)
548+
calc_rest_mag(ssp_wave, sed_attenuated, D4000b_filt.wavelength, D4000b_filt.transmission),
549+
calc_rest_mag(ssp_wave, sed_attenuated, D4000r_filt.wavelength, D4000r_filt.transmission)
544550
]
545551
)
546552

547553
return d4000[0]-d4000[1]
548554

549555

550-
v_d4000n_zo_dusty = vmap(calc_d4000n_dusty, in_axes=(None, None, 0, None))
551-
v_d4000n_dusty = vmap(v_d4000n_zo_dusty, in_axes=(0, None, None, None))
556+
v_d4000n_zo_dusty = vmap(calc_d4000n_dusty, in_axes=(None, 0, None))
557+
v_d4000n_dusty = vmap(v_d4000n_zo_dusty, in_axes=(0, None, None))
552558

553559

554560
@jit
555-
def d4000n(pars_arr, wls, z_obs, av, ssp_data):
561+
def d4000n(pars_arr, z_obs, av, ssp_data):
556562
_pars = pars_arr.at[13].set(av)
557-
sed = mean_spectrum(wls, _pars, z_obs, ssp_data)
563+
#sed = mean_spectrum(wls, _pars, z_obs, ssp_data)
564+
# get the restframe spectra without and with dust attenuation
565+
ssp_wave, _, sed_attenuated = ssp_spectrum_fromparam(_pars, z_obs, ssp_data)
558566
d4000 = jnp.array(
559567
[
560-
calc_rest_mag(wls, sed, D4000b_filt.wavelength, D4000b_filt.transmission),
561-
calc_rest_mag(wls, sed, D4000r_filt.wavelength, D4000r_filt.transmission)
568+
calc_rest_mag(ssp_wave, sed_attenuated, D4000b_filt.wavelength, D4000b_filt.transmission),
569+
calc_rest_mag(ssp_wave, sed_attenuated, D4000r_filt.wavelength, D4000r_filt.transmission)
562570
]
563571
)
564572

565573
return d4000[0]-d4000[1]
566574

567-
vmap_d4000n_av = vmap(d4000n, in_axes=(None, None, None, 0, None))
568-
vmap_d4000n_zob = vmap(vmap_d4000n_av, in_axes=(None, None, 0, None, None))
569-
vmap_d4000n_pars = vmap(vmap_d4000n_zob, in_axes=(0, None, None, None, None))
575+
vmap_d4000n_av = vmap(d4000n, in_axes=(None, None, 0, None))
576+
vmap_d4000n_zob = vmap(vmap_d4000n_av, in_axes=(None, 0, None, None))
577+
vmap_d4000n_pars = vmap(vmap_d4000n_zob, in_axes=(0, None, None, None))
570578

571-
def treemap_d4000(pars_arr, wls, z_obs, av, ssp_data):
579+
def treemap_d4000(pars_arr, z_obs, av, ssp_data):
572580
templ_tupl = [tuple(_pars) for _pars in pars_arr]
573-
reslist_of_tupl = tree_map(lambda partup: vmap_d4000n_zob(jnp.array(partup), wls, z_obs, av, ssp_data), templ_tupl, is_leaf=istuple)
581+
reslist_of_tupl = tree_map(lambda partup: vmap_d4000n_zob(jnp.array(partup), z_obs, av, ssp_data), templ_tupl, is_leaf=istuple)
574582
return reslist_of_tupl
575583

576-
vmap_d4000n_pars_leg = vmap(vmap_d4000n_av, in_axes=(0, None, 0, None, None))
584+
vmap_d4000n_pars_leg = vmap(vmap_d4000n_av, in_axes=(0, 0, None, None))
577585

578-
def treemap_d4000_leg(pars_arr, wls, zref, av, ssp_data):
586+
def treemap_d4000_leg(pars_arr, zref, av, ssp_data):
579587
templ_tupl = [tuple(_pars)+tuple([z]) for _pars, z in zip(pars_arr, zref, strict=True)]
580-
reslist_of_tupl = tree_map(lambda partup: vmap_d4000n_av(jnp.array(partup[:-1]), wls, partup[-1], av, ssp_data), templ_tupl, is_leaf=istuple)
588+
reslist_of_tupl = tree_map(lambda partup: vmap_d4000n_av(jnp.array(partup[:-1]), partup[-1], av, ssp_data), templ_tupl, is_leaf=istuple)
581589
return reslist_of_tupl
582590

583591
def get_colors_templates(params, wls, z_obs, transm_arr, ssp_data):
@@ -1106,8 +1114,8 @@ def bpt_rews_pars_dusty_leg(templ_pars, zref, ssp_data):
11061114
def colrs_bptrews_templ_zo(templ_pars, wls, zobs, transm_arr, ssp_data):
11071115
t_rews = bpt_rews_pars_zo(templ_pars, zobs, ssp_data)
11081116
t_colors = vmap_cols_zo_nodust(templ_pars, wls, zobs, transm_arr, ssp_data)
1109-
t_nuvk = v_nuvk_zo(templ_pars, wls, zobs, ssp_data)
1110-
t_d4000n = v_d4000n_zo(templ_pars, wls, zobs, ssp_data)
1117+
t_nuvk = v_nuvk_zo(templ_pars, zobs, ssp_data)
1118+
t_d4000n = v_d4000n_zo(templ_pars, zobs, ssp_data)
11111119
return jnp.column_stack((t_colors, t_rews, t_nuvk, t_d4000n))
11121120

11131121
vmap_colrs_bptrews_templ_zo = vmap(colrs_bptrews_templ_zo, in_axes=(0, None, None, None, None))
@@ -1116,8 +1124,8 @@ def colrs_bptrews_templ_zo(templ_pars, wls, zobs, transm_arr, ssp_data):
11161124
def colrs_bptrews_templ_zo_dusty(templ_pars, wls, zobs, transm_arr, ssp_data):
11171125
t_rews = bpt_rews_pars_zo_dusty(templ_pars, zobs, ssp_data)
11181126
t_colors = vmap_cols_zo(templ_pars, wls, zobs, transm_arr, ssp_data)
1119-
t_nuvk = v_nuvk_zo_dusty(templ_pars, wls, zobs, ssp_data) #v_nuvk_zo(templ_pars, wls, zobs, ssp_data) -- NUV-K for prior shall perhaps include dust attenuation
1120-
t_d4000n = v_d4000n_zo_dusty(templ_pars, wls, zobs, ssp_data)
1127+
t_nuvk = v_nuvk_zo_dusty(templ_pars, zobs, ssp_data) #v_nuvk_zo(templ_pars, wls, zobs, ssp_data) -- NUV-K for prior shall perhaps include dust attenuation
1128+
t_d4000n = v_d4000n_zo_dusty(templ_pars, zobs, ssp_data)
11211129
return jnp.column_stack((t_colors, t_rews, t_nuvk, t_d4000n))
11221130

11231131
vmap_colrs_bptrews_templ_zo_dusty = vmap(colrs_bptrews_templ_zo_dusty, in_axes=(0, None, None, None, None))
@@ -1127,8 +1135,8 @@ def colrs_bptrews_templ_zo_dusty(templ_pars, wls, zobs, transm_arr, ssp_data):
11271135
def colrs_bptrews_templ_zo_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data):
11281136
t_rews = bpt_rews_pars_leg(templ_pars, zref, ssp_data)
11291137
t_colors = vmap_cols_zo_nodust_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data)
1130-
t_nuvk = calc_nuvk(templ_pars, wls, zref, ssp_data)
1131-
t_d4000n = calc_d4000n(templ_pars, wls, zref, ssp_data)
1138+
t_nuvk = calc_nuvk(templ_pars, zref, ssp_data)
1139+
t_d4000n = calc_d4000n(templ_pars, zref, ssp_data)
11321140
return jnp.column_stack(
11331141
(
11341142
t_colors,
@@ -1144,8 +1152,8 @@ def colrs_bptrews_templ_zo_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data
11441152
def colrs_bptrews_templ_zo_dusty_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data):
11451153
t_rews = bpt_rews_pars_dusty_leg(templ_pars, zref, ssp_data)
11461154
t_colors = vmap_cols_zo_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data)
1147-
t_nuvk = calc_nuvk_dusty(templ_pars, wls, zref, ssp_data) # calc_nuvk(templ_pars, wls, zref, ssp_data) # -- NUV-K for prior shall perhaps include dust attenuation
1148-
t_d4000n = calc_d4000n_dusty(templ_pars, wls, zref, ssp_data)
1155+
t_nuvk = calc_nuvk_dusty(templ_pars, zref, ssp_data) # calc_nuvk(templ_pars, wls, zref, ssp_data) # -- NUV-K for prior shall perhaps include dust attenuation
1156+
t_d4000n = calc_d4000n_dusty(templ_pars, zref, ssp_data)
11491157
return jnp.column_stack(
11501158
(
11511159
t_colors,

0 commit comments

Comments
 (0)