diff --git a/saltshaker/validation/figs/plotSALTModel.py b/saltshaker/validation/figs/plotSALTModel.py index 4fbbf297..0f365da6 100755 --- a/saltshaker/validation/figs/plotSALTModel.py +++ b/saltshaker/validation/figs/plotSALTModel.py @@ -4,7 +4,7 @@ import numpy as np import pylab as plt import sys -from scipy.interpolate import interp1d, interp2d +from scipy.interpolate import interp1d, interp2d, RectBivariateSpline from sncosmo.salt2utils import SALT2ColorLaw from saltshaker.training import colorlaw from saltshaker.initfiles import init_rootdir @@ -122,15 +122,15 @@ def mkModelPlot( salt3m1flux*=sgn spacing = 0.5 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt2m0 = interp2d(salt2m0wave,salt2m0phase,salt2m0flux) - int_salt2m0err = interp2d(salt2m0errwave,salt2m0errphase,salt2m0fluxerr) - salt2m0flux_0 = int_salt2m0(salt2m0wave,plotphase) - salt2m0fluxerr_0 = int_salt2m0err(salt2m0wave,plotphase) + int_salt2m0 = RectBivariateSpline(salt2m0phase,salt2m0wave,salt2m0flux) + int_salt2m0err = RectBivariateSpline(salt2m0errphase,salt2m0errwave,salt2m0fluxerr) + salt2m0flux_0 = int_salt2m0(plotphase,salt2m0wave)[0] + salt2m0fluxerr_0 = int_salt2m0err(plotphase,salt2m0wave)[0] - int_salt3m0 = interp2d(salt3m0wave,salt3m0phase,salt3m0flux) - int_salt3m0err = interp2d(salt3m0errwave,salt3m0errphase,salt3m0fluxerr) - salt3m0flux_0 = int_salt3m0(salt3m0wave,plotphase) - salt3m0fluxerr_0 = int_salt3m0err(salt3m0wave,plotphase) + int_salt3m0 = RectBivariateSpline(salt3m0phase,salt3m0wave,salt3m0flux) + int_salt3m0err = RectBivariateSpline(salt3m0errphase,salt3m0errwave,salt3m0fluxerr) + salt3m0flux_0 = int_salt3m0(plotphase,salt3m0wave)[0] + salt3m0fluxerr_0 = int_salt3m0err(plotphase,salt3m0wave)[0] ax1.plot(salt2m0wave,salt2m0flux_0+spacing*i,color='b',label='SALT2') ax1.fill_between(salt2m0wave, @@ -149,15 +149,15 @@ def mkModelPlot( spacing = 0.15 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt2m1 = interp2d(salt2m1wave,salt2m1phase,salt2m1flux) - int_salt2m1err = interp2d(salt2m1errwave,salt2m1errphase,salt2m1fluxerr) - salt2m1flux_0 = int_salt2m1(salt2m1wave,plotphase) - salt2m1fluxerr_0 = int_salt2m1err(salt2m1wave,plotphase) - - int_salt3m1 = interp2d(salt3m1wave,salt3m1phase,salt3m1flux) - int_salt3m1err = interp2d(salt3m1errwave,salt3m1errphase,salt3m1fluxerr) - salt3m1flux_0 = int_salt3m1(salt3m1wave,plotphase) - salt3m1fluxerr_0 = int_salt3m1err(salt3m1wave,plotphase) + int_salt2m1 = RectBivariateSpline(salt2m1phase,salt2m1wave,salt2m1flux) + int_salt2m1err = RectBivariateSpline(salt2m1errphase,salt2m1errwave,salt2m1fluxerr) + salt2m1flux_0 = int_salt2m1(plotphase,salt2m1wave)[0] + salt2m1fluxerr_0 = int_salt2m1err(plotphase,salt2m1wave)[0] + + int_salt3m1 = RectBivariateSpline(salt3m1phase,salt3m1wave,salt3m1flux) + int_salt3m1err = RectBivariateSpline(salt3m1errphase,salt3m1errwave,salt3m1fluxerr) + salt3m1flux_0 = int_salt3m1(plotphase,salt3m1wave)[0] + salt3m1fluxerr_0 = int_salt3m1err(plotphase,salt3m1wave)[0] ax2.plot(salt2m1wave,salt2m1flux_0+spacing*i,color='b',label='SALT2') m1scale = np.mean(np.abs(salt2m1flux_0[(salt2m1wave > 4000) & (salt2m1wave < 7000)]))/np.mean(np.abs(salt3m1flux_0[(salt3m1wave > 4000) & (salt3m1wave < 7000)])) ax2.plot(salt3m1wave,salt3m1flux_0+spacing*i,color='r',label='SALT3') @@ -207,10 +207,10 @@ def mkModelPlot( spacing = 0.15 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt3mhost = interp2d(salt3mhostwave,salt3mhostphase,salt3mhostflux) - int_salt3mhosterr = interp2d(salt3mhosterrwave,salt3mhosterrphase,salt3mhostfluxerr) - salt3mhostflux_0 = int_salt3mhost(salt3mhostwave,plotphase) - salt3mhostfluxerr_0 = int_salt3mhosterr(salt3mhostwave,plotphase) + int_salt3mhost = RectBivariateSpline(salt3mhostphase,salt3mhostwave,salt3mhostflux) + int_salt3mhosterr = RectBivariateSpline(salt3mhosterrphase,salt3mhosterrwave,salt3mhostfluxerr) + salt3mhostflux_0 = int_salt3mhost(plotphase,salt3mhostwave) + salt3mhostfluxerr_0 = int_salt3mhosterr(plotphase,salt3mhostwave) ax4.plot(salt3mhostwave,salt3mhostflux_0+spacing*i,color='r',label='SALT3') if plotErr: @@ -311,11 +311,12 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200]) plotwave=np.linspace(2000,np.max(salt3m0errwave),720) scale=2.5 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt2m0err = interp2d(salt2m0errwave,salt2m0errphase,salt2m0fluxerr) - salt2m0fluxerr_0 = np.sqrt(int_salt2m0err(plotwave,plotphase)) + #import pdb; pdb.set_trace() + int_salt2m0err = RectBivariateSpline(salt2m0errphase,salt2m0errwave,salt2m0fluxerr) + salt2m0fluxerr_0 = np.sqrt(int_salt2m0err(plotphase,plotwave)[0]) - int_salt3m0err = interp2d(salt3m0errwave,salt3m0errphase,salt3m0fluxerr) - salt3m0fluxerr_0 = np.sqrt(int_salt3m0err(plotwave,plotphase)) + int_salt3m0err = RectBivariateSpline(salt3m0errphase,salt3m0errwave,salt3m0fluxerr) + salt3m0fluxerr_0 = np.sqrt(int_salt3m0err(plotphase,plotwave)[0]) ax1.plot(plotwave,salt2m0fluxerr_0*scale+spacing*i,color='b',label='SALT2') ax1.plot(plotwave,salt3m0fluxerr_0*scale+spacing*i,color='r',label='SALT3') @@ -327,11 +328,11 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200]) scale=3 spacing = 0.15 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt2m1err = interp2d(salt2m1errwave,salt2m1errphase,salt2m1fluxerr) - salt2m1fluxerr_0 = np.sqrt(int_salt2m1err(plotwave,plotphase)) + int_salt2m1err = RectBivariateSpline(salt2m1errphase,salt2m1errwave,salt2m1fluxerr) + salt2m1fluxerr_0 = np.sqrt(int_salt2m1err(plotphase,plotwave)[0]) - int_salt3m1err = interp2d(salt3m1errwave,salt3m1errphase,salt3m1fluxerr) - salt3m1fluxerr_0 = np.sqrt(int_salt3m1err(plotwave,plotphase)) + int_salt3m1err = RectBivariateSpline(salt3m1errphase,salt3m1errwave,salt3m1fluxerr) + salt3m1fluxerr_0 = np.sqrt(int_salt3m1err(plotphase,plotwave)[0]) ax2.plot(plotwave,salt2m1fluxerr_0*scale+spacing*i,color='b',label='SALT2') ax2.plot(plotwave,salt3m1fluxerr_0*scale+spacing*i,color='r',label='SALT3') ax2.plot(xlimits,[spacing*i,spacing*i],'k--') @@ -341,11 +342,11 @@ def mkModelErrPlot(salt3dir='modelfiles/salt3',outfile=None,xlimits=[2000,9200]) ax2.text(xlimits[1]-100,spacing*(i+0.2),'%s'%plotphasestr,ha='right') scale=.2 for plotphase,i,plotphasestr in zip([-5,0,10],range(3),['-5','+0','+10']): - int_salt2m0m1corr = interp2d(salt2m0m1errwave,salt2m0m1errphase,salt2m0m1corr) - salt2m0m1fluxerr_0 = int_salt2m0m1corr(plotwave,plotphase) + int_salt2m0m1corr = RectBivariateSpline(salt2m0m1errphase,salt2m0m1errwave,salt2m0m1corr) + salt2m0m1fluxerr_0 = int_salt2m0m1corr(plotphase,plotwave)[0] - int_salt3m0m1corr = interp2d(salt3m0m1errwave,salt3m0m1errphase,salt3m0m1corr) - salt3m0m1fluxerr_0 = int_salt3m0m1corr(plotwave,plotphase) + int_salt3m0m1corr = RectBivariateSpline(salt3m0m1errphase,salt3m0m1errwave,salt3m0m1corr) + salt3m0m1fluxerr_0 = int_salt3m0m1corr(plotphase,plotwave)[0] ax3.plot(plotwave,salt2m0m1fluxerr_0*scale+spacing*i,color='b',label='SALT2') ax3.plot(plotwave,salt3m0m1fluxerr_0*scale+spacing*i,color='r',label='SALT3') ax3.plot(xlimits,[spacing*i,spacing*i],'k--')