Skip to content

Commit

Permalink
new scipy interpolation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
djones1040 committed Jul 29, 2024
1 parent 055259b commit f0bccb4
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions saltshaker/validation/figs/plotSALTModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pylab as plt
import sys
from scipy.interpolate import interp1d, interp2d
from scipy.interpolate import interp1d, interp2d, RectBivariateSpline
from sncosmo.salt2utils import SALT2ColorLaw
from saltshaker.training import colorlaw
from saltshaker.initfiles import init_rootdir
Expand Down Expand Up @@ -122,15 +122,15 @@ def mkModelPlot(
salt3m1flux*=sgn
spacing = 0.5
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):
int_salt2m0 = interp2d(salt2m0wave,salt2m0phase,salt2m0flux)
int_salt2m0err = interp2d(salt2m0errwave,salt2m0errphase,salt2m0fluxerr)
salt2m0flux_0 = int_salt2m0(salt2m0wave,plotphase)
salt2m0fluxerr_0 = int_salt2m0err(salt2m0wave,plotphase)
int_salt2m0 = RectBivariateSpline(salt2m0phase,salt2m0wave,salt2m0flux)
int_salt2m0err = RectBivariateSpline(salt2m0errphase,salt2m0errwave,salt2m0fluxerr)
salt2m0flux_0 = int_salt2m0(plotphase,salt2m0wave)[0]
salt2m0fluxerr_0 = int_salt2m0err(plotphase,salt2m0wave)[0]

int_salt3m0 = interp2d(salt3m0wave,salt3m0phase,salt3m0flux)
int_salt3m0err = interp2d(salt3m0errwave,salt3m0errphase,salt3m0fluxerr)
salt3m0flux_0 = int_salt3m0(salt3m0wave,plotphase)
salt3m0fluxerr_0 = int_salt3m0err(salt3m0wave,plotphase)
int_salt3m0 = RectBivariateSpline(salt3m0phase,salt3m0wave,salt3m0flux)
int_salt3m0err = RectBivariateSpline(salt3m0errphase,salt3m0errwave,salt3m0fluxerr)
salt3m0flux_0 = int_salt3m0(plotphase,salt3m0wave)[0]
salt3m0fluxerr_0 = int_salt3m0err(plotphase,salt3m0wave)[0]

ax1.plot(salt2m0wave,salt2m0flux_0+spacing*i,color='b',label='SALT2')
ax1.fill_between(salt2m0wave,
Expand All @@ -149,15 +149,15 @@ def mkModelPlot(

spacing = 0.15
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):
int_salt2m1 = interp2d(salt2m1wave,salt2m1phase,salt2m1flux)
int_salt2m1err = interp2d(salt2m1errwave,salt2m1errphase,salt2m1fluxerr)
salt2m1flux_0 = int_salt2m1(salt2m1wave,plotphase)
salt2m1fluxerr_0 = int_salt2m1err(salt2m1wave,plotphase)

int_salt3m1 = interp2d(salt3m1wave,salt3m1phase,salt3m1flux)
int_salt3m1err = interp2d(salt3m1errwave,salt3m1errphase,salt3m1fluxerr)
salt3m1flux_0 = int_salt3m1(salt3m1wave,plotphase)
salt3m1fluxerr_0 = int_salt3m1err(salt3m1wave,plotphase)
int_salt2m1 = RectBivariateSpline(salt2m1phase,salt2m1wave,salt2m1flux)
int_salt2m1err = RectBivariateSpline(salt2m1errphase,salt2m1errwave,salt2m1fluxerr)
salt2m1flux_0 = int_salt2m1(plotphase,salt2m1wave)[0]
salt2m1fluxerr_0 = int_salt2m1err(plotphase,salt2m1wave)[0]

int_salt3m1 = RectBivariateSpline(salt3m1phase,salt3m1wave,salt3m1flux)
int_salt3m1err = RectBivariateSpline(salt3m1errphase,salt3m1errwave,salt3m1fluxerr)
salt3m1flux_0 = int_salt3m1(plotphase,salt3m1wave)[0]
salt3m1fluxerr_0 = int_salt3m1err(plotphase,salt3m1wave)[0]
ax2.plot(salt2m1wave,salt2m1flux_0+spacing*i,color='b',label='SALT2')
m1scale = np.mean(np.abs(salt2m1flux_0[(salt2m1wave > 4000) & (salt2m1wave < 7000)]))/np.mean(np.abs(salt3m1flux_0[(salt3m1wave > 4000) & (salt3m1wave < 7000)]))
ax2.plot(salt3m1wave,salt3m1flux_0+spacing*i,color='r',label='SALT3')
Expand Down Expand Up @@ -207,10 +207,10 @@ def mkModelPlot(
spacing = 0.15
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):

int_salt3mhost = interp2d(salt3mhostwave,salt3mhostphase,salt3mhostflux)
int_salt3mhosterr = interp2d(salt3mhosterrwave,salt3mhosterrphase,salt3mhostfluxerr)
salt3mhostflux_0 = int_salt3mhost(salt3mhostwave,plotphase)
salt3mhostfluxerr_0 = int_salt3mhosterr(salt3mhostwave,plotphase)
int_salt3mhost = RectBivariateSpline(salt3mhostphase,salt3mhostwave,salt3mhostflux)
int_salt3mhosterr = RectBivariateSpline(salt3mhosterrphase,salt3mhosterrwave,salt3mhostfluxerr)
salt3mhostflux_0 = int_salt3mhost(plotphase,salt3mhostwave)
salt3mhostfluxerr_0 = int_salt3mhosterr(plotphase,salt3mhostwave)

ax4.plot(salt3mhostwave,salt3mhostflux_0+spacing*i,color='r',label='SALT3')
if plotErr:
Expand Down Expand Up @@ -311,11 +311,12 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200])
plotwave=np.linspace(2000,np.max(salt3m0errwave),720)
scale=2.5
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):
int_salt2m0err = interp2d(salt2m0errwave,salt2m0errphase,salt2m0fluxerr)
salt2m0fluxerr_0 = np.sqrt(int_salt2m0err(plotwave,plotphase))
#import pdb; pdb.set_trace()
int_salt2m0err = RectBivariateSpline(salt2m0errphase,salt2m0errwave,salt2m0fluxerr)
salt2m0fluxerr_0 = np.sqrt(int_salt2m0err(plotphase,plotwave)[0])

int_salt3m0err = interp2d(salt3m0errwave,salt3m0errphase,salt3m0fluxerr)
salt3m0fluxerr_0 = np.sqrt(int_salt3m0err(plotwave,plotphase))
int_salt3m0err = RectBivariateSpline(salt3m0errphase,salt3m0errwave,salt3m0fluxerr)
salt3m0fluxerr_0 = np.sqrt(int_salt3m0err(plotphase,plotwave)[0])

ax1.plot(plotwave,salt2m0fluxerr_0*scale+spacing*i,color='b',label='SALT2')
ax1.plot(plotwave,salt3m0fluxerr_0*scale+spacing*i,color='r',label='SALT3')
Expand All @@ -327,11 +328,11 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200])
scale=3
spacing = 0.15
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):
int_salt2m1err = interp2d(salt2m1errwave,salt2m1errphase,salt2m1fluxerr)
salt2m1fluxerr_0 = np.sqrt(int_salt2m1err(plotwave,plotphase))
int_salt2m1err = RectBivariateSpline(salt2m1errphase,salt2m1errwave,salt2m1fluxerr)
salt2m1fluxerr_0 = np.sqrt(int_salt2m1err(plotphase,plotwave)[0])

int_salt3m1err = interp2d(salt3m1errwave,salt3m1errphase,salt3m1fluxerr)
salt3m1fluxerr_0 = np.sqrt(int_salt3m1err(plotwave,plotphase))
int_salt3m1err = RectBivariateSpline(salt3m1errphase,salt3m1errwave,salt3m1fluxerr)
salt3m1fluxerr_0 = np.sqrt(int_salt3m1err(plotphase,plotwave)[0])
ax2.plot(plotwave,salt2m1fluxerr_0*scale+spacing*i,color='b',label='SALT2')
ax2.plot(plotwave,salt3m1fluxerr_0*scale+spacing*i,color='r',label='SALT3')
ax2.plot(xlimits,[spacing*i,spacing*i],'k--')
Expand All @@ -341,11 +342,11 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200])
ax2.text(xlimits[1]-100,spacing*(i+0.2),'%s'%plotphasestr,ha='right')
scale=.2
for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']):
int_salt2m0m1corr = interp2d(salt2m0m1errwave,salt2m0m1errphase,salt2m0m1corr)
salt2m0m1fluxerr_0 = int_salt2m0m1corr(plotwave,plotphase)
int_salt2m0m1corr = RectBivariateSpline(salt2m0m1errphase,salt2m0m1errwave,salt2m0m1corr)
salt2m0m1fluxerr_0 = int_salt2m0m1corr(plotphase,plotwave)[0]

int_salt3m0m1corr = interp2d(salt3m0m1errwave,salt3m0m1errphase,salt3m0m1corr)
salt3m0m1fluxerr_0 = int_salt3m0m1corr(plotwave,plotphase)
int_salt3m0m1corr = RectBivariateSpline(salt3m0m1errphase,salt3m0m1errwave,salt3m0m1corr)
salt3m0m1fluxerr_0 = int_salt3m0m1corr(plotphase,plotwave)[0]
ax3.plot(plotwave,salt2m0m1fluxerr_0*scale+spacing*i,color='b',label='SALT2')
ax3.plot(plotwave,salt3m0m1fluxerr_0*scale+spacing*i,color='r',label='SALT3')
ax3.plot(xlimits,[spacing*i,spacing*i],'k--')
Expand Down

0 comments on commit f0bccb4

Please sign in to comment.