@@ -432,7 +432,7 @@ def templ_iclrs_nuvk(params, wls, filt_trans_arr, z_obs, av, ssp_data, id_imag):
432432
433433
434434@jit
435- def calc_nuvk (pars_arr , wls , z_obs , ssp_data ):
435+ def calc_nuvk (pars_arr , z_obs , ssp_data ):
436436 """calc_nuvk _summary_
437437
438438 :param pars_arr: _description_
@@ -446,138 +446,146 @@ def calc_nuvk(pars_arr, wls, z_obs, ssp_data):
446446 :return: _description_
447447 :rtype: _type_
448448 """
449- sed = mean_spectrum_nodust (wls , pars_arr , z_obs , ssp_data )
449+ #sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
450+
451+ # get the restframe spectra without and with dust attenuation
452+ ssp_wave , rest_sed , _ = ssp_spectrum_fromparam (pars_arr , z_obs , ssp_data )
450453 _nuvk = jnp .array (
451454 [
452- calc_rest_mag (wls , sed , NUV_filt .wavelength , NUV_filt .transmission ),
453- calc_rest_mag (wls , sed , NIR_filt .wavelength , NIR_filt .transmission )
455+ calc_rest_mag (ssp_wave , rest_sed , NUV_filt .wavelength , NUV_filt .transmission ),
456+ calc_rest_mag (ssp_wave , rest_sed , NIR_filt .wavelength , NIR_filt .transmission )
454457 ]
455458 )
456459
457460 return _nuvk [0 ]- _nuvk [1 ]
458461
459462
460- v_nuvk_zo = vmap (calc_nuvk , in_axes = (None , None , 0 , None ))
461- v_nuvk = vmap (v_nuvk_zo , in_axes = (0 , None , None , None ))
463+ v_nuvk_zo = vmap (calc_nuvk , in_axes = (None , 0 , None ))
464+ v_nuvk = vmap (v_nuvk_zo , in_axes = (0 , None , None ))
462465
463466
464467@jit
465- def calc_nuvk_dusty (pars_arr , wls , z_obs , ssp_data ):
468+ def calc_nuvk_dusty (pars_arr , z_obs , ssp_data ):
466469 """calc_nuvk_dusty _summary_
467470
468471 :param pars_arr: _description_
469472 :type pars_arr: _type_
470- :param wls: _description_
471- :type wls: _type_
472473 :param z_obs: _description_
473474 :type z_obs: _type_
474475 :param ssp_data: _description_
475476 :type ssp_data: _type_
476477 :return: _description_
477478 :rtype: _type_
478479 """
479- sed = mean_spectrum (wls , pars_arr , z_obs , ssp_data )
480+ #sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
481+
482+ # get the restframe spectra without and with dust attenuation
483+ ssp_wave , _ , sed_attenuated = ssp_spectrum_fromparam (pars_arr , z_obs , ssp_data )
480484 _nuvk = jnp .array (
481485 [
482- calc_rest_mag (wls , sed , NUV_filt .wavelength , NUV_filt .transmission ),
483- calc_rest_mag (wls , sed , NIR_filt .wavelength , NIR_filt .transmission )
486+ calc_rest_mag (ssp_wave , sed_attenuated , NUV_filt .wavelength , NUV_filt .transmission ),
487+ calc_rest_mag (ssp_wave , sed_attenuated , NIR_filt .wavelength , NIR_filt .transmission )
484488 ]
485489 )
486490
487491 return _nuvk [0 ]- _nuvk [1 ]
488492
489493
490- v_nuvk_zo_dusty = vmap (calc_nuvk_dusty , in_axes = (None , None , 0 , None ))
491- v_nuvk_dusty = vmap (v_nuvk_zo_dusty , in_axes = (0 , None , None , None ))
494+ v_nuvk_zo_dusty = vmap (calc_nuvk_dusty , in_axes = (None , 0 , None ))
495+ v_nuvk_dusty = vmap (v_nuvk_zo_dusty , in_axes = (0 , None , None ))
492496
493497
494498@jit
495- def calc_d4000n (pars_arr , wls , z_obs , ssp_data ):
499+ def calc_d4000n (pars_arr , z_obs , ssp_data ):
496500 """calc_d4000n _summary_
497501
498502 :param pars_arr: _description_
499503 :type pars_arr: _type_
500- :param wls: _description_
501- :type wls: _type_
502504 :param z_obs: _description_
503505 :type z_obs: _type_
504506 :param ssp_data: _description_
505507 :type ssp_data: _type_
506508 :return: _description_
507509 :rtype: _type_
508510 """
509- sed = mean_spectrum_nodust (wls , pars_arr , z_obs , ssp_data )
511+ #sed = mean_spectrum_nodust(wls, pars_arr, z_obs, ssp_data)
512+
513+ # get the restframe spectra without and with dust attenuation
514+ ssp_wave , rest_sed , _ = ssp_spectrum_fromparam (pars_arr , z_obs , ssp_data )
510515 d4000 = jnp .array (
511516 [
512- calc_rest_mag (wls , sed , D4000b_filt .wavelength , D4000b_filt .transmission ),
513- calc_rest_mag (wls , sed , D4000r_filt .wavelength , D4000r_filt .transmission )
517+ calc_rest_mag (ssp_wave , rest_sed , D4000b_filt .wavelength , D4000b_filt .transmission ),
518+ calc_rest_mag (ssp_wave , rest_sed , D4000r_filt .wavelength , D4000r_filt .transmission )
514519 ]
515520 )
516521
517522 return d4000 [0 ]- d4000 [1 ]
518523
519524
520- v_d4000n_zo = vmap (calc_d4000n , in_axes = (None , None , 0 , None ))
521- v_d4000n = vmap (v_d4000n_zo , in_axes = (0 , None , None , None ))
525+ v_d4000n_zo = vmap (calc_d4000n , in_axes = (None , 0 , None ))
526+ v_d4000n = vmap (v_d4000n_zo , in_axes = (0 , None , None ))
522527
523528
524529@jit
525- def calc_d4000n_dusty (pars_arr , wls , z_obs , ssp_data ):
530+ def calc_d4000n_dusty (pars_arr , z_obs , ssp_data ):
526531 """calc_d4000n_dusty _summary_
527532
528533 :param pars_arr: _description_
529534 :type pars_arr: _type_
530- :param wls: _description_
531- :type wls: _type_
532535 :param z_obs: _description_
533536 :type z_obs: _type_
534537 :param ssp_data: _description_
535538 :type ssp_data: _type_
536539 :return: _description_
537540 :rtype: _type_
538541 """
539- sed = mean_spectrum (wls , pars_arr , z_obs , ssp_data )
542+ #sed = mean_spectrum(wls, pars_arr, z_obs, ssp_data)
543+
544+ # get the restframe spectra without and with dust attenuation
545+ ssp_wave , _ , sed_attenuated = ssp_spectrum_fromparam (pars_arr , z_obs , ssp_data )
540546 d4000 = jnp .array (
541547 [
542- calc_rest_mag (wls , sed , D4000b_filt .wavelength , D4000b_filt .transmission ),
543- calc_rest_mag (wls , sed , D4000r_filt .wavelength , D4000r_filt .transmission )
548+ calc_rest_mag (ssp_wave , sed_attenuated , D4000b_filt .wavelength , D4000b_filt .transmission ),
549+ calc_rest_mag (ssp_wave , sed_attenuated , D4000r_filt .wavelength , D4000r_filt .transmission )
544550 ]
545551 )
546552
547553 return d4000 [0 ]- d4000 [1 ]
548554
549555
550- v_d4000n_zo_dusty = vmap (calc_d4000n_dusty , in_axes = (None , None , 0 , None ))
551- v_d4000n_dusty = vmap (v_d4000n_zo_dusty , in_axes = (0 , None , None , None ))
556+ v_d4000n_zo_dusty = vmap (calc_d4000n_dusty , in_axes = (None , 0 , None ))
557+ v_d4000n_dusty = vmap (v_d4000n_zo_dusty , in_axes = (0 , None , None ))
552558
553559
554560@jit
555- def d4000n (pars_arr , wls , z_obs , av , ssp_data ):
561+ def d4000n (pars_arr , z_obs , av , ssp_data ):
556562 _pars = pars_arr .at [13 ].set (av )
557- sed = mean_spectrum (wls , _pars , z_obs , ssp_data )
563+ #sed = mean_spectrum(wls, _pars, z_obs, ssp_data)
564+ # get the restframe spectra without and with dust attenuation
565+ ssp_wave , _ , sed_attenuated = ssp_spectrum_fromparam (_pars , z_obs , ssp_data )
558566 d4000 = jnp .array (
559567 [
560- calc_rest_mag (wls , sed , D4000b_filt .wavelength , D4000b_filt .transmission ),
561- calc_rest_mag (wls , sed , D4000r_filt .wavelength , D4000r_filt .transmission )
568+ calc_rest_mag (ssp_wave , sed_attenuated , D4000b_filt .wavelength , D4000b_filt .transmission ),
569+ calc_rest_mag (ssp_wave , sed_attenuated , D4000r_filt .wavelength , D4000r_filt .transmission )
562570 ]
563571 )
564572
565573 return d4000 [0 ]- d4000 [1 ]
566574
567- vmap_d4000n_av = vmap (d4000n , in_axes = (None , None , None , 0 , None ))
568- vmap_d4000n_zob = vmap (vmap_d4000n_av , in_axes = (None , None , 0 , None , None ))
569- vmap_d4000n_pars = vmap (vmap_d4000n_zob , in_axes = (0 , None , None , None , None ))
575+ vmap_d4000n_av = vmap (d4000n , in_axes = (None , None , 0 , None ))
576+ vmap_d4000n_zob = vmap (vmap_d4000n_av , in_axes = (None , 0 , None , None ))
577+ vmap_d4000n_pars = vmap (vmap_d4000n_zob , in_axes = (0 , None , None , None ))
570578
571- def treemap_d4000 (pars_arr , wls , z_obs , av , ssp_data ):
579+ def treemap_d4000 (pars_arr , z_obs , av , ssp_data ):
572580 templ_tupl = [tuple (_pars ) for _pars in pars_arr ]
573- reslist_of_tupl = tree_map (lambda partup : vmap_d4000n_zob (jnp .array (partup ), wls , z_obs , av , ssp_data ), templ_tupl , is_leaf = istuple )
581+ reslist_of_tupl = tree_map (lambda partup : vmap_d4000n_zob (jnp .array (partup ), z_obs , av , ssp_data ), templ_tupl , is_leaf = istuple )
574582 return reslist_of_tupl
575583
576- vmap_d4000n_pars_leg = vmap (vmap_d4000n_av , in_axes = (0 , None , 0 , None , None ))
584+ vmap_d4000n_pars_leg = vmap (vmap_d4000n_av , in_axes = (0 , 0 , None , None ))
577585
578- def treemap_d4000_leg (pars_arr , wls , zref , av , ssp_data ):
586+ def treemap_d4000_leg (pars_arr , zref , av , ssp_data ):
579587 templ_tupl = [tuple (_pars )+ tuple ([z ]) for _pars , z in zip (pars_arr , zref , strict = True )]
580- reslist_of_tupl = tree_map (lambda partup : vmap_d4000n_av (jnp .array (partup [:- 1 ]), wls , partup [- 1 ], av , ssp_data ), templ_tupl , is_leaf = istuple )
588+ reslist_of_tupl = tree_map (lambda partup : vmap_d4000n_av (jnp .array (partup [:- 1 ]), partup [- 1 ], av , ssp_data ), templ_tupl , is_leaf = istuple )
581589 return reslist_of_tupl
582590
583591def get_colors_templates (params , wls , z_obs , transm_arr , ssp_data ):
@@ -1106,8 +1114,8 @@ def bpt_rews_pars_dusty_leg(templ_pars, zref, ssp_data):
11061114def colrs_bptrews_templ_zo (templ_pars , wls , zobs , transm_arr , ssp_data ):
11071115 t_rews = bpt_rews_pars_zo (templ_pars , zobs , ssp_data )
11081116 t_colors = vmap_cols_zo_nodust (templ_pars , wls , zobs , transm_arr , ssp_data )
1109- t_nuvk = v_nuvk_zo (templ_pars , wls , zobs , ssp_data )
1110- t_d4000n = v_d4000n_zo (templ_pars , wls , zobs , ssp_data )
1117+ t_nuvk = v_nuvk_zo (templ_pars , zobs , ssp_data )
1118+ t_d4000n = v_d4000n_zo (templ_pars , zobs , ssp_data )
11111119 return jnp .column_stack ((t_colors , t_rews , t_nuvk , t_d4000n ))
11121120
11131121vmap_colrs_bptrews_templ_zo = vmap (colrs_bptrews_templ_zo , in_axes = (0 , None , None , None , None ))
@@ -1116,8 +1124,8 @@ def colrs_bptrews_templ_zo(templ_pars, wls, zobs, transm_arr, ssp_data):
11161124def colrs_bptrews_templ_zo_dusty (templ_pars , wls , zobs , transm_arr , ssp_data ):
11171125 t_rews = bpt_rews_pars_zo_dusty (templ_pars , zobs , ssp_data )
11181126 t_colors = vmap_cols_zo (templ_pars , wls , zobs , transm_arr , ssp_data )
1119- t_nuvk = v_nuvk_zo_dusty (templ_pars , wls , zobs , ssp_data ) #v_nuvk_zo(templ_pars, wls, zobs, ssp_data) -- NUV-K for prior shall perhaps include dust attenuation
1120- t_d4000n = v_d4000n_zo_dusty (templ_pars , wls , zobs , ssp_data )
1127+ t_nuvk = v_nuvk_zo_dusty (templ_pars , zobs , ssp_data ) #v_nuvk_zo(templ_pars, wls, zobs, ssp_data) -- NUV-K for prior shall perhaps include dust attenuation
1128+ t_d4000n = v_d4000n_zo_dusty (templ_pars , zobs , ssp_data )
11211129 return jnp .column_stack ((t_colors , t_rews , t_nuvk , t_d4000n ))
11221130
11231131vmap_colrs_bptrews_templ_zo_dusty = vmap (colrs_bptrews_templ_zo_dusty , in_axes = (0 , None , None , None , None ))
@@ -1127,8 +1135,8 @@ def colrs_bptrews_templ_zo_dusty(templ_pars, wls, zobs, transm_arr, ssp_data):
11271135def colrs_bptrews_templ_zo_leg (templ_pars , wls , zobs , zref , transm_arr , ssp_data ):
11281136 t_rews = bpt_rews_pars_leg (templ_pars , zref , ssp_data )
11291137 t_colors = vmap_cols_zo_nodust_leg (templ_pars , wls , zobs , zref , transm_arr , ssp_data )
1130- t_nuvk = calc_nuvk (templ_pars , wls , zref , ssp_data )
1131- t_d4000n = calc_d4000n (templ_pars , wls , zref , ssp_data )
1138+ t_nuvk = calc_nuvk (templ_pars , zref , ssp_data )
1139+ t_d4000n = calc_d4000n (templ_pars , zref , ssp_data )
11321140 return jnp .column_stack (
11331141 (
11341142 t_colors ,
@@ -1144,8 +1152,8 @@ def colrs_bptrews_templ_zo_leg(templ_pars, wls, zobs, zref, transm_arr, ssp_data
11441152def colrs_bptrews_templ_zo_dusty_leg (templ_pars , wls , zobs , zref , transm_arr , ssp_data ):
11451153 t_rews = bpt_rews_pars_dusty_leg (templ_pars , zref , ssp_data )
11461154 t_colors = vmap_cols_zo_leg (templ_pars , wls , zobs , zref , transm_arr , ssp_data )
1147- t_nuvk = calc_nuvk_dusty (templ_pars , wls , zref , ssp_data ) # calc_nuvk(templ_pars, wls, zref, ssp_data) # -- NUV-K for prior shall perhaps include dust attenuation
1148- t_d4000n = calc_d4000n_dusty (templ_pars , wls , zref , ssp_data )
1155+ t_nuvk = calc_nuvk_dusty (templ_pars , zref , ssp_data ) # calc_nuvk(templ_pars, wls, zref, ssp_data) # -- NUV-K for prior shall perhaps include dust attenuation
1156+ t_d4000n = calc_d4000n_dusty (templ_pars , zref , ssp_data )
11491157 return jnp .column_stack (
11501158 (
11511159 t_colors ,
0 commit comments