From 34268f014a314f6660154b3169c18438e9eba925 Mon Sep 17 00:00:00 2001 From: Kasper Marstal Date: Mon, 29 Mar 2021 12:47:46 +0200 Subject: [PATCH 1/2] ENH: Add Active Registration Model metrics Kasper Marstal and Stefan Klein "Active registration models", Proc. SPIE 10133, Medical Imaging 2017: Image Processing, 101330Y (24 February 2017); https://doi.org/10.1117/12.2254356) --- .../ActiveRegistrationModel/CMakeLists.txt | 29 + .../Statismo/CMakeLists.txt | 2 + .../Statismo/ITK/CMakeLists.txt | 7 + .../ITK/include/itkConditionalModelBuilder.h | 116 +++ .../Statismo/ITK/include/itkDataManager.h | 146 ++++ .../include/itkDataManagerWithSurrogates.h | 132 +++ ...tingStatisticalDeformationModelTransform.h | 187 +++++ .../ITK/include/itkLowRankGPModelBuilder.h | 150 ++++ .../Statismo/ITK/include/itkPCAModelBuilder.h | 114 +++ .../ITK/include/itkPixelConversionTraits.h | 373 +++++++++ .../ITK/include/itkPosteriorModelBuilder.h | 148 ++++ .../include/itkReducedVarianceModelBuilder.h | 128 +++ .../ITK/include/itkStandardImageRepresenter.h | 166 ++++ .../include/itkStandardImageRepresenter.hxx | 432 ++++++++++ .../itkStandardImageRepresenterTraits.h | 262 ++++++ .../ITK/include/itkStandardMeshRepresenter.h | 250 ++++++ .../include/itkStandardMeshRepresenter.hxx | 421 ++++++++++ .../Statismo/ITK/include/itkStatismoIO.h | 100 +++ .../itkStatisticalDeformationModelTransform.h | 128 +++ .../ITK/include/itkStatisticalModel.h | 287 +++++++ .../itkStatisticalModelTransformBase.h | 229 ++++++ .../itkStatisticalModelTransformBase.hxx | 186 +++++ .../itkStatisticalShapeModelTransform.h | 123 +++ .../Statismo/ITK/include/statismoITKConfig.h | 49 ++ .../Statismo/core/CMakeLists.txt | 1 + .../Statismo/core/include/CommonTypes.h | 146 ++++ .../core/include/ConditionalModelBuilder.h | 138 ++++ .../core/include/ConditionalModelBuilder.hxx | 267 ++++++ .../Statismo/core/include/Config.h | 53 ++ .../Statismo/core/include/DataItem.h | 223 +++++ .../Statismo/core/include/DataItem.hxx | 76 ++ .../Statismo/core/include/DataManager.h | 218 +++++ .../Statismo/core/include/DataManager.hxx | 303 +++++++ .../core/include/DataManagerWithSurrogates.h | 143 ++++ .../include/DataManagerWithSurrogates.hxx | 106 +++ .../Statismo/core/include/Domain.h | 83 ++ .../Statismo/core/include/Exceptions.h | 82 ++ .../Statismo/core/include/HDF5Utils.h | 273 +++++++ .../Statismo/core/include/HDF5Utils.hxx | 556 +++++++++++++ .../Statismo/core/include/KernelCombinators.h | 310 +++++++ .../Statismo/core/include/Kernels.h | 131 +++ .../core/include/LowRankGPModelBuilder.h | 289 +++++++ .../Statismo/core/include/ModelBuilder.h | 111 +++ .../Statismo/core/include/ModelInfo.h | 192 +++++ .../Statismo/core/include/Nystrom.h | 179 ++++ .../Statismo/core/include/PCAModelBuilder.h | 127 +++ .../Statismo/core/include/PCAModelBuilder.hxx | 251 ++++++ .../core/include/PosteriorModelBuilder.h | 209 +++++ .../core/include/PosteriorModelBuilder.hxx | 269 ++++++ .../Statismo/core/include/RandSVD.h | 109 +++ .../include/ReducedVarianceModelBuilder.h | 132 +++ .../include/ReducedVarianceModelBuilder.hxx | 128 +++ .../Statismo/core/include/Representer.h | 318 ++++++++ .../Statismo/core/include/StatismoIO.h | 247 ++++++ .../Statismo/core/include/StatismoUtils.h | 144 ++++ .../Statismo/core/include/StatisticalModel.h | 569 +++++++++++++ .../core/include/StatisticalModel.hxx | 519 ++++++++++++ .../include/TrivialVectorialRepresenter.h | 202 +++++ .../core/include/genericRepresenterTest.hxx | 381 +++++++++ .../Statismo/core/src/CMakeLists.txt | 14 + .../Statismo/core/src/ModelInfo.cxx | 307 +++++++ ...ActiveRegistrationModelIntensityMetric.cxx | 21 + ...lxActiveRegistrationModelIntensityMetric.h | 242 ++++++ ...ActiveRegistrationModelIntensityMetric.hxx | 712 ++++++++++++++++ .../elxActiveRegistrationModelShapeMetric.cxx | 17 + .../elxActiveRegistrationModelShapeMetric.h | 226 +++++ .../elxActiveRegistrationModelShapeMetric.hxx | 769 ++++++++++++++++++ ...tkActiveRegistrationModelIntensityMetric.h | 235 ++++++ ...ActiveRegistrationModelIntensityMetric.hxx | 382 +++++++++ .../itkActiveRegistrationModelShapeMetric.h | 207 +++++ .../itkActiveRegistrationModelShapeMetric.hxx | 329 ++++++++ 71 files changed, 15111 insertions(+) create mode 100644 Components/Metrics/ActiveRegistrationModel/CMakeLists.txt create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/CMakeLists.txt create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/CMakeLists.txt create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkConditionalModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManager.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManagerWithSurrogates.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkInterpolatingStatisticalDeformationModelTransform.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkLowRankGPModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPCAModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPixelConversionTraits.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPosteriorModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkReducedVarianceModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenterTraits.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatismoIO.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalDeformationModelTransform.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModel.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalShapeModelTransform.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/statismoITKConfig.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/CMakeLists.txt create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/CommonTypes.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Config.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Domain.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Exceptions.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/KernelCombinators.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Kernels.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/LowRankGPModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelInfo.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Nystrom.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/RandSVD.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Representer.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoIO.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoUtils.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/TrivialVectorialRepresenter.h create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/include/genericRepresenterTest.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/src/CMakeLists.txt create mode 100644 Components/Metrics/ActiveRegistrationModel/Statismo/core/src/ModelInfo.cxx create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.cxx create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.h create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.cxx create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.h create mode 100644 Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.h create mode 100644 Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.hxx create mode 100644 Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.h create mode 100644 Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.hxx diff --git a/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt new file mode 100644 index 000000000..3933593fd --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt @@ -0,0 +1,29 @@ +ADD_ELXCOMPONENT(ActiveRegistrationModelShapeMetric OFF + elxActiveRegistrationModelShapeMetric.cxx + elxActiveRegistrationModelShapeMetric.h + elxActiveRegistrationModelShapeMetric.hxx + itkActiveRegistrationModelShapeMetric.h + itkActiveRegistrationModelShapeMetric.hxx +) + +ADD_ELXCOMPONENT(ActiveRegistrationModelIntensityMetric OFF + elxActiveRegistrationModelIntensityMetric.h + elxActiveRegistrationModelIntensityMetric.hxx + elxActiveRegistrationModelIntensityMetric.cxx + itkActiveRegistrationModelIntensityMetric.h + itkActiveRegistrationModelIntensityMetric.hxx +) + +# Eigen3, statismo and hdf5 include directories and Eigen3 and hdf5 libraries are transitively included via +# the statismo_core target +if(${USE_ActiveRegistrationModelShapeMetric} OR ${USE_ActiveRegistrationModelIntensityMetric}) + add_subdirectory(Statismo) +endif() + +if(USE_ActiveRegistrationModelShapeMetric) + target_link_libraries(ActiveRegistrationModelShapeMetric statismo_core ${Boost_LIBRARIES} ITKInternalEigen3::Eigen) +endif() + +if(USE_ActiveRegistrationModelIntensityMetric) + target_link_libraries(ActiveRegistrationModelIntensityMetric statismo_core ${Boost_LIBRARIES} ITKInternalEigen3::Eigen) +endif() diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/Statismo/CMakeLists.txt new file mode 100644 index 000000000..170a9e032 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(core) +add_subdirectory(ITK) diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/CMakeLists.txt new file mode 100644 index 000000000..94bc6bf77 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/CMakeLists.txt @@ -0,0 +1,7 @@ +if(MSVC11) #i.e. Visual Studio 2012 + # Fix for VS2012 that has _VARIADIC_MAX set to 5. Don't set too high because it increases compiler memory usage / compile-time. + add_definitions(-D_VARIADIC_MAX=10 ) + # Fix for another VS2012 problem: not all TR1 options are automatically detected, therefore we force them here. + add_definitions(-D BOOST_HAS_TR1) + add_definitions(-D BOOST_NO_0X_HDR_INITIALIZER_LIST) +endif() diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkConditionalModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkConditionalModelBuilder.h new file mode 100644 index 000000000..23b96dd2b --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkConditionalModelBuilder.h @@ -0,0 +1,116 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITKMODELBUILDER_H_ +#define ITKMODELBUILDER_H_ + +#include +#include + +#include "itkDataManager.h" +#include "itkStatisticalModel.h" +#include "ConditionalModelBuilder.h" +#include "statismoITKConfig.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::PCAModelBuilder class. + * \see statismo::PCAModelBuilder for detailed documentation. + */ +template +class ConditionalModelBuilder : public Object { + public: + + typedef ConditionalModelBuilder Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( ConditionalModelBuilder, Object ); + + + typedef statismo::ConditionalModelBuilder ImplType; + typedef statismo::DataManager DataManagerType; + typedef typename DataManagerType::SampleDataStructureListType SampleDataStructureListType; + + ConditionalModelBuilder() : m_impl(ImplType::Create()) {} + + virtual ~ConditionalModelBuilder() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + template + typename boost::result_of::type callstatismoImpl(F f) const { + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + typename StatisticalModel::Pointer + BuildNewModel(SampleDataStructureListType SampleDataStructureList, + const typename statismo::ConditionalModelBuilder::SurrogateTypeVectorType& surrogateTypes, + const typename statismo::ConditionalModelBuilder::CondVariableValueVectorType& conditioningInfo, + float noiseVariance, + double modelVarianceRetained + ) { + statismo::StatisticalModel* model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModel, this->m_impl, SampleDataStructureList, surrogateTypes, conditioningInfo, noiseVariance, modelVarianceRetained)); + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + } + + + private: + ConditionalModelBuilder(const ConditionalModelBuilder& orig); + ConditionalModelBuilder& operator=(const ConditionalModelBuilder& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITKMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManager.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManager.h new file mode 100644 index 000000000..2f9f61bea --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManager.h @@ -0,0 +1,146 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITK_DATAMANAGER_H_ +#define ITK_DATAMANAGER_H_ + +#include +#include + +#include +#include + +#include "DataManager.h" +#include "statismoITKConfig.h" + +namespace itk { + + +/** + * \brief ITK Wrapper for the statismo::DataManager class. + * \see statismo::DataManager for detailed documentation. + */ +template +class DataManager : public Object { + public: + + + typedef DataManager Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( DataManager, Object ); + + + typedef statismo::DataManager ImplType; + typedef typename statismo::DataManager::DataItemType DataItemType; + typedef typename statismo::DataManager::DataItemListType DataItemListType; + typedef statismo::Representer RepresenterType; + + template + typename boost::result_of::type callstatismoImpl(F f) const { + if (m_impl == 0) { + itkExceptionMacro(<< "Model not properly initialized. Maybe you forgot to call SetRepresenter"); + } + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + DataManager() : m_impl(0) {} + + virtual ~DataManager() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + ImplType* GetstatismoImplObj() const { + return m_impl; + } + + + void SetstatismoImplObj(ImplType* impl) { + if (m_impl) { + delete m_impl; + } + m_impl = impl; + } + + void SetRepresenter(const RepresenterType* representer) { + SetstatismoImplObj(ImplType::Create(representer)); + } + + void AddDataset(typename RepresenterType::DatasetType* ds, const char* filename) { + callstatismoImpl(boost::bind(&ImplType::AddDataset, this->m_impl, ds, filename)); + } + + void Load(const char* filename) { + try { + SetstatismoImplObj(ImplType::Load(filename)); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + void Save(const char* filename) { + callstatismoImpl(boost::bind(&ImplType::Save, this->m_impl, filename)); + } + + typename statismo::DataManager::DataItemListType GetData() const { + return callstatismoImpl(boost::bind(&ImplType::GetData, this->m_impl)); + } + + + private: + DataManager(const DataManager& orig); + DataManager& operator=(const DataManager& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITK_DATAMANAGER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManagerWithSurrogates.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManagerWithSurrogates.h new file mode 100644 index 000000000..382c0ebc1 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkDataManagerWithSurrogates.h @@ -0,0 +1,132 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITK_DATAMANAGER_WITH_SURROGATES_H_ +#define ITK_DATAMANAGER_WITH_SURROGATES_H_ + +#include +#include + +#include +#include + +#include "DataManagerWithSurrogates.h" +#include "statismoITKConfig.h" + +namespace itk { + + +/** + * \brief ITK Wrapper for the statismo::DataManager class. + * \see statismo::DataManager for detailed documentation. + */ +template +class DataManagerWithSurrogates : public statismo::DataManager { + public: + + + typedef DataManagerWithSurrogates Self; + typedef statismo::DataManager Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( DataManagerWithSurrogates, Object ); + + + typedef statismo::DataManagerWithSurrogates ImplType; + + template + typename boost::result_of::type callstatismoImpl(F f) const { + if (m_impl == 0) { + itkExceptionMacro(<< "Model not properly initialized. Maybe you forgot to call SetParameters"); + } + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + DataManagerWithSurrogates() : m_impl(0) {} + + virtual ~DataManagerWithSurrogates() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + + void SetstatismoImplObj(ImplType* impl) { + if (m_impl) { + delete m_impl; + } + m_impl = impl; + } + + void SetRepresenterAndSurrogateFilename(const Representer* representer, const char* surrogTypeFilename) { + SetstatismoImplObj(ImplType::Create(representer, surrogTypeFilename)); + } + + void SetRepresenter(const Representer* representer) { + itkExceptionMacro(<< "Please call SetRepresenterAndSurrogateFilename to initialize the object"); + } + + + + void AddDatasetWithSurrogates(typename Representer::DatasetConstPointerType ds, + const char* datasetURI, + const char* surrogateFilename) { + callstatismoImpl(boost::bind(&ImplType::AddDatasetWithSurrogates, this->m_impl, ds, datasetURI, surrogateFilename)); + } + + + private: + + DataManagerWithSurrogates(const DataManagerWithSurrogates& orig); + DataManagerWithSurrogates& operator=(const DataManagerWithSurrogates& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITK_DATAMANAGER_WITH_SURROGATES_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkInterpolatingStatisticalDeformationModelTransform.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkInterpolatingStatisticalDeformationModelTransform.h new file mode 100644 index 000000000..0d536c551 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkInterpolatingStatisticalDeformationModelTransform.h @@ -0,0 +1,187 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ItkInterpolatingStatisticalDeformationModelTransform +#define __ItkInterpolatingStatisticalDeformationModelTransform + +#include + +#include +#include +#include + +#include "itkStandardImageRepresenter.h" +#include "itkStatisticalModel.h" +#include "itkStatisticalModelTransformBase.h" + +#include "Representer.h" + +namespace itk { + +/** + * + * \brief An itk transform that allows for deformations defined by a given Statistical Deformation Model. + * + * In contrast to the standard StatisticalDeformationModelTransform, this transform performs a linear interpolation of the + * PCABasis. This has the advantage that a model can be fitted which has a much lower resolution that the image, that needs to + * be explained. + * + * \ingroup Transforms + */ +template +class ITK_EXPORT InterpolatingStatisticalDeformationModelTransform : + public itk::StatisticalModelTransformBase< TDataset, TScalarType , TDimension> { + public: + + /* Standard class typedefs. */ + typedef InterpolatingStatisticalDeformationModelTransform Self; + typedef itk::StatisticalModelTransformBase< TDataset, TScalarType , TDimension> Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + + itkSimpleNewMacro( Self ); + + + /** Run-time type information (and related methods). */ + itkTypeMacro(InterpolatingStatisticalDeformationModelTransform, Superclass); + + + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename Superclass::RepresenterType RepresenterType; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + typedef typename Superclass::JacobianType JacobianType; + + + typedef typename RepresenterType::DatasetType DeformationFieldType; + typedef VectorLinearInterpolateImageFunction InterpolatorType; + + /** + * Clone the current transform + */ + virtual ::itk::LightObject::Pointer CreateAnother() const { + ::itk::LightObject::Pointer smartPtr; + Pointer another = Self::New().GetPointer(); + this->CopyBaseMembers(another); + another->m_meanDeformation = this->m_meanDeformation; + another->m_PCABasisDeformations = this->m_PCABasisDeformations; + smartPtr = static_cast(another); + return smartPtr; + } + + virtual void SetStatisticalModel(const StatisticalModelType* model) { + this->Superclass::SetStatisticalModel(model); + + m_meanDeformation = InterpolatorType::New(); + typename DeformationFieldType::Pointer meanDf = model->DrawMean(); + m_meanDeformation->SetInputImage(meanDf); + for (unsigned i = 0; i < model->GetNumberOfPrincipalComponents(); i++) { + typename DeformationFieldType::Pointer deformationField = model->DrawPCABasisSample(i); + typename InterpolatorType::Pointer basisI = InterpolatorType::New(); + basisI->SetInputImage(deformationField); + m_PCABasisDeformations.push_back(basisI); + } + } + + + + void ComputeJacobianWithRespectToParameters(const InputPointType &pt, JacobianType &jacobian) const { + jacobian.SetSize(TDimension, m_PCABasisDeformations.size()); + jacobian.Fill(0); + if (m_meanDeformation->IsInsideBuffer(pt) == false) + return; + + for(unsigned j = 0; j < m_PCABasisDeformations.size(); j++) { + typename RepresenterType::ValueType d = m_PCABasisDeformations[j]->Evaluate(pt); + for(unsigned i = 0; i < TDimension; i++) { + jacobian(i,j) += d[i] ; + } + } + + itkDebugMacro( << "Jacobian with MM:\n" << jacobian); + itkDebugMacro( << "After GetMorphableModelJacobian:" + << "\nJacobian = \n" << jacobian); + } + + + /** + * Transform a given point according to the deformation induced by the StatisticalModel, + * given the current parameters. + * + * \param pt The point to tranform + * \return The transformed point + */ + virtual OutputPointType TransformPoint(const InputPointType &pt) const { + if (m_meanDeformation->IsInsideBuffer(pt) == false) { + return pt; + } + assert(this->m_coeff_vector.size() == m_PCABasisDeformations.size()); + typename RepresenterType::ValueType def = m_meanDeformation->Evaluate(pt); + + for (unsigned i = 0; i < m_PCABasisDeformations.size(); i++) { + typename RepresenterType::ValueType defBasisI = m_PCABasisDeformations[i]->Evaluate(pt); + def += (defBasisI * this->m_coeff_vector[i]); + } + + OutputPointType transformedPoint; + for (unsigned i = 0; i < pt.GetPointDimension(); i++) { + transformedPoint[i] = pt[i] + def[i]; + } + + return transformedPoint; + } + + virtual ~InterpolatingStatisticalDeformationModelTransform() {} + + InterpolatingStatisticalDeformationModelTransform() {} + + private: + + + InterpolatingStatisticalDeformationModelTransform(const InterpolatingStatisticalDeformationModelTransform& orig); // purposely not implemented + InterpolatingStatisticalDeformationModelTransform& operator=(const InterpolatingStatisticalDeformationModelTransform& rhs); //purposely not implemented + + + typename InterpolatorType::Pointer m_meanDeformation; + std::vector m_PCABasisDeformations; +}; + + +} // namespace itk + +#endif // __ItkInterpolatingStatisticalDeformationModelTransform diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkLowRankGPModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkLowRankGPModelBuilder.h new file mode 100644 index 000000000..436d159a8 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkLowRankGPModelBuilder.h @@ -0,0 +1,150 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ITKLOWRANKMODELBUILDER_H_ +#define ITKLOWRANKMODELBUILDER_H_ + +#include +#include + +#include "itkStatisticalModel.h" + +#include "Kernels.h" +#include "LowRankGPModelBuilder.h" +#include "Representer.h" +#include "statismoITKConfig.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::LowRankGPModelBuilder class. + * \see statismo::LowRankGPModelBuilder for detailed documentation. + */ + +template +class LowRankGPModelBuilder: public Object { + public: + + typedef LowRankGPModelBuilder Self; + typedef statismo::Representer RepresenterType; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro (Self); + itkTypeMacro( LowRankGPModelBuilder, Object ); + + typedef statismo::LowRankGPModelBuilder ImplType; + typedef itk::StatisticalModel StatisticalModelType; + typedef statismo::MatrixValuedKernel MatrixValuedKernelType; + + LowRankGPModelBuilder() : + m_impl(0) { + } + + + void SetstatismoImplObj(ImplType* impl) { + if (m_impl) { + delete m_impl; + } + m_impl = impl; + } + + + void SetRepresenter(const RepresenterType* representer) { + SetstatismoImplObj(ImplType::Create(representer)); + } + + virtual ~LowRankGPModelBuilder() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + typename StatisticalModelType::Pointer BuildNewZeroMeanModel( + const MatrixValuedKernelType& kernel, unsigned numComponents, + unsigned numPointsForNystrom = 500) const { + if (m_impl == 0) { + itkExceptionMacro(<< "Model not properly initialized. Maybe you forgot to call SetRepresenter"); + } + + + statismo::StatisticalModel* model_statismo = 0; + try { + model_statismo = this->m_impl->BuildNewZeroMeanModel(kernel, numComponents, numPointsForNystrom); + + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + + } + + typename StatisticalModelType::Pointer BuildNewModel(typename RepresenterType::DatasetType* mean, const MatrixValuedKernelType& kernel, unsigned numComponents, unsigned numPointsForNystrom = 500) { + if (m_impl == 0) { + itkExceptionMacro(<< "Model not properly initialized. Maybe you forgot to call SetRepresenter"); + } + + statismo::StatisticalModel* model_statismo = 0; + try { + model_statismo = this->m_impl->BuildNewModel(mean, kernel, numComponents, numPointsForNystrom); + + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + + + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + + } + + private: + LowRankGPModelBuilder(const LowRankGPModelBuilder& orig); + LowRankGPModelBuilder& operator=(const LowRankGPModelBuilder& rhs); + + ImplType* m_impl; +}; + +} + +#endif /* ITKLOWRANKMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPCAModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPCAModelBuilder.h new file mode 100644 index 000000000..3256ff31b --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPCAModelBuilder.h @@ -0,0 +1,114 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITKMODELBUILDER_H_ +#define ITKMODELBUILDER_H_ + +#include +#include + +#include "itkDataManager.h" +#include "itkStatisticalModel.h" + +#include "PCAModelBuilder.h" +#include "statismoITKConfig.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::PCAModelBuilder class. + * \see statismo::PCAModelBuilder for detailed documentation. + */ +template +class PCAModelBuilder : public Object { + public: + + typedef PCAModelBuilder Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( PCAModelBuilder, Object ); + + + typedef statismo::PCAModelBuilder ImplType; + typedef statismo::DataManager DataManagerType; + typedef typename DataManagerType::DataItemListType DataItemListType; + + typedef typename ImplType::EigenValueMethod EigenValueMethod; + + PCAModelBuilder() : m_impl(ImplType::Create()) {} + + virtual ~PCAModelBuilder() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + template + typename boost::result_of::type callstatismoImpl(F f) const { + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + + typename StatisticalModel::Pointer BuildNewModel(DataItemListType DataItemList, float noiseVariance, bool computeScores = true, EigenValueMethod method = ImplType::JacobiSVD) { + statismo::StatisticalModel* model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModel, this->m_impl, DataItemList, noiseVariance, computeScores, method)); + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + } + + + private: + PCAModelBuilder(const PCAModelBuilder& orig); + PCAModelBuilder& operator=(const PCAModelBuilder& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITKMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPixelConversionTraits.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPixelConversionTraits.h new file mode 100644 index 000000000..048d8c67f --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPixelConversionTraits.h @@ -0,0 +1,373 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef __ITK_TYPE_CONVERSION_TRAIT +#define __ITK_TYPE_CONVERSION_TRAIT + +#include + +#include "Exceptions.h" +#include "CommonTypes.h" + +namespace itk { + +// these traits are used to allow a conversion from the generic pixel type to a statismo vector. +// Currently only scalar types are supported. + +template struct PixelConversionTrait { + static statismo::VectorType ToVector(const T& pixel) { + throw statismo::StatisticalModelException("Unsupported PixelType (PixelTraits::ToVector not implemented)"); + } + static T FromVector(const statismo::VectorType& v) { + throw statismo::StatisticalModelException("Unsupported PixelType (PixelTraits::ToVector not implemented)"); + } + static unsigned GetDataType() { + throw statismo::StatisticalModelException("Unsupported PixelType (PixelTraits::ToVector not implemented)"); + } + static unsigned GetPixelDimension() { + throw statismo::StatisticalModelException("Unsupported PixelType (PixelTraits::ToVector not implemented)"); + } +}; + +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const double& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static double FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::DOUBLE; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const float& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static float FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::FLOAT; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const short& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static short FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::SIGNED_SHORT; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const unsigned short& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static unsigned short FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::UNSIGNED_SHORT; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const int& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static int FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::SIGNED_INT; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const unsigned int& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static unsigned int FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::UNSIGNED_SHORT; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const char& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static char FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::SIGNED_CHAR; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const unsigned char& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static unsigned char FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::UNSIGNED_CHAR; + } + static unsigned GetPixelDimension() { + return 1; + } +}; + +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const long& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static long FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::SIGNED_LONG; + } + static unsigned GetPixelDimension() { + return 1; + } +}; +template <> struct PixelConversionTrait { + static statismo::VectorType ToVector(const unsigned long& pixel) { + statismo::VectorType v(1); + v << pixel; + return v; + } + static unsigned long FromVector(const statismo::VectorType& v) { + assert(v.size() == 1); + return v(0); + } + static unsigned GetDataType() { + return statismo::UNSIGNED_LONG; + } + static unsigned GetPixelDimension() { + return 1; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(2); + v << pixel[0] , pixel[1]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 2); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + return itkVec; + } + static unsigned GetDataType() { + return statismo::FLOAT; + } + static unsigned GetPixelDimension() { + return 2; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(3); + v << pixel[0] , pixel[1], pixel[2]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 3); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + itkVec[2] = v(2); + return itkVec; + } + static unsigned GetDataType() { + return statismo::FLOAT; + } + static unsigned GetPixelDimension() { + return 3; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(4); + v << pixel[0] , pixel[1], pixel[2], pixel[3]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 4); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + itkVec[2] = v(2); + itkVec[3] = v(3); + return itkVec; + } + static unsigned GetDataType() { + return statismo::FLOAT; + } + static unsigned GetPixelDimension() { + return 4; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(2); + v << pixel[0] , pixel[1]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 2); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + return itkVec; + } + static unsigned GetDataType() { + return statismo::DOUBLE; + } + static unsigned GetPixelDimension() { + return 2; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(3); + v << pixel[0] , pixel[1], pixel[2]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 3); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + itkVec[2] = v(2); + return itkVec; + } + static unsigned GetDataType() { + return statismo::DOUBLE; + } + static unsigned GetPixelDimension() { + return 3; + } +}; + +template <> struct PixelConversionTrait > { + static statismo::VectorType ToVector(const itk::Vector& pixel) { + statismo::VectorType v(4); + v << pixel[0] , pixel[1], pixel[2], pixel[3]; + return v; + } + static itk::Vector FromVector(const statismo::VectorType& v) { + assert(v.size() == 4); + itk::Vector itkVec; + itkVec[0] = v(0); + itkVec[1] = v(1); + itkVec[2] = v(2); + itkVec[3] = v(3); + return itkVec; + } + static unsigned GetDataType() { + return statismo::DOUBLE; + } + static unsigned GetPixelDimension() { + return 4; + } +}; + +} // namespace itk + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPosteriorModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPosteriorModelBuilder.h new file mode 100644 index 000000000..f610f5d50 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkPosteriorModelBuilder.h @@ -0,0 +1,148 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITK_POSTERIOR_MODELBUILDER_H_ +#define ITK_POSTERIOR_MODELBUILDER_H_ + +#include + +#include "itkDataManager.h" +#include "itkStatisticalModel.h" + +#include "PosteriorModelBuilder.h" +#include "statismoITKConfig.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::PosteriorModelBuilder class. + * \see statismo::PosteriorModelBuilder for detailed documentation. + */ +template +class PosteriorModelBuilder : public Object { + public: + + typedef PosteriorModelBuilder Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( PosteriorModelBuilder, Object ); + + typedef statismo::PosteriorModelBuilder ImplType; + typedef statismo::DataManager DataManagerType; + typedef typename DataManagerType::DataItemListType DataItemListType; + + + + template + typename boost::result_of::type callstatismoImpl(F f) const { + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + PosteriorModelBuilder() : m_impl(ImplType::Create()) {} + + virtual ~PosteriorModelBuilder() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + + // create statismo stuff + typedef statismo::Representer RepresenterType; + typedef typename RepresenterType::ValueType ValueType; + typedef typename RepresenterType::PointType PointType; + typedef typename statismo::PosteriorModelBuilder::PointValueListType PointValueListType; + typedef typename statismo::PosteriorModelBuilder::PointValueWithCovariancePairType PointValueWithCovariancePairType; + typedef typename statismo::PosteriorModelBuilder::PointValueWithCovarianceListType PointValueWithCovarianceListType; + typedef itk::StatisticalModel StatisticalModelType; + typedef statismo::StatisticalModel StatismoStatisticalModelType; + + typename StatisticalModelType::Pointer BuildNewModelFromModel(const StatisticalModelType* model, const PointValueListType& pointValues, double pointValuesNoiseVariance, bool computeScores=true) { + StatismoStatisticalModelType* model_statismo = model->GetstatismoImplObj(); + StatismoStatisticalModelType* new_model_statismo = callstatismoImpl(boost::bind( + static_cast (&ImplType::BuildNewModelFromModel), + this->m_impl, model_statismo, pointValues, pointValuesNoiseVariance, computeScores)); + typename StatisticalModelType::Pointer model_itk = StatisticalModelType::New(); + model_itk->SetstatismoImplObj(new_model_statismo); + return model_itk; + } + + typename StatisticalModelType::Pointer BuildNewModel(DataItemListType DataItemList, const PointValueListType& pointValues, double pointValuesNoiseVariance, double noiseVariance) { + StatismoStatisticalModelType* model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModel, this->m_impl, DataItemList ,pointValues, pointValuesNoiseVariance, noiseVariance)); + typename StatisticalModelType::Pointer model_itk = StatisticalModelType::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + } + + typename StatisticalModelType::Pointer BuildNewModelFromModel(const StatisticalModelType* model, const PointValueWithCovarianceListType& pointValuesWithCovariance, bool computeScores=true) { + StatismoStatisticalModelType* model_statismo = model->GetstatismoImplObj(); + StatismoStatisticalModelType* new_model_statismo = callstatismoImpl(boost::bind( + static_cast (&ImplType::BuildNewModelFromModel), + this->m_impl, model_statismo, pointValuesWithCovariance, computeScores)); + typename StatisticalModelType::Pointer model_itk = StatisticalModelType::New(); + model_itk->SetstatismoImplObj(new_model_statismo); + return model_itk; + } + + typename StatisticalModelType::Pointer BuildNewModel(const DataItemListType& DataItemList, const PointValueWithCovarianceListType& pointValuesWithCovariance, double noiseVariance) { + StatismoStatisticalModelType* model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModel, this->m_impl, DataItemList, pointValuesWithCovariance, noiseVariance)); + typename StatisticalModelType::Pointer model_itk = StatisticalModelType::New(); + model_itk->SetstatismoImplObj(model_statismo); + return model_itk; + } + + private: + PosteriorModelBuilder(const PosteriorModelBuilder& orig); + PosteriorModelBuilder& operator=(const PosteriorModelBuilder& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITK_POSTERIOR_MODEL_BUILDER */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkReducedVarianceModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkReducedVarianceModelBuilder.h new file mode 100644 index 000000000..c64b94765 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkReducedVarianceModelBuilder.h @@ -0,0 +1,128 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITK_PARTIALLY_FIXED_MODELBUILDER_H_ +#define ITK_PARTIALLY_FIXED_MODELBUILDER_H_ + +#include + +#include "itkDataManager.h" +#include "itkStatisticalModel.h" + +#include "ReducedVarianceModelBuilder.h" +#include "statismoITKConfig.h" +#include "StatismoUtils.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::ReducedVarianceModelBuilder class. + * \see statismo::ReducedVariance for detailed documentation. + */ +template +class ReducedVarianceModelBuilder : public Object { + public: + + typedef ReducedVarianceModelBuilder Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( ReducedVarianceModelBuilder, Object ); + + typedef statismo::ReducedVarianceModelBuilder ImplType; + + + template + typename boost::result_of::type callstatismoImpl(F f) const { + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + + ReducedVarianceModelBuilder() : m_impl(ImplType::Create()) {} + + virtual ~ReducedVarianceModelBuilder() { + if (m_impl) { + delete m_impl; + m_impl = 0; + } + } + + + + typename StatisticalModel::Pointer BuildNewModelWithLeadingComponents(const StatisticalModel* model, unsigned numberOfPrincipalComponents) { + statismo::StatisticalModel* model_statismo = model->GetstatismoImplObj(); + statismo::StatisticalModel* new_model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModelWithLeadingComponents, this->m_impl, model_statismo, numberOfPrincipalComponents)); + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(new_model_statismo); + return model_itk; + } + + typename StatisticalModel::Pointer BuildNewModelWithVariance(const StatisticalModel* model, double totalVariance) { + statismo::StatisticalModel* model_statismo = model->GetstatismoImplObj(); + statismo::StatisticalModel* new_model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModelWithVariance, this->m_impl, model_statismo, totalVariance)); + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(new_model_statismo); + return model_itk; + } + + is_deprecated typename StatisticalModel::Pointer BuildNewModelFromModel(const StatisticalModel* model, double totalVariance) { + statismo::StatisticalModel* model_statismo = model->GetstatismoImplObj(); + statismo::StatisticalModel* new_model_statismo = callstatismoImpl(boost::bind(&ImplType::BuildNewModelFromModel, this->m_impl, model_statismo, totalVariance)); + typename StatisticalModel::Pointer model_itk = StatisticalModel::New(); + model_itk->SetstatismoImplObj(new_model_statismo); + return model_itk; + } + + + private: + ReducedVarianceModelBuilder(const ReducedVarianceModelBuilder& orig); + ReducedVarianceModelBuilder& operator=(const ReducedVarianceModelBuilder& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITK_PARTIALLY_FIXED_MODEL_BUILDER */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.h new file mode 100644 index 000000000..7306584c2 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.h @@ -0,0 +1,166 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ITK_STANDARDIMAGE_REPRESENTER_H_ +#define ITK_STANDARDIMAGE_REPRESENTER_H_ + +#include "statismoITKConfig.h" // this needs to be the first include + +#include + +#include +#include + +#include "itkPixelConversionTraits.h" +#include "itkStandardImageRepresenterTraits.h" + +#include "CommonTypes.h" +#include "Representer.h" + +namespace itk { + +/** + * \ingroup Representers + * \brief A representer for scalar and vector valued images + * \sa Representer + */ + +template +class StandardImageRepresenter: public Object, public statismo::Representer< + itk::Image > { + public: + + /* Standard class typedefs. */ + typedef StandardImageRepresenter Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + /** New macro for creation of through a Smart Pointer. */ + itkSimpleNewMacro (Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro( StandardImageRepresenter, Object ); + + typedef itk::Image ImageType; + + typedef typename statismo::Representer RepresenterBaseType; + typedef typename RepresenterBaseType::DomainType DomainType; + typedef typename RepresenterBaseType::PointType PointType; + typedef typename RepresenterBaseType::ValueType ValueType; + typedef typename RepresenterBaseType::DatasetType DatasetType; + typedef typename RepresenterBaseType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterBaseType::DatasetConstPointerType DatasetConstPointerType; + + static StandardImageRepresenter* Create() { + return new StandardImageRepresenter(); + } + void Load(const H5::Group& fg) override; + StandardImageRepresenter* Clone() const override; + + StandardImageRepresenter(); + virtual ~StandardImageRepresenter(); + + unsigned GetDimensions() const override { + return PixelConversionTrait::GetPixelDimension(); + } + std::string GetName() const override { + return "itkStandardImageRepresenter"; + } + typename RepresenterBaseType::RepresenterDataType GetType() const override { + return RepresenterBaseType::IMAGE; + } + + const DomainType& GetDomain() const override { + return m_domain; + } + std::string GetVersion() const override { + return "0.1"; + } + + /// return the reference used in the representer + DatasetConstPointerType GetReference() const override { + return m_reference; + } + + + /** Set the reference that is used to build the model */ + void SetReference(ImageType* ds); + + /** + * Creates a sample by first aligning the dataset ds to the reference using Procrustes + * Alignment. + */ + statismo::VectorType PointToVector(const PointType& pt) const override; + statismo::VectorType SampleToSampleVector(DatasetConstPointerType sample) const override; + DatasetPointerType SampleVectorToSample( + const statismo::VectorType& sample) const override; + + ValueType PointSampleFromSample(DatasetConstPointerType sample, + unsigned ptid) const override; + ValueType PointSampleVectorToPointSample( + const statismo::VectorType& pointSample) const override; + statismo::VectorType PointSampleToPointSampleVector( + const ValueType& v) const override; + + void Save(const H5::Group& fg) const override; + virtual unsigned GetPointIdForPoint(const PointType& point) const override; + + unsigned GetNumberOfPoints() const; + + void Delete() const override{ + this->UnRegister(); + } + + + void DeleteDataset(DatasetConstPointerType d) const override {} + DatasetPointerType CloneDataset(DatasetConstPointerType d) const override; + + private: + + typename ImageType::Pointer LoadRef(const H5::Group& fg) const; + typename ImageType::Pointer LoadRefLegacy(const H5::Group& fg) const; + + DatasetConstPointerType m_reference; + DomainType m_domain; +}; + +} // namespace itk + +#include "itkStandardImageRepresenter.hxx" + +#endif /* itkStandardImageREPRESENTER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.hxx new file mode 100644 index 000000000..da887e712 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenter.hxx @@ -0,0 +1,432 @@ +/* +* This file is part of the statismo library. +* +* Author: Marcel Luethi (marcel.luethi@unibas.ch) +* +* Copyright (c) 2011 University of Basel +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions +* are met: +* +* Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* +* Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in the +* documentation and/or other materials provided with the distribution. +* +* Neither the name of the project's author nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +* HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +* TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +*/ + +#ifndef __itkStandardImageRepresenter_hxx +#define __itkStandardImageRepresenter_hxx + + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "HDF5Utils.h" +#include "StatismoUtils.h" + +#include "itkStandardImageRepresenter.h" + + +namespace itk { + +template +StandardImageRepresenter::StandardImageRepresenter() + : m_reference() { +} +template +StandardImageRepresenter::~StandardImageRepresenter() { +} + +template +StandardImageRepresenter* +StandardImageRepresenter::Clone() const { + + StandardImageRepresenter* clone = new StandardImageRepresenter(); + clone->Register(); + + DatasetPointerType clonedReference = this->CloneDataset(m_reference); + clone->SetReference(clonedReference); + return clone; +} + + + +template +void +StandardImageRepresenter::Load(const H5::Group& fg) { + + std::string repName = statismo::HDF5Utils::readStringAttribute(fg, "name"); + if (repName == "vtkStructuredPointsRepresenter" || repName == "itkImageRepresenter" || repName == "itkVectorImageRepresenter") { + this->SetReference(LoadRefLegacy(fg)); + } else { + this->SetReference(LoadRef(fg)); + } + +} + + +template +typename StandardImageRepresenter::ImageType::Pointer +StandardImageRepresenter::LoadRef(const H5::Group& fg) const { + + + int readImageDimension = statismo::HDF5Utils::readInt(fg, "imageDimension"); + if (readImageDimension != ImageDimension) { + throw statismo::StatisticalModelException("the image dimension specified in the statismo file does not match the one specified as template parameter"); + } + + + statismo::VectorType originVec; + statismo::HDF5Utils::readVector(fg, "origin", originVec); + typename ImageType::PointType origin; + for (unsigned i = 0; i < ImageDimension; i++) { + origin[i] = originVec[i]; + } + + statismo::VectorType spacingVec; + statismo::HDF5Utils::readVector(fg, "spacing", spacingVec); + typename ImageType::SpacingType spacing; + for (unsigned i = 0; i < ImageDimension; i++) { + spacing[i] = spacingVec[i]; + } + + typename statismo::GenericEigenType::VectorType sizeVec; + statismo::HDF5Utils::readVectorOfType(fg, "size", sizeVec); + typename ImageType::SizeType size; + for (unsigned i = 0; i < ImageDimension; i++) { + size[i] = sizeVec[i]; + } + + statismo::MatrixType directionMat; + statismo::HDF5Utils::readMatrix(fg, "direction", directionMat); + typename ImageType::DirectionType direction; + for (unsigned i = 0; i < directionMat.rows(); i++) { + for (unsigned j = 0; j < directionMat.rows(); j++) { + direction[i][j] = directionMat(i,j); + } + } + + H5::Group pdGroup = fg.openGroup("./pointData"); + unsigned readPixelDimension = static_cast(statismo::HDF5Utils::readInt(pdGroup, "pixelDimension")); + if (readPixelDimension != GetDimensions()) { + throw statismo::StatisticalModelException("the pixel dimension specified in the statismo file does not match the one specified as template parameter"); + } + + typename statismo::GenericEigenType::MatrixType pixelMatDouble; + statismo::HDF5Utils::readMatrixOfType(pdGroup, "pixelValues", pixelMatDouble); + statismo::MatrixType pixelMat = pixelMatDouble.cast(); + typename ImageType::Pointer newImage = ImageType::New(); + typename ImageType::IndexType start; + start.Fill(0); + + + H5::DataSet ds = pdGroup.openDataSet("pixelValues"); + unsigned int type = static_cast(statismo::HDF5Utils::readIntAttribute(ds, "datatype")); + if (type != PixelConversionTrait::GetDataType()) { + std::cout << "Warning: The datatype specified for the scalars does not match the TPixel template argument used in this representer." << std::endl; + } + pdGroup.close(); + typename ImageType::RegionType region(start, size); + newImage->SetRegions(region); + newImage->Allocate(); + newImage->SetOrigin(origin); + newImage->SetSpacing(spacing); + newImage->SetDirection(direction); + + + itk::ImageRegionIterator it(newImage, newImage->GetLargestPossibleRegion()); + it.GoToBegin(); + for (unsigned i = 0; !it.IsAtEnd(); ++it, i++) { + TPixel v = PixelConversionTrait::FromVector(pixelMat.col(i)); + it.Set(v); + } + + return newImage; +} + +template +typename StandardImageRepresenter::ImageType::Pointer +StandardImageRepresenter::LoadRefLegacy(const H5::Group& fg) const { + + std::string tmpfilename; + tmpfilename = statismo::Utils::CreateTmpName(".vtk"); + statismo::HDF5Utils::getFileFromHDF5(fg, "./reference", tmpfilename.c_str()); + + typename itk::ImageFileReader::Pointer reader = itk::ImageFileReader::New(); + reader->SetFileName(tmpfilename); + try { + reader->Update(); + } catch (itk::ImageFileReaderException& e) { + boost::filesystem::remove(tmpfilename); + throw statismo::StatisticalModelException((std::string("Could not read file ") + tmpfilename).c_str()); + } + typename DatasetType::Pointer img = reader->GetOutput(); + img->Register(); + boost::filesystem::remove(tmpfilename); + return img; + +} + + +template +void +StandardImageRepresenter::SetReference(ImageType* reference) { + m_reference = reference; + + typename DomainType::DomainPointsListType domainPoints; + itk::ImageRegionConstIterator it(reference, reference->GetLargestPossibleRegion()); + it.GoToBegin(); + for (; + it.IsAtEnd() == false + ;) { + PointType pt; + reference->TransformIndexToPhysicalPoint(it.GetIndex(), pt); + domainPoints.push_back(pt); + ++it; + } + m_domain = DomainType(domainPoints); +} + +template +statismo::VectorType +StandardImageRepresenter::PointToVector(const PointType& pt) const { + statismo::VectorType v(PointType::GetPointDimension()); + for (unsigned i = 0; i < PointType::GetPointDimension(); i++) { + v(i) = pt[i]; + } + return v; + +} + + + + +template +statismo::VectorType +StandardImageRepresenter::SampleToSampleVector(DatasetConstPointerType image) const { + statismo::VectorType sample(this->GetNumberOfPoints() * GetDimensions()); + itk::ImageRegionConstIterator it(image, image->GetLargestPossibleRegion()); + + it.GoToBegin(); + for (unsigned i = 0; + it.IsAtEnd() == false; + ++i) { + + statismo::VectorType sampleAtPt = PixelConversionTrait::ToVector(it.Value()); + for (unsigned j = 0; j < GetDimensions(); j++) { + unsigned idx = this->MapPointIdToInternalIdx(i, j); + sample[idx] = sampleAtPt[j]; + } + ++it; + } + return sample; +} + + +template +typename StandardImageRepresenter::DatasetPointerType +StandardImageRepresenter::SampleVectorToSample(const statismo::VectorType& sample) const { + + typedef itk::ImageDuplicator< DatasetType > DuplicatorType; + typename DuplicatorType::Pointer duplicator = DuplicatorType::New(); + duplicator->SetInputImage(this->m_reference); + duplicator->Update(); + DatasetPointerType clonedImage = duplicator->GetOutput(); + + itk::ImageRegionIterator it(clonedImage, clonedImage->GetLargestPossibleRegion()); + it.GoToBegin(); + for (unsigned i = 0; !it.IsAtEnd(); ++it, i++) { + + statismo::VectorType valAtPoint(GetDimensions()); + for (unsigned d = 0; d < GetDimensions(); d++) { + unsigned idx = this->MapPointIdToInternalIdx(i, d); + valAtPoint[d] = sample[idx]; + } + ValueType v = PixelConversionTrait::FromVector(valAtPoint); + it.Set(v); + } + return clonedImage; + +} + +template +typename StandardImageRepresenter::ValueType +StandardImageRepresenter::PointSampleFromSample(DatasetConstPointerType sample, unsigned ptid) const { + if (ptid >= GetDomain().GetNumberOfPoints()) { + throw statismo::StatisticalModelException("invalid ptid provided to PointSampleFromSample"); + } + + // we get the point with the id from the domain, as itk does not allow us get a point via its index. + PointType pt = GetDomain().GetDomainPoints()[ptid]; + typename ImageType::IndexType idx; + sample->TransformPhysicalPointToIndex(pt, idx); + ValueType value = sample->GetPixel(idx); + return value; + +} + +template +typename StandardImageRepresenter::ValueType +StandardImageRepresenter::PointSampleVectorToPointSample(const statismo::VectorType& pointSample) const { + return PixelConversionTrait::FromVector(pointSample); +} + +template +statismo::VectorType +StandardImageRepresenter::PointSampleToPointSampleVector(const ValueType& v) const { + return PixelConversionTrait::ToVector(v); +} + + +template +void +StandardImageRepresenter::Save(const H5::Group& fg) const { + + typename ImageType::PointType origin = m_reference->GetOrigin(); + statismo::VectorType originVec(ImageDimension); + for (unsigned i = 0; i < ImageDimension; i++) { + originVec(i) = origin[i]; + } + statismo::HDF5Utils::writeVector(fg, "origin", originVec); + + typename ImageType::SpacingType spacing = m_reference->GetSpacing(); + statismo::VectorType spacingVec(ImageDimension); + for (unsigned i = 0; i < ImageDimension; i++) { + spacingVec(i) = spacing[i]; + } + statismo::HDF5Utils::writeVector(fg, "spacing", spacingVec); + + + statismo::GenericEigenType::VectorType sizeVec(ImageDimension); + for (unsigned i = 0; i < ImageDimension; i++) { + sizeVec(i) = m_reference->GetLargestPossibleRegion().GetSize()[i]; + } + statismo::HDF5Utils::writeVectorOfType(fg, "size", sizeVec); + + typename ImageType::DirectionType direction = m_reference->GetDirection(); + statismo::MatrixType directionMat(ImageDimension, ImageDimension); + for (unsigned i = 0; i < ImageDimension; i++) { + for (unsigned j = 0; j < ImageDimension; j++) { + directionMat(i,j) = direction[i][j]; + } + } + statismo::HDF5Utils::writeMatrix(fg, "direction", directionMat); + + statismo::HDF5Utils::writeInt(fg, "imageDimension", ImageDimension); + + H5::Group pdGroup = fg.createGroup("pointData"); + statismo::HDF5Utils::writeInt(pdGroup, "pixelDimension", GetDimensions()); + + + typedef statismo::GenericEigenType::MatrixType DoubleMatrixType; + statismo::MatrixType pixelMat(GetDimensions(), GetNumberOfPoints()); + + itk::ImageRegionIterator it(m_reference, m_reference->GetLargestPossibleRegion()); + it.GoToBegin(); + for (unsigned i = 0; + it.IsAtEnd() == false; + ++i) { + pixelMat.col(i) = PixelConversionTrait::ToVector(it.Get()); + ++it; + } + DoubleMatrixType pixelMatDouble = pixelMat.cast(); + H5::DataSet ds = statismo::HDF5Utils::writeMatrixOfType(pdGroup, "pixelValues", pixelMatDouble); + statismo::HDF5Utils::writeIntAttribute(ds, "datatype", PixelConversionTrait::GetDataType()); + pdGroup.close(); +} + + +template +unsigned +StandardImageRepresenter::GetNumberOfPoints() const { + return m_reference->GetLargestPossibleRegion().GetNumberOfPixels(); +} + + +template +unsigned +StandardImageRepresenter::GetPointIdForPoint(const PointType& pt) const { + // itks organization is slice row col + typename DatasetType::IndexType idx; + bool ptInImage = this->m_reference->TransformPhysicalPointToIndex(pt, idx); + + typename DatasetType::SizeType size = this->m_reference->GetLargestPossibleRegion().GetSize(); + + // It does not make sense to allow points outside the image, because only the inside is modeled. + // However, some discretization artifacts of image and surface operations may produce points that + // are just on the boundary of the image, but mathematically outside. We accept these points and + // return the iD of the closest image point. + // Any points further out will trigger an exception. + if(!ptInImage) { + for (unsigned int i=0; i size[i]) { + throw statismo::StatisticalModelException("GetPointIdForPoint computed invalid ptId. Make sure that the point is within the reference you chose "); + } + // If it is on the boundary, we set it to the nearest boundary coordinate. + if(idx[i] == -1) idx[i] = 0; + if(idx[i] == size[i]) idx[i] = size[i] - 1; + } + } + + + // in itk, idx 0 is by convention the fastest moving index + unsigned int index=0; + for (unsigned int i=0; i=0; --d) { + multiplier*=size[d]; + } + index+=multiplier*idx[i]; + } + + return index; +} + +template +typename StandardImageRepresenter::DatasetPointerType +StandardImageRepresenter::CloneDataset(DatasetConstPointerType d) const { + typedef itk::ImageDuplicator< DatasetType > DuplicatorType; + typename DuplicatorType::Pointer duplicator = DuplicatorType::New(); + duplicator->SetInputImage(d); + duplicator->Update(); + DatasetPointerType clone = duplicator->GetOutput(); + clone->DisconnectPipeline(); + return clone; +} + +} // namespace itk + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenterTraits.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenterTraits.h new file mode 100644 index 000000000..750f524d4 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardImageRepresenterTraits.h @@ -0,0 +1,262 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __itkStandardImageRepresenterTraits_h +#define __itkStandardImageRepresenterTraits_h + +#include "itkImage.h" +#include "itkVector.h" +#include "Representer.h" + +namespace statismo { + +template<> +struct RepresenterTraits, 4u> > { + + typedef itk::Image, 4u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits, 3u> > { + + typedef itk::Image, 3u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits, 2u> > { + + typedef itk::Image, 2u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + + + +template<> +struct RepresenterTraits, 4u> > { + + typedef itk::Image, 4u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits, 3u> > { + + typedef itk::Image, 3u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits, 2u> > { + + typedef itk::Image, 2u> VectorImageType; + typedef VectorImageType::Pointer DatasetPointerType; + typedef VectorImageType::Pointer DatasetConstPointerType; + typedef VectorImageType::PointType PointType; + typedef VectorImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + +template<> +struct RepresenterTraits > { + + typedef itk::Image ImageType; + typedef ImageType::Pointer DatasetPointerType; + typedef ImageType::Pointer DatasetConstPointerType; + typedef ImageType::PointType PointType; + typedef ImageType::PixelType ValueType; +}; + + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.h new file mode 100644 index 000000000..60dbc1518 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.h @@ -0,0 +1,250 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + + +#ifndef ITK_STANDARD_MESH_REPRESENTER_H +#define ITK_STANDARD_MESH_REPRESENTER_H + +#include + +#include +#include + +#include "statismoITKConfig.h" // this needs to be the first include file + +#include "CommonTypes.h" +#include "Exceptions.h" +#include "Representer.h" + +#include "itkPixelConversionTraits.h" + +namespace statismo { + +template <> +struct RepresenterTraits > > { + + typedef itk::Mesh> MeshType; + + typedef MeshType::Pointer DatasetPointerType; + typedef MeshType::ConstPointer DatasetConstPointerType; + + typedef MeshType::PointType PointType; + typedef MeshType::PointType ValueType; +}; + +template <> +struct RepresenterTraits > > { + + typedef itk::Mesh> MeshType; + + typedef MeshType::Pointer DatasetPointerType; + typedef MeshType::ConstPointer DatasetConstPointerType; + + typedef MeshType::PointType PointType; + typedef MeshType::PointType ValueType; +}; + +template <> +struct RepresenterTraits > > { + + typedef itk::Mesh> MeshType; + + typedef MeshType::Pointer DatasetPointerType; + typedef MeshType::ConstPointer DatasetConstPointerType; + + typedef MeshType::PointType PointType; + typedef MeshType::PointType ValueType; +}; + +} + +namespace itk { + +// helper function to compute the hash value of an itk point (needed by unorderd_map) +template +size_t hash_value(const PointType& pt) { + size_t hash_val = 0; + for (unsigned i = 0; i < pt.GetPointDimension(); i++) { + boost::hash_combine( hash_val, pt[i] ); + } + return hash_val; +} + + +/** + * \ingroup Representers + * \brief A representer for scalar valued itk Meshs + * \sa Representer + */ +template +class StandardMeshRepresenter : public statismo::Representer > >, public Object { + public: + + /* Standard class typedefs. */ + typedef StandardMeshRepresenter Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + + typedef itk::Mesh> MeshType; + typedef typename statismo::Representer RepresenterBaseType; + typedef typename RepresenterBaseType::DomainType DomainType; + typedef typename RepresenterBaseType::PointType PointType; + typedef typename RepresenterBaseType::ValueType ValueType; + typedef typename RepresenterBaseType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterBaseType::DatasetConstPointerType DatasetConstPointerType; + typedef typename MeshType::PointsContainer PointsContainerType; + + /** New macro for creation of through a Smart Pointer. */ + itkSimpleNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro( StandardMeshRepresenter, Object ); + + + static StandardMeshRepresenter* Create() { + return new StandardMeshRepresenter(); + } + + void Load(const H5::Group& fg) override; + StandardMeshRepresenter* Clone() const override; + + /// The type of the data set to be used + typedef MeshType DatasetType; + + // An unordered map is used to cache pointid for corresonding points + typedef boost::unordered_map PointCacheType; + + StandardMeshRepresenter(); + virtual ~StandardMeshRepresenter(); + + unsigned GetDimensions() const override { + return MeshDimension; + } + std::string GetName() const override { + return "itkStandardMeshRepresenter"; + } + typename RepresenterBaseType::RepresenterDataType GetType() const override { + return RepresenterBaseType::POLYGON_MESH; + } + std::string GetVersion() const override { + return "0.1"; + } + + const DomainType& GetDomain() const override { + return m_domain; + } + + /** Set the reference that is used to build the model */ + void SetReference(DatasetPointerType ds); + + statismo::VectorType PointToVector(const PointType& pt) const override; + + + /** + * Converts a sample to its vectorial representation + */ + statismo::VectorType SampleToSampleVector(DatasetConstPointerType sample) const override; + + /** + * Converts the given sample Vector to a Sample (an itk::Mesh) + */ + DatasetPointerType SampleVectorToSample(const statismo::VectorType& sample) const override; + + /** + * Returns the value of the sample at the point with the given id. + */ + ValueType PointSampleFromSample(DatasetConstPointerType sample, unsigned ptid) const override; + + /** + * Given a vector, represening a points convert it to an itkPoint + */ + ValueType PointSampleVectorToPointSample(const statismo::VectorType& pointSample) const override; + + /** + * Given an itkPoint, convert it to a sample vector + */ + statismo::VectorType PointSampleToPointSampleVector(const ValueType& v) const override; + + /** + * Save the state of the representer (this simply saves the reference) + */ + void Save(const H5::Group& fg) const override; + + /// return the number of points of the reference + virtual unsigned GetNumberOfPoints() const; + + /// return the point id associated with the given point + /// \warning This works currently only for points that are defined on the reference + virtual unsigned GetPointIdForPoint(const PointType& point) const override; + + /// return the reference used in the representer + DatasetConstPointerType GetReference() const override { + return m_reference; + } + + void Delete() const override { + this->UnRegister(); + } + + + void DeleteDataset(DatasetPointerType d) const override { }; + DatasetPointerType CloneDataset(DatasetConstPointerType mesh) const override; + + private: + + typename MeshType::Pointer LoadRef(const H5::Group& fg) const; + typename MeshType::Pointer LoadRefLegacy(const H5::Group& fg) const; + + // returns the closest point for the given mesh + unsigned FindClosestPoint(const MeshType* mesh, const PointType pt) const ; + + DatasetConstPointerType m_reference; + DomainType m_domain; + mutable PointCacheType m_pointCache; +}; + + +} // namespace itk + + + +#include "itkStandardMeshRepresenter.hxx" + +#endif /* ITK_STANDARD_MESH_REPRESENTER */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.hxx new file mode 100644 index 000000000..79a3337cd --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStandardMeshRepresenter.hxx @@ -0,0 +1,421 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __itkStandardMeshRepresenter_hxx +#define __itkStandardMeshRepresenter_hxx + +#include "itkStandardMeshRepresenter.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "HDF5Utils.h" +#include "StatismoUtils.h" + +namespace itk { + +template +StandardMeshRepresenter::StandardMeshRepresenter() + : m_reference(DatasetType::New()) { +} +template +StandardMeshRepresenter::~StandardMeshRepresenter() { +} + +template +StandardMeshRepresenter* +StandardMeshRepresenter::Clone() const { + + StandardMeshRepresenter* clone = new StandardMeshRepresenter(); + clone->Register(); + + typename MeshType::Pointer clonedReference = this->CloneDataset(m_reference); + clone->SetReference(clonedReference); + return clone; +} + + +template +void +StandardMeshRepresenter::Load(const H5::Group& fg) { + + std::string repName = statismo::HDF5Utils::readStringAttribute(fg, "name"); + if (repName == "vtkPolyDataRepresenter" || repName == "itkMeshRepresenter") { + this->SetReference(LoadRefLegacy(fg)); + } else { + this->SetReference(LoadRef(fg)); + } +} + +template +typename StandardMeshRepresenter::MeshType::Pointer +StandardMeshRepresenter::LoadRef(const H5::Group& fg) const { + + statismo::MatrixType vertexMat; + statismo::HDF5Utils::readMatrix(fg, "./points", vertexMat); + + typedef typename statismo::GenericEigenType::MatrixType UIntMatrixType; + UIntMatrixType cellsMat; + + unsigned nVertices = vertexMat.cols(); + unsigned nCells = cellsMat.cols(); + unsigned cellDim = cellsMat.rows(); + + + typename MeshType::Pointer mesh = MeshType::New(); + + // add points + for (unsigned i = 0; i < nVertices; i++) { + typename MeshType::PointType p; + for(unsigned int j = 0; j < MeshDimension; j++) { + p[j] = vertexMat(j, i); + } + mesh->SetPoint(i, p); + } + + // add cells + if(statismo::HDF5Utils::existsObjectWithName(fg, "cells")) { + statismo::HDF5Utils::readMatrixOfType(fg, "./cells", cellsMat); + + typedef typename MeshType::CellType::CellAutoPointer CellAutoPointer; + typedef itk::LineCell< typename MeshType::CellType > LineType; + typedef itk::TriangleCell < typename MeshType::CellType > TriangleCellType; + + CellAutoPointer cell; + + for (unsigned i = 0; i < nCells; i++) { + if (cellDim == 2) { + cell.TakeOwnership( new LineType ); + } else if (cellDim == 3) { + cell.TakeOwnership( new TriangleCellType); + } else { + throw statismo::StatisticalModelException("This representer currently supports only line and triangle cells"); + } + + for (unsigned d = 0; d < cellDim; d++) { + cell->SetPointId(d, cellsMat(d, i)); + } + mesh->SetCell( i, cell ); + } + } + + + // currently this representer supports only pointdata of type scalar + if (statismo::HDF5Utils::existsObjectWithName(fg, "pointData")) { + H5::Group pdGroup = fg.openGroup("./pointData"); + + if (statismo::HDF5Utils::existsObjectWithName(pdGroup, "scalars")) { + H5::DataSet ds = pdGroup.openDataSet("scalars"); + unsigned type = static_cast(statismo::HDF5Utils::readIntAttribute(ds, "datatype")); + if (type != PixelConversionTrait::GetDataType()) { + std::cout << "Warning: The datatype specified for the scalars does not match the TPixel template argument used in this representer." << std::endl; + } + statismo::MatrixTypeDoublePrecision scalarMatDouble; + statismo::HDF5Utils::readMatrixOfType(pdGroup, "scalars", scalarMatDouble); + statismo::MatrixType scalarMat = scalarMatDouble.cast(); + assert(static_cast(scalarMatDouble.cols()) == mesh->GetNumberOfPoints()); + typename MeshType::PointDataContainerPointer pd = MeshType::PointDataContainer::New(); + + for (unsigned i = 0; i < scalarMatDouble.cols(); i++) { + TPixel v = PixelConversionTrait::FromVector(scalarMat.col(i)); + pd->InsertElement(i, v); + } + mesh->SetPointData(pd); + } + + pdGroup.close(); + } + + return mesh; +} + + +template +typename StandardMeshRepresenter::MeshType::Pointer +StandardMeshRepresenter::LoadRefLegacy(const H5::Group& fg) const { + + std::string tmpfilename = statismo::Utils::CreateTmpName(".vtk"); + statismo::HDF5Utils::getFileFromHDF5(fg, "./reference", tmpfilename.c_str()); + + + typename itk::MeshFileReader::Pointer reader = itk::MeshFileReader::New(); + reader->SetFileName(tmpfilename); + try { + reader->Update(); + } catch (itk::MeshFileReaderException& e) { + boost::filesystem::remove(tmpfilename); + throw statismo::StatisticalModelException((std::string("Could not read file ") + tmpfilename).c_str()); + } + + typename MeshType::Pointer mesh = reader->GetOutput(); + boost::filesystem::remove(tmpfilename); + return mesh; + +} + + +template +void +StandardMeshRepresenter::SetReference(DatasetPointerType reference) { + m_reference = reference; + + // We create a list of poitns for the domain. + // Furthermore, we cache for all the points of the reference, as these are the most likely ones + // we have to look up later. + typename DomainType::DomainPointsListType domainPointList; + + typename PointsContainerType::ConstPointer points = m_reference->GetPoints(); + typename PointsContainerType::ConstIterator pointIterator = points->Begin(); + unsigned id = 0; + while( pointIterator != points->End() ) { + domainPointList.push_back(pointIterator.Value()); + m_pointCache.insert(std::pair(pointIterator.Value(), id)); + ++pointIterator; + ++id; + } + m_domain = DomainType(domainPointList); + +} + +template +statismo::VectorType +StandardMeshRepresenter::PointToVector(const PointType& pt) const { + statismo::VectorType v(PointType::GetPointDimension()); + for (unsigned i = 0; i < PointType::GetPointDimension(); i++) { + v(i) = pt[i]; + } + return v; + +} + +template +statismo::VectorType +StandardMeshRepresenter::SampleToSampleVector(DatasetConstPointerType mesh) const { + statismo::VectorType sample(GetNumberOfPoints() * GetDimensions()); + + typename PointsContainerType::ConstPointer points = mesh->GetPoints(); + + typename PointsContainerType::ConstIterator pointIterator = points->Begin(); + unsigned id = 0; + while( pointIterator != points->End() ) { + for (unsigned d = 0; d < GetDimensions(); d++) { + unsigned idx = this->MapPointIdToInternalIdx(id, d); + sample[idx] = pointIterator.Value()[d]; + } + ++pointIterator; + ++id; + } + return sample; +} + + + +template +typename StandardMeshRepresenter::DatasetPointerType +StandardMeshRepresenter::SampleVectorToSample(const statismo::VectorType& sample) const { + typename MeshType::Pointer mesh = this->CloneDataset(m_reference); + typename PointsContainerType::Pointer points = mesh->GetPoints(); + typename PointsContainerType::Iterator pointsIterator = points->Begin(); + + unsigned ptId = 0; + while( pointsIterator != points->End() ) { + ValueType v; + for (unsigned d = 0; d < GetDimensions(); d++) { + unsigned idx = this->MapPointIdToInternalIdx(ptId, d); + v[d] = sample[idx]; + } + mesh->SetPoint(ptId, v); + + ++ptId; + ++pointsIterator; + } + return mesh; +} + +template +typename StandardMeshRepresenter::ValueType +StandardMeshRepresenter::PointSampleFromSample(DatasetConstPointerType sample, unsigned ptid) const { + if (ptid >= sample->GetNumberOfPoints()) { + throw statismo::StatisticalModelException("invalid ptid provided to PointSampleFromSample"); + } + + return sample->GetPoint(ptid); +} + + +template +typename StandardMeshRepresenter::ValueType +StandardMeshRepresenter::PointSampleVectorToPointSample(const statismo::VectorType& pointSample) const { + ValueType value; + for (unsigned d = 0; d < GetDimensions(); d++) { + value[d] = pointSample[d]; + } + return value; +} +template +statismo::VectorType +StandardMeshRepresenter::PointSampleToPointSampleVector(const ValueType& v) const { + statismo::VectorType vec(GetDimensions()); + for (unsigned d = 0; d < GetDimensions(); d++) { + vec[d] = v[d]; + } + return vec; +} + + +template +void +StandardMeshRepresenter::Save(const H5::Group& fg) const { + using namespace H5; + + statismo::MatrixType vertexMat = statismo::MatrixType::Zero(3, m_reference->GetNumberOfPoints()); + + for (unsigned i = 0; i < m_reference->GetNumberOfPoints(); i++) { + typename MeshType::PointType pt = m_reference->GetPoint(i); + for (unsigned d = 0; d < 3; d++) { + vertexMat(d, i) = pt[d]; + } + } + statismo::HDF5Utils::writeMatrix(fg, "./points", vertexMat); + + H5::Group pdGroup = fg.createGroup("pointData"); + + typename MeshType::PointDataContainerConstPointer pd = m_reference->GetPointData(); + if (pd.IsNotNull() && pd->Size() == m_reference->GetNumberOfPoints()) { + unsigned numComponents = PixelConversionTrait::ToVector(pd->GetElement(0)).rows(); + + statismo::MatrixType scalarsMat = statismo::MatrixType::Zero(numComponents, m_reference->GetNumberOfPoints()); + for (unsigned i = 0; i < m_reference->GetNumberOfPoints(); i++) { + scalarsMat.col(i) = PixelConversionTrait::ToVector(pd->GetElement(i)); + } + statismo::MatrixTypeDoublePrecision scalarsMatDouble = scalarsMat.cast(); + H5::DataSet ds = statismo::HDF5Utils::writeMatrixOfType(pdGroup, "scalars", scalarsMatDouble); + statismo::HDF5Utils::writeIntAttribute(ds, "datatype", PixelConversionTrait::GetDataType()); + } + + if(this->m_reference->GetNumberOfCells()) { + // check the dimensionality of a face (i.e. the number of points it has). We assume that + // all the cells are the same. + unsigned numPointsPerCell = 0; + if (m_reference->GetNumberOfCells() > 0) { + typename MeshType::CellAutoPointer cellPtr; + m_reference->GetCell(0, cellPtr); + numPointsPerCell = cellPtr->GetNumberOfPoints(); + } + + typedef typename statismo::GenericEigenType::MatrixType UIntMatrixType; + UIntMatrixType facesMat = UIntMatrixType::Zero(numPointsPerCell, m_reference->GetNumberOfCells()); + + + for (unsigned i = 0; i < m_reference->GetNumberOfCells(); i++) { + typename MeshType::CellAutoPointer cellPtr; + m_reference->GetCell(i, cellPtr); + assert(numPointsPerCell == cellPtr->GetNumberOfPoints()); + for (unsigned d = 0; d < numPointsPerCell; d++) { + facesMat(d, i) = cellPtr->GetPointIds()[d]; + } + } + + statismo::HDF5Utils::writeMatrixOfType(fg, "./cells", facesMat); + } +} + + +template +unsigned +StandardMeshRepresenter::GetNumberOfPoints() const { + return this->m_reference->GetNumberOfPoints(); +} + + +template +unsigned +StandardMeshRepresenter::GetPointIdForPoint(const PointType& pt) const { + int ptId = -1; + + // check whether the point is cached, otherwise look for it + typename PointCacheType::const_iterator got = m_pointCache.find (pt); + if (got == m_pointCache.end()) { + ptId = FindClosestPoint(m_reference, pt); + m_pointCache.insert(std::pair(pt, ptId)); + } else { + ptId = got->second; + } + assert(ptId != -1); + return static_cast(ptId); +} + + + +template +typename StandardMeshRepresenter::DatasetPointerType +StandardMeshRepresenter::CloneDataset(DatasetConstPointerType mesh) const { + + // cloning is cumbersome - therefore we let itk do the job for, and use perform a + // Mesh transform using the identity transform. This should result in a perfect clone. + + typedef itk::IdentityTransform IdentityTransformType; + typedef itk::TransformMeshFilter TransformMeshFilterType; + + typename TransformMeshFilterType::Pointer tf = TransformMeshFilterType::New(); + tf->SetInput(mesh); + typename IdentityTransformType::Pointer idTrans = IdentityTransformType::New(); + tf->SetTransform(idTrans); + tf->Update(); + + typename MeshType::Pointer clone = tf->GetOutput(); + clone->DisconnectPipeline(); + return clone; +} + +template +unsigned +StandardMeshRepresenter::FindClosestPoint(const MeshType* mesh, const PointType pt) const { + throw statismo::StatisticalModelException("Not implemented. Currently only points of the reference can be used."); +} + +} // namespace itk + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatismoIO.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatismoIO.h new file mode 100644 index 000000000..285213b29 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatismoIO.h @@ -0,0 +1,100 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef ITKSTATISMOIO_H_ +#define ITKSTATISMOIO_H_ + +#include "itkStatisticalModel.h" +#include "StatismoIO.h" + + +namespace itk { + +template +class StatismoIO { + private: + //These tyedefs are only used internally and as such are marked as private + typedef statismo::StatisticalModel StatisticalModelType; + typedef itk::StatisticalModel ITKStatisticalModelType; + typedef typename ITKStatisticalModelType::Pointer ITKStatisticalModelTypePointer; + + public: + static ITKStatisticalModelTypePointer LoadStatisticalModel( + typename StatisticalModelType::RepresenterType *representer, const std::string &filename, + unsigned maxNumberOfPCAComponents = std::numeric_limits::max()) { + try { + ITKStatisticalModelTypePointer pModel = ITKStatisticalModelType::New(); + pModel->SetstatismoImplObj(statismo::IO::LoadStatisticalModel(representer, filename, maxNumberOfPCAComponents)); + return pModel; + } catch (const statismo::StatisticalModelException& e) { + itkGenericExceptionMacro(<< e.what()); + } + } + + static ITKStatisticalModelTypePointer LoadStatisticalModel( + typename ITKStatisticalModelType::RepresenterType *representer, const H5::Group &modelRoot, + unsigned maxNumberOfPCAComponents = std::numeric_limits::max()) { + try { + ITKStatisticalModelTypePointer pModel = ITKStatisticalModelType::New(); + pModel->SetstatismoImplObj(statismo::IO::LoadStatisticalModel(representer, modelRoot, maxNumberOfPCAComponents)); + return pModel; + } catch (const statismo::StatisticalModelException& e) { + itkGenericExceptionMacro(<< e.what()); + } + } + + static void SaveStatisticalModel(const ITKStatisticalModelType *const model, const std::string &filename) { + try { + statismo::IO::SaveStatisticalModel(model->GetstatismoImplObj(), filename); + } catch (const statismo::StatisticalModelException& e) { + itkGenericExceptionMacro(<< e.what()); + } + } + + static void SaveStatisticalModel(const ITKStatisticalModelType *model, const H5::Group &modelRoot) { + try { + statismo::IO::SaveStatisticalModel(model->GetstatismoImplObj(), modelRoot); + } catch (const statismo::StatisticalModelException& e) { + itkGenericExceptionMacro(<< e.what()); + } + } +}; + +} // namespace itk + +#endif /* ITKSTATISMOIO_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalDeformationModelTransform.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalDeformationModelTransform.h new file mode 100644 index 000000000..3145d23ea --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalDeformationModelTransform.h @@ -0,0 +1,128 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ItkStatisticalDeformationModelTransform +#define __ItkStatisticalDeformationModelTransform + +#include + +#include +#include + +#include "itkStandardImageRepresenter.h" +#include "itkStatisticalModel.h" +#include "itkStatisticalModelTransformBase.h" + +namespace itk { + +/** + * + * \brief An itk transform that allows for deformations defined by a given Statistical Deformation Model. + * +* + * \ingroup Transforms + */ +template +class ITK_EXPORT StatisticalDeformationModelTransform : + public itk::StatisticalModelTransformBase< TDataSet, TScalarType , TDimension> { + + public: + /* Standard class typedefs. */ + typedef StatisticalDeformationModelTransform Self; + typedef itk::StatisticalModelTransformBase< TDataSet, TScalarType , TDimension> Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkSimpleNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro(StatisticalDeformationModelTransform, Superclass); + + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename Superclass::RepresenterType RepresenterType; + + /** + * Clone the current transform + */ + virtual ::itk::LightObject::Pointer CreateAnother() const { + ::itk::LightObject::Pointer smartPtr; + Pointer another = Self::New().GetPointer(); + this->CopyBaseMembers(another); + + smartPtr = static_cast(another); + return smartPtr; + } + + + /** + * Transform a given point according to the deformation induced by the StatisticalModel, + * given the current parameters. + * + * \param pt The point to tranform + * \return The transformed point + */ + virtual OutputPointType TransformPoint(const InputPointType &pt) const { + typename RepresenterType::ValueType d; + try { + d = this->m_StatisticalModel->DrawSampleAtPoint(this->m_coeff_vector, pt); + } catch (ExceptionObject &e) { + std::cout << "exception occured at point " << pt << std::endl; + std::cout << "message " << e.what() << std::endl; + } + OutputPointType transformedPoint; + for (unsigned i = 0; i < pt.GetPointDimension(); i++) { + transformedPoint[i] = pt[i] + d[i]; + } + return transformedPoint; + } + + virtual ~StatisticalDeformationModelTransform() {} + + StatisticalDeformationModelTransform() {} + + private: + + + StatisticalDeformationModelTransform(const StatisticalDeformationModelTransform& orig); // purposely not implemented + StatisticalDeformationModelTransform& operator=(const StatisticalDeformationModelTransform& rhs); //purposely not implemented +}; + + +} // namespace itk + +#endif // __ItkStatisticalDeformationModelTransform diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModel.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModel.h new file mode 100644 index 000000000..398df6033 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModel.h @@ -0,0 +1,287 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef ITKSTATISTICALMODEL_H_ +#define ITKSTATISTICALMODEL_H_ + +#include +#include + +#include +#include + +#include +#include + +#include "ModelInfo.h" +#include "Representer.h" +#include "StatisticalModel.h" +#include "statismoITKConfig.h" + +namespace itk { + +/** + * \brief ITK Wrapper for the statismo::StatisticalModel class. + * \see statismo::StatisticalModel for detailed documentation. + */ +template +class StatisticalModel : public Object { + public: + + typedef StatisticalModel Self; + typedef Object Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + itkNewMacro( Self ); + itkTypeMacro( StatisticalModel, Object ); + + typedef statismo::Representer RepresenterType; + + // statismo stuff + typedef statismo::StatisticalModel ImplType; + + typedef typename statismo::DataManager::DataItemType DataItemType; + + typedef vnl_matrix MatrixType; + typedef vnl_vector VectorType; + + + template + typename boost::result_of::type callstatismoImpl(F f) const { + try { + return f(); + } catch (statismo::StatisticalModelException& s) { + itkExceptionMacro(<< s.what()); + } + } + + void SetstatismoImplObj(ImplType* impl) { + if (m_impl) { + delete m_impl; + } + m_impl = impl; + } + + ImplType* GetstatismoImplObj() const { + return m_impl; + } + + StatisticalModel() : m_impl(0) {} + + virtual ~StatisticalModel() { + if (m_impl) { + delete m_impl; + } + } + + + typedef typename RepresenterType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterType::DatasetConstPointerType DatasetConstPointerType; + + typedef typename RepresenterType::ValueType ValueType; + typedef typename RepresenterType::PointType PointType; + + typedef typename statismo::StatisticalModel::PointValuePairType PointValuePairType; + typedef typename statismo::StatisticalModel::PointValueListType PointValueListType; + + typedef typename statismo::StatisticalModel::PointCovarianceMatrixType PointCovarianceMatrixType; + typedef typename statismo::StatisticalModel::PointValueWithCovariancePairType PointValueWithCovariancePairType; + typedef typename statismo::StatisticalModel::PointValueWithCovarianceListType PointValueWithCovarianceListType; + + typedef typename statismo::StatisticalModel::DomainType DomainType; + + const RepresenterType* GetRepresenter() const { + return callstatismoImpl(boost::bind(&ImplType::GetRepresenter, this->m_impl)); + } + + const DomainType& GetDomain() const { + return callstatismoImpl(boost::bind(&ImplType::GetDomain, this->m_impl)); + } + + DatasetPointerType DrawMean() const { + return callstatismoImpl(boost::bind(&ImplType::DrawMean, this->m_impl)); + } + + ValueType DrawMeanAtPoint(const PointType& pt) const { + typedef ValueType (ImplType::*functype)(const PointType&) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawMeanAtPoint), this->m_impl, pt)); + } + + ValueType DrawMeanAtPoint(unsigned ptid) const { + typedef ValueType (ImplType::*functype)(unsigned) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawMeanAtPoint), this->m_impl, ptid)); + } + + DatasetPointerType DrawSample(const VectorType& coeffs, bool addNoise = false) const { + typedef DatasetPointerType (ImplType::*functype)(const statismo::VectorType&, bool) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawSample), this->m_impl, fromVnlVector(coeffs), addNoise)); + } + + DatasetPointerType DrawSample(bool addNoise = false) const { + typedef DatasetPointerType (ImplType::*functype)(bool) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawSample), this->m_impl, addNoise)); + } + + DatasetPointerType DrawPCABasisSample(unsigned componentNumber) const { + typedef DatasetPointerType (ImplType::*functype)(unsigned) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawPCABasisSample), this->m_impl, componentNumber)); + } + + ValueType DrawSampleAtPoint(const VectorType& coeffs, const PointType& pt, bool addNoise = false) const { + typedef ValueType (ImplType::*functype)(const statismo::VectorType&, const PointType&, bool) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawSampleAtPoint), this->m_impl, fromVnlVector(coeffs), pt, addNoise)); + } + + ValueType DrawSampleAtPoint(const VectorType& coeffs, unsigned ptid, bool addNoise = false) const { + typedef ValueType (ImplType::*functype)(const statismo::VectorType&, unsigned, bool) const; + return callstatismoImpl(boost::bind(static_cast(&ImplType::DrawSampleAtPoint), this->m_impl, fromVnlVector(coeffs), ptid, addNoise)); + } + + VectorType ComputeCoefficients(DatasetConstPointerType ds) const { + return toVnlVector(callstatismoImpl(boost::bind(&ImplType::ComputeCoefficients, this->m_impl, ds))); + } + + double ComputeLogProbability(DatasetConstPointerType ds) const { + return callstatismoImpl(boost::bind(&ImplType::ComputeLogProbability, this->m_impl, ds)); + } + + double ComputeProbability(DatasetConstPointerType ds) const { + return callstatismoImpl(boost::bind(&ImplType::ComputeProbability, this->m_impl, ds)); + } + + double ComputeLogProbabilityOfCoefficients(const VectorType& coeffs) const { + return callstatismoImpl(boost::bind(&ImplType::ComputeLogProbabilityOfCoefficients, this->m_impl, fromVnlVector(coeffs))); + } + + double ComputeProbabilityOfCoefficients(const VectorType& coeffs) const { + return callstatismoImpl(boost::bind(&ImplType::ComputeProbabilityOfCoefficients, this->m_impl, fromVnlVector(coeffs))); + } + + double ComputeMahalanobisDistance(DatasetConstPointerType ds) const { + return callstatismoImpl(boost::bind(&ImplType::ComputeMahalanobisDistance, this->m_impl, ds)); + } + + VectorType ComputeCoefficientsForPointValues(const PointValueListType& pvlist, double variance) const { + typedef statismo::VectorType (ImplType::*functype)(const PointValueListType&, double) const; + return toVnlVector(callstatismoImpl(boost::bind(static_cast(&ImplType::ComputeCoefficientsForPointValues), this->m_impl, pvlist, variance))); + } + + VectorType ComputeCoefficientsForPointValuesWithCovariance(const PointValueWithCovarianceListType& pvclist) const { + typedef statismo::VectorType(ImplType::*functype)(const PointValueWithCovarianceListType&) const; + return toVnlVector(callstatismoImpl(boost::bind(static_cast(&ImplType::ComputeCoefficientsForPointValuesWithCovariance), this->m_impl, pvclist))); + } + + unsigned GetNumberOfPrincipalComponents() const { + return callstatismoImpl(boost::bind(&ImplType::GetNumberOfPrincipalComponents, this->m_impl)); + } + + float GetNoiseVariance() const { + return callstatismoImpl(boost::bind(&ImplType::GetNoiseVariance, this->m_impl)); + } + + MatrixType GetCovarianceAtPoint(const PointType& pt1, const PointType& pt2) const { + typedef statismo::MatrixType (ImplType::*functype)(const PointType&, const PointType&) const; + return toVnlMatrix(callstatismoImpl(boost::bind(static_cast(&ImplType::GetCovarianceAtPoint), this->m_impl, pt1, pt2))); + } + + MatrixType GetCovarianceAtPoint(unsigned ptid1, unsigned ptid2) const { + typedef statismo::MatrixType (ImplType::*functype)(unsigned, unsigned ) const; + return toVnlMatrix(callstatismoImpl(boost::bind(static_cast(&ImplType::GetCovarianceAtPoint),this->m_impl, ptid1, ptid2))); + } + + MatrixType GetJacobian(const PointType& pt) const { + typedef statismo::MatrixType (ImplType::*functype)(const PointType&) const; + return toVnlMatrix(callstatismoImpl(boost::bind(static_cast(&ImplType::GetJacobian), this->m_impl, pt))); + } + + MatrixType GetJacobian(unsigned ptId) const { + typedef statismo::MatrixType (ImplType::*functype)(unsigned) const; + return toVnlMatrix(callstatismoImpl(boost::bind(static_cast(&ImplType::GetJacobian), this->m_impl, ptId))); + } + + MatrixType GetPCABasisMatrix() const { + return toVnlMatrix(callstatismoImpl(boost::bind(&ImplType::GetPCABasisMatrix, this->m_impl))); + } + + MatrixType GetInverseCovarianceMatrix() const { + return toVnlMatrix(callstatismoImpl(boost::bind(&ImplType::GetInverseCovarianceMatrix, this->m_impl))); + } + + MatrixType GetOrthonormalPCABasisMatrix() const { + return toVnlMatrix(callstatismoImpl(boost::bind(&ImplType::GetOrthonormalPCABasisMatrix, this->m_impl))); + } + + VectorType GetPCAVarianceVector() const { + return toVnlVector(callstatismoImpl(boost::bind(&ImplType::GetPCAVarianceVector, this->m_impl))); + } + + VectorType GetMeanVector() const { + return toVnlVector(callstatismoImpl(boost::bind(&ImplType::GetMeanVector, this->m_impl))); + } + + const statismo::ModelInfo& GetModelInfo() const { + return callstatismoImpl(boost::bind(&ImplType::GetModelInfo, this->m_impl)); + } + + private: + + static MatrixType toVnlMatrix(const statismo::MatrixType& M) { + return MatrixType(M.data(), M.rows(), M.cols()); + + } + + static VectorType toVnlVector(const statismo::VectorType& v) { + return VectorType(v.data(), v.rows()); + + } + + static statismo::VectorType fromVnlVector(const VectorType& v) { + return Eigen::Map(v.data_block(), v.size()); + + } + + StatisticalModel(const StatisticalModel& orig); + StatisticalModel& operator=(const StatisticalModel& rhs); + + ImplType* m_impl; +}; + + +} + +#endif /* ITKSTATISTICALMODEL_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.h new file mode 100644 index 000000000..660c47d1a --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.h @@ -0,0 +1,229 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __itkStatisticalModelTransform_h +#define __itkStatisticalModelTransform_h + +#include + +#include +#include +#include + +#include "Representer.h" + +#include "itkStatisticalModel.h" + +namespace itk { + +/** + * + * \brief Base class that implements an itk transform interface for statistical models. + * + * Statistical models (shape or deformation models) are often used to model the typical variations within + * an object class. The StatisticalModelTransformBase implements the standard Transform interface, and thus allows + * for the use of statistical models within the ITK registration framework. + * Subclasses will need to implement the TransformPoint method, as its semantics depends on the type of statistical model. + * + * \ingroup Transforms + */ + +template +class ITK_EXPORT StatisticalModelTransformBase : + public itk::Transform< TScalarType , TInputDimension, TOutputDimension> { + public: + /* Standard class typedefs. */ + typedef StatisticalModelTransformBase Self; + typedef itk::Transform< TScalarType , TInputDimension, TOutputDimension> Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + typedef vnl_vector VectorType; + typedef vnl_matrix MatrixType; + + + /** + * Copy the members of the current transform. To be used by subclasses in the CreateAnother method. + */ + virtual void CopyBaseMembers(StatisticalModelTransformBase* another) const { + another->m_StatisticalModel = m_StatisticalModel; + another->m_coeff_vector = m_coeff_vector; + another->m_usedNumberCoefficients = m_usedNumberCoefficients; + another->m_FixedParameters = m_FixedParameters; + another->m_Parameters = this->m_Parameters; + } + + + /** Run-time type information (and related methods). */ + itkTypeMacro( StatisticalModelTransformBase, Superclass ); + + /* Dimension of parameters. */ + itkStaticConstMacro(SpaceDimension, unsigned int, TInputDimension); + itkStaticConstMacro(InputSpaceDimension, unsigned int, TInputDimension); + itkStaticConstMacro(OutputSpaceDimension, unsigned int, TOutputDimension); + + + /* Parameters Type */ + typedef typename Superclass::ParametersType ParametersType; + typedef typename Superclass::JacobianType JacobianType; + typedef typename Superclass::ScalarType ScalarType; + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename Superclass::InputVectorType InputVectorType; + typedef typename Superclass::OutputVectorType OutputVectorType; + typedef typename Superclass::InputVnlVectorType InputVnlVectorType; + typedef typename Superclass::OutputVnlVectorType OutputVnlVectorType; + typedef typename Superclass::InputCovariantVectorType + InputCovariantVectorType; + typedef typename Superclass::OutputCovariantVectorType + OutputCovariantVectorType; + + typedef statismo::Representer RepresenterType; + typedef itk::StatisticalModel StatisticalModelType; + + + virtual void ComputeJacobianWithRespectToParameters(const InputPointType &pt, JacobianType & jacobian) const; + + /** + * Transform a given point according to the deformation induced by the StatisticalModel, + * given the current parameters. + * + * \param pt The point to tranform + * \return The transformed point + */ + virtual OutputPointType TransformPoint(const InputPointType &pt) const = 0; + + /** + * Set the parameters to the IdentityTransform + * */ + virtual void SetIdentity(void); + + /** + * Set the parameters of the transform + */ + virtual void SetParameters( const ParametersType & ); + + /** + * Get the parameters of the transform + */ + virtual const ParametersType& GetParameters(void) const; + + /** + * Does nothing - as the transform does not have any fixed parameters + */ + virtual void SetFixedParameters( const ParametersType &p ) { + // there no fixed parameters + + } + + /** + * returns an empty Parameter vector, as the tranform does not have any fixed parameters + */ + virtual const ParametersType& GetFixedParameters(void) const { + return this->m_FixedParameters; + }; + + /** + * Convenience method to obtain the current coefficients of the StatisticalModel as a statismo::VectorType. + * The resulting vector is the same as it would be obtained from GetParameters. + */ + virtual VectorType GetCoefficients() const { + return m_coeff_vector; + } + + /** + * Convenicne method to set the coefficients of the underlying StatisticalModel from a statismo::VectorType. + * This has the same effect as calling SetParameters. + */ + virtual void SetCoefficients( VectorType& coefficients) { + m_coeff_vector = coefficients; + } + + /** + * Set the statistical model that defines the valid transformations. + */ + void SetStatisticalModel(const StatisticalModelType* model); + + /** + * Returns the statistical model used. + */ + typename StatisticalModelType::ConstPointer GetStatisticalModel() const; + + /** + * Set the number of PCA Coefficients used by the model. This parameters has a + * regularization effect. Setting it to a small value will restrict the possible tranformations + * to the main modes of variations. + */ + void SetUsedNumberOfCoefficients(unsigned n) { + m_usedNumberCoefficients = n; + } + + /** + * returns the number of used model coefficients. + */ + unsigned GetUsedNumberOfCoefficients() { + return m_usedNumberCoefficients; + } + + protected: + + StatisticalModelTransformBase(); + virtual ~StatisticalModelTransformBase() {}; + + void PrintSelf(std::ostream &os, Indent indent) const; + + typename StatisticalModelType::ConstPointer m_StatisticalModel; + VectorType m_coeff_vector; + unsigned m_usedNumberCoefficients; + ParametersType m_FixedParameters; + + StatisticalModelTransformBase(const Self& obj);// : Superclass(obj) {} //purposely not implemented + void operator=(const Self& rhs);// { return Superclass::operator=(rhs); } //purposely not implemented + + + +}; + + +} // namespace itk + + +#ifndef ITK_MANUAL_INSTANTIATION +# include "itkStatisticalModelTransformBase.hxx" +#endif + +#endif /* __itkStatisticalModelTransform_h */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.hxx new file mode 100644 index 000000000..5a1fc7082 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalModelTransformBase.hxx @@ -0,0 +1,186 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __itkStatisticalModelTransformBase_hxx +#define __itkStatisticalModelTransformBase_hxx + +#include "itkStatisticalModelTransformBase.h" + + +namespace itk { + +/*! + * Constructor with default arguments. + */ +template +StatisticalModelTransformBase +::StatisticalModelTransformBase() : + Superclass(0), // we don't know the number of parameters at this point. + m_StatisticalModel(0), + m_coeff_vector(0), + m_usedNumberCoefficients(10000) { // something large + itkDebugMacro( << "Constructor MorphableModelTransform()"); + + this->m_FixedParameters.SetSize(0); +} + + + +/*! + * Set the morphable model and ajust the parameters dimension. + */ +template +void +StatisticalModelTransformBase +::SetStatisticalModel(const StatisticalModelType* model) { + itkDebugMacro( << "Setting statistical model "); + m_StatisticalModel = model; + + + this->m_Parameters.SetSize(model->GetNumberOfPrincipalComponents()); + this->m_Parameters.Fill(0.0); + + this->m_coeff_vector.set_size(model->GetNumberOfPrincipalComponents()); + +} + +template +typename StatisticalModelTransformBase::StatisticalModelType::ConstPointer +StatisticalModelTransformBase +::GetStatisticalModel() const { + itkDebugMacro( << "Getting statistical model "); + return m_StatisticalModel; +} + + +/*! + * Set the parameters to the IdentityTransform. + */ +template +void +StatisticalModelTransformBase +::SetIdentity( ) { + itkDebugMacro( << "Setting Identity"); + + for (unsigned i = 0; i < this->GetNumberOfParameters(); i++) + this->m_coeff_vector[i] = 0; + + + this->Modified(); +} + +template +void +StatisticalModelTransformBase +::SetParameters( const ParametersType & parameters ) { + itkDebugMacro( << "Setting parameters " << parameters ); + + // Set angle + for(unsigned int i=0; i < std::min(m_usedNumberCoefficients, (unsigned) this->GetNumberOfParameters()); i++) { + m_coeff_vector[i] = parameters[i]; + } + for (unsigned int i = std::min(m_usedNumberCoefficients, (unsigned) this->GetNumberOfParameters()); i < this->GetNumberOfParameters(); i++) { + m_coeff_vector[i] = 0; + } + + // Modified is always called since we just have a pointer to the + // parameters and cannot know if the parameters have changed. + this->Modified(); + + itkDebugMacro(<<"After setting parameters "); +} + + + + + +// Get Parameters +template +const typename StatisticalModelTransformBase::ParametersType & +StatisticalModelTransformBase +::GetParameters( void ) const { + itkDebugMacro( << "Getting parameters "); + + + // Get the translation + for(unsigned int i=0; i < this->GetNumberOfParameters(); i++) { + this->m_Parameters[i] = this->m_coeff_vector[i]; + } + itkDebugMacro(<<"After getting parameters " << this->m_Parameters ); + + return this->m_Parameters; +} + + + + + + +template +void +StatisticalModelTransformBase +::ComputeJacobianWithRespectToParameters(const InputPointType &pt, JacobianType &jacobian) const { + jacobian.SetSize(OutputSpaceDimension, m_StatisticalModel->GetNumberOfPrincipalComponents()); + jacobian.Fill(0); + + const MatrixType& statModelJacobian = m_StatisticalModel->GetJacobian(pt); + + for (unsigned i = 0; i < statModelJacobian.rows(); i++) { + for (unsigned j = 0; j < std::min(m_usedNumberCoefficients, (unsigned) this->GetNumberOfParameters()); j++) { + jacobian[i][j] = statModelJacobian[i][j]; + } + } + + + itkDebugMacro( << "Jacobian with MM:\n" << jacobian); + itkDebugMacro( << "After GetMorphableModelJacobian:" + << "\nJacobian = \n" << jacobian); +} + + + +// Print self +template +void +StatisticalModelTransformBase:: +PrintSelf(std::ostream &os, Indent indent) const { + Superclass::PrintSelf(os,indent); +} + +} // namespace + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalShapeModelTransform.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalShapeModelTransform.h new file mode 100644 index 000000000..899fd1a55 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/itkStatisticalShapeModelTransform.h @@ -0,0 +1,123 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ItkStatisticalShapeModelTransform +#define __ItkStatisticalShapeModelTransform + +#include + +#include +#include + +#include "itkStandardImageRepresenter.h" +#include "itkStatisticalModel.h" +#include "itkStatisticalModelTransformBase.h" + +namespace itk { + +/** + * + * \brief An itk transform that allows for deformations defined by a given Statistical Shape Model. + * +* + * \ingroup Transforms + */ +template +class ITK_EXPORT StatisticalShapeModelTransform : + public itk::StatisticalModelTransformBase< TRepresenter, TScalarType , TDimension> { + public: + /* Standard class typedefs. */ + typedef StatisticalShapeModelTransform Self; + typedef itk::StatisticalModelTransformBase< TRepresenter, TScalarType , TDimension> Superclass; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + + + itkSimpleNewMacro( Self ); + + + /** Run-time type information (and related methods). */ + itkTypeMacro(StatisticalShapeModelTransform, Superclass); + + + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename Superclass::RepresenterType RepresenterType; + + /** + * Clone the current transform + */ + virtual ::itk::LightObject::Pointer CreateAnother() const { + ::itk::LightObject::Pointer smartPtr; + Pointer another = Self::New().GetPointer(); + this->CopyBaseMembers(another); + + smartPtr = static_cast(another); + return smartPtr; + } + + + /** + * Transform a given point according to the deformation induced by the StatisticalModel, + * given the current parameters. + * + * \param pt The point to tranform + * \return The transformed point + */ + virtual OutputPointType TransformPoint(const InputPointType &pt) const { + typename RepresenterType::ValueType d; + try { + d = this->m_StatisticalModel->DrawSampleAtPoint(this->m_coeff_vector, pt); + } catch (ExceptionObject &e) { + std::cout << "exception occured at point " << pt << std::endl; + std::cout << "message " << e.what() << std::endl; + } + return d; + } + + StatisticalShapeModelTransform() {} + + private: + + StatisticalShapeModelTransform(const StatisticalShapeModelTransform& orig); // purposely not implemented + StatisticalShapeModelTransform& operator=(const StatisticalShapeModelTransform& rhs); //purposely not implemented +}; + + +} // namespace itk + +#endif // __ItkStatisticalShapeModelTransform diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/statismoITKConfig.h b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/statismoITKConfig.h new file mode 100644 index 000000000..0a719625b --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/ITK/include/statismoITKConfig.h @@ -0,0 +1,49 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __STATISMO_ITK_CONFIG_H +#define __STATISMO_ITK_CONFIG_H + +#include "Config.h" + +// in case we are using itk, we are using the HDF5 version that comes with ITK +#include "itk_H5Cpp.h" + +// prevent standard HDF5 header from being included +#define _H5CPP_H +#define __H5Cpp_H +#endif // __STATISMO_ITK_CONFIG_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/Statismo/core/CMakeLists.txt new file mode 100644 index 000000000..febd4f0ab --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(src) diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/CommonTypes.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/CommonTypes.h new file mode 100644 index 000000000..18f9fa9eb --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/CommonTypes.h @@ -0,0 +1,146 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __COMMON_TYPES_H +#define __COMMON_TYPES_H + +#include +#include +#include +#include +#include + +#include + +#include "itk_eigen.h" +#include ITK_EIGEN(Dense) + +#include "Config.h" +#include "Domain.h" +#include "Exceptions.h" + +namespace statismo { + +const double PI = 3.14159265358979323846; + +/// the type that is used for all vector and matrices throughout the library. +typedef double ScalarType; + +// wrapper struct that allows us to easily select matrix and vectors of an arbitrary +// type, wich has the same traits as the standard matrix / vector traits +template struct GenericEigenType { + typedef Eigen::Matrix MatrixType; + typedef Eigen::DiagonalMatrix DiagMatrixType; + typedef Eigen::Matrix VectorType; + typedef Eigen::Matrix RowVectorType; + +}; +typedef GenericEigenType::MatrixType MatrixType; +typedef GenericEigenType::MatrixType MatrixTypeDoublePrecision; +typedef GenericEigenType::DiagMatrixType DiagMatrixType; +typedef GenericEigenType::VectorType VectorType; +typedef GenericEigenType::VectorType VectorTypeDoublePrecision; +typedef GenericEigenType::RowVectorType RowVectorType; + +// type definitions used in the standard file format. +// Note that these are the same as used by VTK +const static unsigned Void = 0; // not capitalized, as windows defines: #define VOID void, which causes trouble +const static unsigned SIGNED_CHAR = 2; +const static unsigned UNSIGNED_CHAR = 3; +const static unsigned SIGNED_SHORT = 4; +const static unsigned UNSIGNED_SHORT = 5; +const static unsigned SIGNED_INT = 6; +const static unsigned UNSIGNED_INT = 7; +const static unsigned SIGNED_LONG = 8; +const static unsigned UNSIGNED_LONG = 9; +const static unsigned FLOAT = 10; +const static unsigned DOUBLE = 11; + +template unsigned GetDataTypeId() { + throw StatisticalModelException("The datatype that was provided is not a valid statismo data type "); +} +template <> inline unsigned GetDataTypeId() { + return SIGNED_CHAR; +} +template <> inline unsigned GetDataTypeId() { + return UNSIGNED_CHAR; +} +template <> inline unsigned GetDataTypeId() { + return SIGNED_SHORT; +} +template <> inline unsigned GetDataTypeId() { + return UNSIGNED_SHORT; +} +template <> inline unsigned GetDataTypeId() { + return SIGNED_INT; +} +template <> inline unsigned GetDataTypeId() { + return UNSIGNED_INT; +} +template <> inline unsigned GetDataTypeId() { + return SIGNED_LONG; +} +template <> inline unsigned GetDataTypeId() { + return UNSIGNED_LONG; +} +template <> inline unsigned GetDataTypeId() { + return FLOAT; +} +template <> inline unsigned GetDataTypeId() { + return DOUBLE; +} + + + +} //namespace statismo + +// If we want to store a vector in a boost map, boost requires this function to be present. +// We define it here once and for all. +// Because of the way boost looksup the values, it needs to be defined in the namespace Eigen +namespace Eigen { +inline size_t hash_value(const statismo::VectorType& v) { + + size_t value = 0; + for (unsigned i = 0; i < v.size(); i++) { + boost::hash_combine(value, v(i)); + } + return value; +} +} + +#endif + diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.h new file mode 100644 index 000000000..3c0f0e9d0 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.h @@ -0,0 +1,138 @@ +/* + * ConditionalModelBuilder.h + * + * Created by Remi Blanc, + * + * Copyright (c) 2011 ETH Zurich + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __CONDITIONALMODELBUILDER_H_ +#define __CONDITIONALMODELBUILDER_H_ + +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "DataManagerWithSurrogates.h" +#include "ModelBuilder.h" +#include "StatisticalModel.h" + +namespace statismo { + +/** + * \brief Creates a StatisticalModel conditioned on some external data + * + * The principle of this class is to exploit supplementary information (surrogate variables) describing + * the samples (e.g. the age and gender of the subject) to generate a conditional statistical model. + * This class assumes a joint multivariate gaussian distribution of the sample vector and the continuous surrogates + * Categorical surrogates are taken into account by selecting the subset of samples that fit in the requested categories. + * + * For mathematical details and illustrations, see the paper + * Conditional Variability of Statistical Shape Models Based on Surrogate Variables + * R. Blanc, M. Reyes, C. Seiler and G. Szekely, In Proc. MICCAI 2009 + * + * CAVEATS: + * - conditioning on too many categories may lead to small or empty training sets + * - using more surrogate variables than training samples may cause instabilities + * + * The class does not implement missing data functionalities. + * + * \sa DataManagerWithSurrogates + */ +template +class ConditionalModelBuilder : public ModelBuilder { + public: + + typedef ModelBuilder Superclass; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + + typedef std::pair CondVariableValuePair; //replace the first element by a bool (indicates whether the variable is in use) + typedef std::vector CondVariableValueVectorType; //replace list by vector, to gain direct access + + typedef DataManagerWithSurrogates DataManagerType; + typedef typename DataManagerType::DataItemListType DataItemListType; + typedef typename DataManagerType::DataItemWithSurrogatesType DataItemWithSurrogatesType; + typedef typename DataManagerType::SurrogateTypeInfoType SurrogateTypeInfoType; + + /** + * Factory method to create a new ConditionalModelBuilder + * \param representer The representer + */ + static ConditionalModelBuilder* Create() { + return new ConditionalModelBuilder(); + } + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + /** + * Builds a new model from the provided data and the requested constraints. + * \param sampleSet A list training samples with associated surrogate data - typically obtained from a DataManagerWithSurrogates. + * \param surrogateTypes A vector with length corresponding to the number of surrogate variables, indicating whether a variable is continuous or categorical - typically obtained from a DataManagerWithSurrogates. + * \param conditioningInfo A vector (length = number of surrogates) indicating which surrogates are used for conditioning, and the conditioning value. + * \param noiseVariance The variance of the noise assumed on our data + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModel(const DataItemListType& sampleSet, + const SurrogateTypeInfoType& surrogateTypesInfo, + const CondVariableValueVectorType& conditioningInfo, + float noiseVariance, + double modelVarianceRetained = 1) const; + + private: + + unsigned PrepareData(const DataItemListType& DataItemList, + const SurrogateTypeInfoType& surrogateTypesInfo, + const CondVariableValueVectorType& conditioningInfo, + DataItemListType* acceptedSamples, + MatrixType* surrogateMatrix, + VectorType* conditions) const; + + CondVariableValueVectorType m_conditioningInfo; //keep in storage +}; + + + +} // namespace statismo + +#include "ConditionalModelBuilder.hxx" + +#endif /* __PCAMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.hxx new file mode 100644 index 000000000..a5dad2bb6 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ConditionalModelBuilder.hxx @@ -0,0 +1,267 @@ +/* + * Representer.hxx + * + * Created by Remi Blanc, Marcel Luethi + * + * Copyright (c) 2011 ETH Zurich + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ConditionalModelBuilder_hxx +#define __ConditionalModelBuilder_hxx + +#include "ConditionalModelBuilder.h" + +#include + +#include + +#include "Exceptions.h" +#include "PCAModelBuilder.h" + +namespace statismo { + +// +// ConditionalModelBuilder +// +// + + +template +unsigned +ConditionalModelBuilder::PrepareData(const DataItemListType& sampleDataList, + const SurrogateTypeInfoType& surrogateTypesInfo, + const CondVariableValueVectorType& conditioningInfo, + DataItemListType *acceptedSamples, + MatrixType *surrogateMatrix, + VectorType *conditions) const { + bool acceptSample; + unsigned nbAcceptedSamples = 0; + unsigned nbContinuousSurrogatesInUse = 0, nbCategoricalSurrogatesInUse = 0; + std::vector indicesContinuousSurrogatesInUse; + std::vector indicesCategoricalSurrogatesInUse; + + //first: identify the continuous and categorical variables, which are used for conditioning and which are not + for (unsigned i=0 ; iresize(nbContinuousSurrogatesInUse); + for (unsigned i=0 ; iresize(nbContinuousSurrogatesInUse, sampleDataList.size()); //number of variables is now known: nbContinuousSurrogatesInUse ; the number of samples is yet unknown... + + //now, browse all samples to select the ones which fall into the requested categories + for (typename DataItemListType::const_iterator it = sampleDataList.begin(); it != sampleDataList.end(); ++it) { + const DataItemWithSurrogatesType* sampleData = dynamic_cast(*it); + if (sampleData == 0) { + // this is a normal sample without surrogate information. + // we simply discard it + std::cout<<"WARNING: ConditionalModelBuilder, sample data "<< (*it)->GetDatasetURI()<<" has no surrogate data associated, and is ignored"<GetSurrogateVector(); + acceptSample = true; + for (unsigned i=0 ; ipush_back(*it); + //and fill in the matrix of continuous variables + for (unsigned j=0 ; jconservativeResize(Eigen::NoChange_t(), nbAcceptedSamples); + + return nbAcceptedSamples; +} + +template +typename ConditionalModelBuilder::StatisticalModelType* +ConditionalModelBuilder::BuildNewModel(const DataItemListType& sampleDataList, + const SurrogateTypeInfoType& surrogateTypesInfo, + const CondVariableValueVectorType& conditioningInfo, + float noiseVariance, + double modelVarianceRetained) const { + DataItemListType acceptedSamples; + MatrixType X; + VectorType x0; + unsigned nSamples = PrepareData(sampleDataList, surrogateTypesInfo, conditioningInfo, &acceptedSamples, &X, &x0); + assert(nSamples == acceptedSamples.size()); + + unsigned nCondVariables = X.rows(); + + // build a normal PCA model + typedef PCAModelBuilder PCAModelBuilderType; + PCAModelBuilderType* modelBuilder = PCAModelBuilderType::Create(); + StatisticalModelType* pcaModel = modelBuilder->BuildNewModel(acceptedSamples, noiseVariance); + + unsigned nPCAComponents = pcaModel->GetNumberOfPrincipalComponents(); + + if ( X.cols() == 0 || X.rows() == 0) { + return pcaModel; + } else { + // the scores in the pca model correspond to the parameters of each sample in the model. + MatrixType B = pcaModel->GetModelInfo().GetScoresMatrix().transpose(); + assert(B.rows() == nSamples); + assert(B.cols() == nPCAComponents); + + // A is the joint data matrix B, X, where X contains the conditional information for each sample + // Thus the i-th row of A contains the PCA parameters b of the i-th sample, + // together with the conditional information for each sample + MatrixType A(nSamples, nPCAComponents+nCondVariables); + A << B,X.transpose(); + + // Compute the mean and the covariance of the joint data matrix + VectorType mu = A.colwise().mean().transpose(); // colwise returns a row vector + assert(mu.rows() == nPCAComponents + nCondVariables); + + MatrixType A0 = A.rowwise() - mu.transpose(); // + MatrixType cov = 1.0 / (nSamples-1) * A0.transpose() * A0; + + assert(cov.rows() == cov.cols()); + assert(cov.rows() == pcaModel->GetNumberOfPrincipalComponents() + nCondVariables); + + // extract the submatrices involving the conditionals x + // note that since the matrix is symmetric, Sbx = Sxb.transpose(), hence we only store one + MatrixType Sbx = cov.topRightCorner(nPCAComponents, nCondVariables); + MatrixType Sxx = cov.bottomRightCorner(nCondVariables, nCondVariables); + MatrixType Sbb = cov.topLeftCorner(nPCAComponents, nPCAComponents); + + // compute the conditional mean + VectorType condMean = mu.topRows(nPCAComponents) + Sbx * Sxx.inverse() * (x0 - mu.bottomRows(nCondVariables)); + + // compute the conditional covariance + MatrixType condCov = Sbb - Sbx * Sxx.inverse() * Sbx.transpose(); + + // get the sample mean corresponding the the conditional given mean of the parameter vectors + VectorType condMeanSample = pcaModel->GetRepresenter()->SampleToSampleVector(pcaModel->DrawSample(condMean)); + + + // so far all the computation have been done in parameter (latent) space. Go back to sample space. + // (see PartiallyFixedModelBuilder for a detailed documentation) + // TODO we should factor this out into the base class, as it is the same code as it is used in + // the partially fixed model builder + const VectorType& pcaVariance = pcaModel->GetPCAVarianceVector(); + VectorTypeDoublePrecision pcaSdev = pcaVariance.cast().array().sqrt(); + + typedef Eigen::JacobiSVD SVDType; + MatrixTypeDoublePrecision innerMatrix = pcaSdev.asDiagonal() * condCov.cast() * pcaSdev.asDiagonal(); + SVDType svd(innerMatrix, Eigen::ComputeThinU); + VectorType singularValues = svd.singularValues().cast(); + + // keep only the necessary number of modes, wrt modelVarianceRetained... + double totalRemainingVariance = singularValues.sum(); // + //and count the number of modes required for the model + double cumulatedVariance = singularValues(0); + unsigned numComponentsToReachPrescribedVariance = 1; + while ( cumulatedVariance/totalRemainingVariance < modelVarianceRetained ) { + numComponentsToReachPrescribedVariance++; + if (numComponentsToReachPrescribedVariance==singularValues.size()) break; + cumulatedVariance += singularValues(numComponentsToReachPrescribedVariance-1); + } + + unsigned numComponentsToKeep = std::min( numComponentsToReachPrescribedVariance, singularValues.size() ); + + VectorType newPCAVariance = singularValues.topRows(numComponentsToKeep); + MatrixType newPCABasisMatrix = (pcaModel->GetOrthonormalPCABasisMatrix() * svd.matrixU().cast()).leftCols(numComponentsToKeep); + + StatisticalModelType* model = StatisticalModelType::Create(pcaModel->GetRepresenter(), condMeanSample, newPCABasisMatrix, newPCAVariance, noiseVariance); + + // add builder info and data info to the info list + MatrixType scores(0,0); + BuilderInfo::ParameterInfoList bi; + + bi.push_back(BuilderInfo::KeyValuePair("NoiseVariance ", Utils::toString(noiseVariance))); + + //generate a matrix ; first column = boolean (yes/no, this variable is used) ; second: conditioning value. + MatrixType conditioningInfoMatrix(conditioningInfo.size(), 2); + for (unsigned i=0 ; i(*it); + std::ostringstream os; + os << "URI_" << i; + di.push_back(BuilderInfo::KeyValuePair(os.str().c_str(),sampleData->GetDatasetURI())); + + os << "_surrogates"; + di.push_back(BuilderInfo::KeyValuePair(os.str().c_str(),sampleData->GetSurrogateFilename())); + } + + std::ostringstream os; + os << "surrogates_types"; + di.push_back(BuilderInfo::KeyValuePair(os.str().c_str(),surrogateTypesInfo.typeFilename)); + + + BuilderInfo builderInfo("ConditionalModelBuilder", di, bi); + + ModelInfo::BuilderInfoList biList; + biList.push_back(builderInfo); + + ModelInfo info(scores, biList); + model->SetModelInfo(info); + + delete pcaModel; + + return model; + } + +} + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Config.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Config.h new file mode 100644 index 000000000..84ce5e432 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Config.h @@ -0,0 +1,53 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __STATISMO_CONFIG_H +#define __STATISMO_CONFIG_H + +#include + +namespace statismo { +const std::string STATISMO_VERSION = "0.11.0"; +} + +// gccxml (as used by e.g. wrapitk) does not compile with vectorization enabled. +#if defined(__GCCXML__) +#define EIGEN_DONT_VECTORIZE 1 +#define EIGEN_DISABLE_UNALIGNED_ARRAY_ASSERT 1 +#endif + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.h new file mode 100644 index 000000000..74f0bb186 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.h @@ -0,0 +1,223 @@ +/* + * DataItem.h + * + * Created by Marcel Luethi + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __SAMPLE_DATA_H +#define __SAMPLE_DATA_H + +#include "CommonTypes.h" +#include "HDF5Utils.h" +#include "Representer.h" + +namespace statismo { +/* \class DataItem + * \brief Holds all the information for a given sample. + * Use GetSample() to recover a Sample + * \warning This method generates a new object containing the sample. If the Representer does not provide a smart pointer, the user is responsible for releasing memory. + */ +template +class DataItem { + public: + typedef Representer RepresenterType; + typedef typename RepresenterType::DatasetPointerType DatasetPointerType; + + /** + * Ctor. Usually not called from outside of the library + */ + static DataItem* Create(const RepresenterType* representer, const std::string& URI, const VectorType& sampleVector) { + return new DataItem(representer, URI, sampleVector); + } + + /** + * Dtor + */ + virtual ~DataItem() {} + + /** Create a new DataItem object, using the data from the group in the HDF5 file + * \param dsGroup. The group in the hdf5 file for this dataset + */ + static DataItem* Load(const RepresenterType* representer, const H5::Group& dsGroup); + /** + * Save the sample data to the hdf5 group dsGroup. + */ + virtual void Save(const H5::Group& dsGroup) const; + + /** + * Get the URI of the original dataset + */ + std::string GetDatasetURI() const { + return m_URI; + } + + /** + * Get the representer used to create this sample + */ + const RepresenterType* GetRepresenter() const { + return m_representer; + } + + /** + * Get the vectorial representation of this sample + */ + const VectorType& GetSampleVector() const { + return m_sampleVector; + } + + /** + * Returns the sample in the representation given by the representer + * \warning This method generates a new object containing the sample. If the Representer does not provide a smart pointer, the user is responsible for releasing memory. + */ + const DatasetPointerType GetSample() const { + return m_representer->SampleVectorToSample(m_sampleVector); + } + + protected: + + DataItem(const RepresenterType* representer, const std::string& URI, const VectorType& sampleVector) + : m_representer(representer), m_URI(URI), m_sampleVector(sampleVector) { + } + + DataItem(const RepresenterType* representer) : m_representer(representer) { + } + + // loads the internal state from the hdf5 file + virtual void LoadInternal(const H5::Group& dsGroup) { + VectorType v; + HDF5Utils::readVector(dsGroup, "./samplevector", m_sampleVector); + m_URI = HDF5Utils::readString(dsGroup, "./URI"); + } + + virtual void SaveInternal(const H5::Group& dsGroup) const { + HDF5Utils::writeVector(dsGroup, "./samplevector", m_sampleVector); + HDF5Utils::writeString(dsGroup, "./URI", m_URI); + } + + + const RepresenterType* m_representer; + std::string m_URI; + VectorType m_sampleVector; +}; + + + + +/* \class DataItemWithSurrogates + * \brief Holds all the information for a given sample. + * Use GetSample() to recover a Sample + * \warning This method generates a new object containing the sample. If the Representer does not provide a smart pointer, the user is responsible for releasing memory. + * In particular, it enables to associate categorical or continuous variables with a sample, in a vectorial representation. + * The vector is provided by a file providing the values in ascii format (empty space or EOL separating the values) + * \sa DataItem + * \sa DataManagerWithSurrogates + */ + +template +class DataItemWithSurrogates : public DataItem { + friend class DataItem; + typedef Representer RepresenterType; + + public: + + enum SurrogateType { + Categorical = 0, + Continuous = 1 + }; + + + typedef std::vector SurrogateTypeVectorType; + + + + + static DataItemWithSurrogates* Create(const RepresenterType* representer, + const std::string& datasetURI, + const VectorType& sampleVector, + const std::string& surrogateFilename, + const VectorType& surrogateVector) { + return new DataItemWithSurrogates(representer, datasetURI, sampleVector, surrogateFilename, surrogateVector); + } + + + + + virtual ~DataItemWithSurrogates() {} + + const VectorType& GetSurrogateVector() const { + return m_surrogateVector; + } + const std::string& GetSurrogateFilename() const { + return m_surrogateFilename; + } + + private: + + DataItemWithSurrogates(const RepresenterType* representer, + const std::string& datasetURI, + const VectorType& sampleVector, + const std::string& surrogateFilename, + const VectorType& surrogateVector) + : DataItem(representer, datasetURI, sampleVector), + m_surrogateFilename(surrogateFilename), + m_surrogateVector(surrogateVector) { + } + + DataItemWithSurrogates(const RepresenterType* r) : DataItem(r) {} + + // loads the internal state from the hdf5 file + virtual void LoadInternal(const H5::Group& dsGroup) { + DataItem::LoadInternal(dsGroup); + VectorType v; + HDF5Utils::readVector(dsGroup, "./surrogateVector", this->m_surrogateVector); + m_surrogateFilename = HDF5Utils::readString(dsGroup, "./surrogateFilename"); + } + + virtual void SaveInternal(const H5::Group& dsGroup) const { + DataItem::SaveInternal(dsGroup); + HDF5Utils::writeVector(dsGroup, "./surrogateVector", this->m_surrogateVector); + HDF5Utils::writeString(dsGroup, "./surrogateFilename", this->m_surrogateFilename); + } + + std::string m_surrogateFilename; + VectorType m_surrogateVector; +}; + + +} // namespace statismo + +#include "DataItem.hxx" + +#endif // __SAMPLE_DATA_H + diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.hxx new file mode 100644 index 000000000..a547f5928 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataItem.hxx @@ -0,0 +1,76 @@ +/* + * DataItem.h + * + * Created by Marcel Luethi + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __SAMPLE_DATA_TXX +#define __SAMPLE_DATA_TXX + +#include "DataItem.h" + +namespace statismo { + + +template +DataItem* +DataItem::Load(const RepresenterType* representer, const H5::Group& dsGroup) { + VectorType dsVector; + std::string sampleType = HDF5Utils::readString(dsGroup, "./sampletype"); + DataItem* newSample = 0; + if (sampleType == "DataItem") { + newSample = new DataItem(representer); + } else if (sampleType == "DataItemWithSurrogates") { + newSample = new DataItemWithSurrogates(representer); + } else { + throw StatisticalModelException((std::string("Unknown sampletype in hdf5 group: ") +sampleType).c_str()); + } + newSample->LoadInternal(dsGroup); + return newSample; +} + +template +void +DataItem::Save(const H5::Group& dsGroup) const { + if (dynamic_cast* >(this) != 0) { + HDF5Utils::writeString(dsGroup, "./sampletype", "DataItemWithSurrogates"); + } else { + HDF5Utils::writeString(dsGroup, "./sampletype", "DataItem"); + } + SaveInternal(dsGroup); +} + +} + +#endif // __SAMPLE_DATA_TXX diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.h new file mode 100644 index 000000000..6f817a5f6 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.h @@ -0,0 +1,218 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS addINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __DATAMANAGER_H_ +#define __DATAMANAGER_H_ + +#include + +#include "Config.h" +#include "CommonTypes.h" +#include "DataItem.h" +#include "Exceptions.h" +#include "HDF5Utils.h" +#include "ModelInfo.h" +#include "Representer.h" +#include "StatismoUtils.h" + +namespace statismo { + +/** + * \brief Holds training and test data used for Crossvalidation + */ +template +class CrossValidationFold { + public: + typedef DataItem DataItemType; + typedef std::list DataItemListType; + + /*** + * Create an empty fold + */ + CrossValidationFold() { + } + ; + + /** + * Create a fold with the given trainingData and testingData + */ + CrossValidationFold(const DataItemListType& trainingData, + const DataItemListType& testingData) : + m_trainingData(trainingData), m_testingData(testingData) { + } + + /** + * Get a list holding the training data + */ + DataItemListType GetTrainingData() const { + return m_trainingData; + } + + /** + * Get a list holding the testing data + */ + DataItemListType GetTestingData() const { + return m_testingData; + } + + private: + DataItemListType m_trainingData; + DataItemListType m_testingData; +}; + +/** + * \brief Manages Training and Test Data for building Statistical Models and provides functionality for Crossvalidation. + * + * The DataManager class provides functionality for loading and managing data sets to be used in the + * statistical model. The datasets are loaded either by using DataManager::AddDataset or directly from a hdf5 File using + * the Load function. Per default all the datasets are marked as training data. It is, however, often useful + * to leave a few datasets out to validate the model. For this purpose, the DataManager class implements basic + * crossvalidation functionality. + * + * For efficiency purposes, the data is internally stored as a large matrix, using the internal SampleVector representation. + * Furthermore, Statismo emphasizes on traceability, and ties information with the datasets, such as the original filename. + * This means that when accessing the data stored in the DataManager, one gets a DataItem structure + * \sa Representer + * \sa DataItem + */ +template +class DataManager { + + public: + + typedef Representer RepresenterType; + typedef typename RepresenterType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterType::DatasetConstPointerType DatasetConstPointerType; + + typedef DataItem DataItemType; + typedef DataItemWithSurrogates DataItemWithSurrogatesType; + typedef std::list DataItemListType; + typedef CrossValidationFold CrossValidationFoldType; + typedef std::list CrossValidationFoldListType; + + /** + * Factory method that creates a new instance of a DataManager class + * + */ + static DataManager* Create(const RepresenterType* representer) { + return new DataManager(representer); + } + + /** + * Create a new dataManager, with the data stored in the given hdf5 file + */ + static DataManager* Load(Representer* representer, + const std::string& filename); + + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + /** + * Destructor + */ + virtual ~DataManager(); + + /** + * Add a dataset to the data manager. + * \param dataset the dataset to be added + * \param URI A string containing the URI of the given dataset. This is only added as an info to the metadata. + * + * While it is not strictly necessary, and sometimes not even possible, to specify a URI for the given dataset, + * it is strongly encouraged to add a description. The string will be added to the metadata and stored with the model. + * Having this information stored with the model may prove valuable at a later point in time. + */ + virtual void AddDataset(DatasetConstPointerType dataset, + const std::string& URI); + + /** + * Saves the data matrix and all URIs into an HDF5 file. + * \param filename + */ + virtual void Save(const std::string& filename) const; + + /** + * return a list with all the sample data objects managed by the data manager + * \sa DataItem + */ + DataItemListType GetData() const; + + /** + * returns the number of samples managed by the datamanager + */ + unsigned GetNumberOfSamples() const { + return m_DataItemList.size(); + } + + /** + * Assigns the data to one of n Folds to be used for cross validation. + * This method has to be called before cross validation can be started. + * + * \param nFolds The number of folds used in the crossvalidation + * \param randomize If true, the data will be randomly assigned to the nfolds, otherwise the order with which it was added is preserved + */ + CrossValidationFoldListType GetCrossValidationFolds(unsigned nFolds, + bool randomize = true) const; + + /** + * Generates Leave-one-out cross validation folds + */ + CrossValidationFoldListType GetLeaveOneOutCrossValidationFolds() const; + + protected: + DataManager(const RepresenterType* representer); + + DataManager(const DataManager& orig); + DataManager& operator=(const DataManager& rhs); + + RepresenterType* m_representer; + + // members + DataItemListType m_DataItemList; +}; + +} + +#include "DataManager.hxx" + +#endif /* __DATAMANAGER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.hxx new file mode 100644 index 000000000..45981b737 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManager.hxx @@ -0,0 +1,303 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __DataManager_hxx +#define __DataManager_hxx + +#include + +#include "DataManager.h" +#include "HDF5Utils.h" + +namespace statismo { + +//////////////////////////////////////////////// +// Data manager +//////////////////////////////////////////////// + +template +DataManager::DataManager(const RepresenterType* representer) + : m_representer(representer->Clone()) { +} + +template +DataManager::~DataManager() { + for (typename DataItemListType::iterator it = + m_DataItemList.begin(); + it != m_DataItemList.end(); ++it) { + delete (*it); + } + m_DataItemList.clear(); + if (m_representer) { + m_representer->Delete(); + } + +} + + + +template +DataManager* +DataManager::Load(Representer* representer, + const std::string& filename) { + using namespace H5; + + DataManager* newDataManager = 0; + + H5File file; + try { + file = H5File(filename.c_str(), H5F_ACC_RDONLY); + } catch (H5::Exception& e) { + std::string msg( + std::string("could not open HDF5 file \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + + try { + // loading representer + + Group representerGroup = file.openGroup("./representer"); + std::string rep_name = HDF5Utils::readStringAttribute(representerGroup, "name"); + std::string repTypeStr = HDF5Utils::readStringAttribute(representerGroup, "datasetType"); + std::string versionStr = HDF5Utils::readStringAttribute(representerGroup, "version"); + typename RepresenterType::RepresenterDataType type = RepresenterType::TypeFromString(repTypeStr); + if (type == RepresenterType::CUSTOM || type == RepresenterType::UNKNOWN) { + if (rep_name != representer->GetName()) { + std::ostringstream os; + os << "A different representer was used to create the file and the representer is not of a standard type "; + os << ("(RepresenterName = ") << rep_name << " does not match required name = " << representer->GetName() << ")"; + os << "Cannot load hdf5 file"; + throw StatisticalModelException(os.str().c_str()); + } + if (versionStr != representer->GetVersion()) { + std::ostringstream os; + os << "The version of the representers do not match "; + os << ("(Version = ") << versionStr << " != = " << representer->GetVersion() << ")"; + os << "Cannot load hdf5 file"; + + } + + } + if (type != representer->GetType()) { + std::ostringstream os; + os << "The representer that was provided cannot be used to load the dataset "; + os << "(" << type << " != " << representer->GetType() << ")."; + os << "Cannot load hdf5 file."; + throw StatisticalModelException(os.str().c_str()); + } + + representer->Load(representerGroup); + representerGroup.close(); + newDataManager = new DataManager(representer); + + + Group publicGroup = file.openGroup("/data"); + unsigned numds = HDF5Utils::readInt(publicGroup, "./NumberOfDatasets"); + + for (unsigned num = 0; num < numds; num++) { + std::ostringstream ss; + ss << "./dataset-" << num; + + Group dsGroup = file.openGroup(ss.str().c_str()); + newDataManager->m_DataItemList.push_back( + DataItemType::Load(representer, dsGroup)); + + } + + } catch (H5::Exception& e) { + std::string msg( + std::string( + "an exception occurred while reading data matrix to HDF5 file \n") + + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + + file.close(); + + assert(newDataManager != 0); + return newDataManager; +} + +template +void DataManager::Save(const std::string& filename) const { + using namespace H5; + + assert(m_representer != 0); + + H5File file; + + try { + file = H5File(filename.c_str(), H5F_ACC_TRUNC); + } catch (H5::Exception& e) { + std::string msg( + std::string("Could not open HDF5 file for writing \n") + + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + + try { + + Group representerGroup = file.createGroup("./representer"); + std::string dataTypeStr = RepresenterType::TypeToString(m_representer->GetType()); + + HDF5Utils::writeStringAttribute(representerGroup, "name", m_representer->GetName()); + HDF5Utils::writeStringAttribute(representerGroup, "version", m_representer->GetVersion()); + HDF5Utils::writeStringAttribute(representerGroup, "datasetType", dataTypeStr); + + this->m_representer->Save(representerGroup); + representerGroup.close(); + + + Group publicGroup = file.createGroup("./data"); + HDF5Utils::writeInt(publicGroup, "./NumberOfDatasets", + this->m_DataItemList.size()); + + unsigned num = 0; + for (typename DataItemListType::const_iterator it = + this->m_DataItemList.begin(); + it != this->m_DataItemList.end(); ++it) { + std::ostringstream ss; + ss << "./dataset-" << num; + + Group dsGroup = file.createGroup(ss.str().c_str()); + + (*it)->Save(dsGroup); + + dsGroup.close(); + num++; + } + } catch (H5::Exception& e) { + std::string msg( + std::string( + "an exception occurred while writing data matrix to HDF5 file \n") + + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + file.close(); +} + +template +void DataManager::AddDataset(DatasetConstPointerType dataset, + const std::string& URI) { + + DatasetPointerType sample; + sample = m_representer->CloneDataset(dataset); + + m_DataItemList.push_back( + DataItemType::Create(m_representer, URI, + m_representer->SampleToSampleVector(sample))); + m_representer->DeleteDataset(sample); +} + +template +typename DataManager::DataItemListType DataManager::GetData() const { + return m_DataItemList; +} + +template +typename DataManager::CrossValidationFoldListType DataManager::GetCrossValidationFolds( + unsigned nFolds, bool randomize) const { + if (nFolds <= 1 || nFolds > GetNumberOfSamples()) { + throw StatisticalModelException( + "Invalid number of folds specified in GetCrossValidationFolds"); + } + unsigned nElemsPerFold = GetNumberOfSamples() / nFolds; + + // we create a vector with as many entries as datasets. Each entry contains the + // fold the entry belongs to + std::vector batchAssignment(GetNumberOfSamples()); + + for (unsigned i = 0; i < GetNumberOfSamples(); i++) { + batchAssignment[i] = std::min(i / nElemsPerFold, nFolds); + } + + // randomly shuffle the vector + srand(time(0)); + if (randomize) { + std::random_shuffle(batchAssignment.begin(), batchAssignment.end()); + } + + // now we create the folds + CrossValidationFoldListType foldList; + for (unsigned currentFold = 0; currentFold < nFolds; currentFold++) { + DataItemListType trainingData; + DataItemListType testingData; + + unsigned sampleNum = 0; + for (typename DataItemListType::const_iterator it = + m_DataItemList.begin(); + it != m_DataItemList.end(); ++it) { + if (batchAssignment[sampleNum] != currentFold) { + trainingData.push_back(*it); + } else { + testingData.push_back(*it); + } + ++sampleNum; + } + CrossValidationFoldType fold(trainingData, testingData); + foldList.push_back(fold); + } + return foldList; +} + +template +typename DataManager::CrossValidationFoldListType DataManager::GetLeaveOneOutCrossValidationFolds() const { + CrossValidationFoldListType foldList; + for (unsigned currentFold = 0; currentFold < GetNumberOfSamples(); + currentFold++) { + DataItemListType trainingData; + DataItemListType testingData; + + unsigned sampleNum = 0; + for (typename DataItemListType::const_iterator it = + m_DataItemList.begin(); + it != m_DataItemList.end(); ++it, ++sampleNum) { + if (sampleNum == currentFold) { + testingData.push_back(*it); + } else { + trainingData.push_back(*it); + } + } + CrossValidationFoldType fold(trainingData, testingData); + foldList.push_back(fold); + } + return foldList; +} + +} // Namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.h new file mode 100644 index 000000000..ca0386bb6 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.h @@ -0,0 +1,143 @@ +/* + * DataManagerWithSurrogates.h + * + * Created by Marcel Luethi and Remi Blanc + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __DATAMANAGERWITHSURROGATES_H_ +#define __DATAMANAGERWITHSURROGATES_H_ + +#include "DataManager.h" + +namespace statismo { + + +/** + * \brief Manages Training and Test Data for building Statistical Models and provides functionality for Crossvalidation. + * Manages data together with surrogate information. + * The surrogate variables are provided through a vector (see DataManager), and can contain both continuous or categorical data. + * The surrogate data is provided through files. One file for each dataset, and one file describing the types of surrogates. This file is also an ascii file + * with space or EOL separated values. Those values are either 0 or 1, standing for respectively categorical or continuous variable. + * This class does not support any missing data, so each dataset must come with a surrogate data file, all of which must contain the same number of entries as the type-file. + * \sa DataManager + */ +template +class DataManagerWithSurrogates : public DataManager { + + public: + + typedef Representer RepresenterType; + + typedef typename RepresenterType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterType::DatasetConstPointerType DatasetConstPointerType; + + + typedef DataItemWithSurrogates DataItemWithSurrogatesType; + + typedef typename DataItemWithSurrogatesType::SurrogateTypeVectorType SurrogateTypeVectorType; + + struct SurrogateTypeInfoType { + SurrogateTypeVectorType types; + std::string typeFilename; + }; + + + /** + * Destructor + */ + virtual ~DataManagerWithSurrogates() {} + + + /** + * Factory method that creates a new instance of a DataManager class + * + */ + static DataManagerWithSurrogates* Create(const RepresenterType* representer, const std::string& surrogTypeFilename) { + return new DataManagerWithSurrogates(representer, surrogTypeFilename); + } + + + + + /** + * Add a dataset, together with surrogate information + * \param datasetFilename + * \param datasetURI (An URI for the dataset. This info is only added to the metadata). + * \param surrogateFilename + */ + void AddDatasetWithSurrogates(DatasetConstPointerType ds, + const std::string& datasetURI, + const std::string& surrogateFilename); + + /** + * Get a vector indicating the types of surrogates variables (Categorical vs Continuous) + */ + SurrogateTypeVectorType GetSurrogateTypes() const { + return m_typeInfo.types; + } + + /** Returns the source filename defining the surrogate types */ + std::string GetSurrogateTypeFilename() const { + return m_typeInfo.typeFilename; + } + + /** Get a structure containing the type info: vector of types, and source filename */ + SurrogateTypeInfoType GetSurrogateTypeInfo() const { + return m_typeInfo; + } + + protected: + + /** + * Loads the information concerning the types of the surrogates variables (categorical=0, continuous=1) + * => it is assumed to be in a text file with the entries separated by spaces or EOL character + */ + void LoadSurrogateTypes(const std::string& filename); + + + + // private - to prevent use + DataManagerWithSurrogates(const RepresenterType* r, const std::string& filename); + + DataManagerWithSurrogates(const DataManagerWithSurrogates& orig); + DataManagerWithSurrogates& operator=(const DataManagerWithSurrogates& rhs); + + SurrogateTypeInfoType m_typeInfo; +}; + +} + +#include "DataManagerWithSurrogates.hxx" + +#endif /* __DATAMANAGERWITHSURROGATES_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.hxx new file mode 100644 index 000000000..d9b04dfd7 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/DataManagerWithSurrogates.hxx @@ -0,0 +1,106 @@ +/* + * DataManagerWithSurrogates.hxx + * + * Created by: Marcel Luethi and Remi Blanc + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __DataManagerWithSurrogates_hxx +#define __DataManagerWithSurrogates_hxx + +#include "DataManagerWithSurrogates.h" + +#include + +#include "HDF5Utils.h" + +namespace statismo { + + +//////////////////////////////////////////////// +// Data manager With Surrogates +//////////////////////////////////////////////// + + +template +DataManagerWithSurrogates::DataManagerWithSurrogates(const RepresenterType* representer, const std::string& filename) + : DataManager(representer) { + LoadSurrogateTypes(filename); +} + + +template +void +DataManagerWithSurrogates::LoadSurrogateTypes(const std::string& filename) { + VectorType tmpVector; + tmpVector = Utils::ReadVectorFromTxtFile(filename.c_str()); + m_typeInfo.typeFilename = filename; + m_typeInfo.types.clear(); + for (unsigned i=0 ; i +void +DataManagerWithSurrogates::AddDatasetWithSurrogates(DatasetConstPointerType ds, + const std::string& datasetURI, + const std::string& surrogateFilename) { + + + //assert(this->m_representer != 0); + //assert(this->m_surrogateTypes.size() > 0); + assert(this->m_representer != 0); + + const VectorType& surrogateVector = Utils::ReadVectorFromTxtFile(surrogateFilename.c_str()); + + if (static_cast(surrogateVector.size()) != m_typeInfo.types.size() ) throw StatisticalModelException("Trying to loading a dataset with unexpected number of surrogates"); + + DatasetPointerType sample; + sample = this->m_representer->CloneDataset(ds); + + this->m_DataItemList.push_back(DataItemWithSurrogatesType::Create(this->m_representer, + datasetURI, + this->m_representer->SampleToSampleVector(sample), + surrogateFilename, + surrogateVector)); + this->m_representer->DeleteDataset(sample); +} + + +} // Namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Domain.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Domain.h new file mode 100644 index 000000000..fa630e707 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Domain.h @@ -0,0 +1,83 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __DOMAIN_H +#define __DOMAIN_H + +#include + +namespace statismo { + +/** + * This class represents the domain on which a statistical model is defined. + * A domain is simply a list of points. + */ +//RB: enable adding / removing elements to avoid copying data? +template +class Domain { + public: + typedef std::vector DomainPointsListType; + + /** + * Create an empty domain + */ + Domain() {} + + /** + * Create a new domain from the given list of points + */ + Domain(const DomainPointsListType& domainPoints) + : m_domainPoints(domainPoints) {} + + /** Returns a list of points that define the domain */ + const DomainPointsListType& GetDomainPoints() const { + return m_domainPoints; + } + + /** Returns the number of poitns of the domain */ + const unsigned GetNumberOfPoints() const { + return m_domainPoints.size(); + } + + private: + + DomainPointsListType m_domainPoints; +}; + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Exceptions.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Exceptions.h new file mode 100644 index 000000000..36b68db96 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Exceptions.h @@ -0,0 +1,82 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef __COME_SSM_EXCEPTIONS__ +#define __COME_SSM_EXCEPTIONS__ + +#include +#include + +namespace statismo { + +/** + * \brief Used to indicate that a method has not yet been implemented + */ +class NotImplementedException : public std::exception { + public: + NotImplementedException(const char* classname, const char* methodname) + :m_classname(classname), m_methodname(methodname) { + } + virtual ~NotImplementedException() throw() {} + + const char* what() const throw() { + return (m_classname + "::" +m_methodname).c_str(); + } + private: + std::string m_classname; + std::string m_methodname; +}; + +/** + * \brief Generic Exception class for the statismo Library. + */ +class StatisticalModelException : public std::exception { + public: + StatisticalModelException(const char* message) : m_message(message) {} + virtual ~StatisticalModelException() throw() {} + const char* what() const throw() { + return m_message.c_str(); + } + + private: + std::string m_message; +}; + +} + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.h new file mode 100644 index 000000000..fdeb87fff --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.h @@ -0,0 +1,273 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef HDF5UTILS_H_ +#define HDF5UTILS_H_ + +#include "CommonTypes.h" +#include "itk_H5Cpp.h" + +/*namespace H5 { +class H5Location; +class H5Location; +class Group; +class H5File; +class H5Object; +class DataSet; +}*/ + +namespace statismo { + +/** + * \brief Utility methods to read and store common types to a HDF5 File. + */ +class HDF5Utils { + public: + + + + /** + * Opens the hdf5 file with the given name, or creates it if the file does not exist + */ + static H5::H5File openOrCreateFile(const std::string filename); + + + /** + * Opens the hdf5 group or creates it if it doesn't exist. + * @param a file object + * @param path An absolute path that defines a group + * @param createPath if true, creates the path if it does not exist + * + * @return the group object representing the path in the hdf5 file + */ + static H5::Group openPath(H5::H5File& fg, const std::string& path, bool createPath=false); + + /** + * Read a Matrix from a HDF5 File + * @param fg The group + * @param name the name of the entry + * @param the output matrix + */ + static void readMatrix(const H5::H5Location& fg, const char* name, MatrixType& matrix); + + /** + * Read a submatrix from the file, with the given number of Columns + * @param fg The group + * @param name the name of the entry + * @param nCols the number of columns to be read + * @param the output matrix + */ + static void readMatrix(const H5::H5Location& fg, const char* name, unsigned nCols, MatrixType& matrix); + + /** + * Read a Matrix of a given type from a HDF5 File + * @param fg The group + * @param name the name of the entry + * @param the output matrix + */ + template + static void readMatrixOfType(const H5::H5Location& fg, const char* name, typename GenericEigenType::MatrixType& matrix); + + /** + * Write a Matrix to the HDF5 File + * @param fg The group + * @param name the name of the entry + * @param the matrix to be written + */ + static H5::DataSet writeMatrix(const H5::H5Location& fg, const char* name, const MatrixType& matrix); + + /** + * Write a Matrix of the given type to the HDF5 File + * @param fg The group + * @param name the name of the entry + * @param the matrix to be written + */ + template + static H5::DataSet writeMatrixOfType(const H5::H5Location& fg, const char* name, const typename GenericEigenType::MatrixType& matrix); + + + /** + * Read a Vector from a HDF5 File with the given number of elements + * @param fg The group + * @param name the name of the entry + * @param numElements The number of elements to be read from the file + * @param the output vector + */ + static void readVector(const H5::H5Location& fg, const char* name, unsigned nElements, VectorType& vector); + + /** + * Read a Vector from a HDF5 File + * @param fg The group + * @param name the name of the entry + * @param numElements The number of elements to be read from the file + * @param the output vector + */ + static void readVector(const H5::H5Location& fg, const char* name, VectorType& vector); + + template + static void readVectorOfType(const H5::H5Location& fg, const char* name, typename GenericEigenType::VectorType& vector); + + /** + * Write a vector to the HDF5 File + * @param fg The hdf5 group + * @param name the name of the entry + * @param the vector to be written + */ + static H5::DataSet writeVector(const H5::H5Location& fg, const char* name, const VectorType& vector); + + template + static H5::DataSet writeVectorOfType(const H5::H5Location& fg, const char* name, const typename GenericEigenType::VectorType& vector); + + + /** + * Reads a file (in binary mode) and saves it as a byte array in the hdf5 file. + * @param filename The filename of the file to be stored + * @param fg The hdf5 group + * @param name The name of the entry + */ + static void dumpFileToHDF5( const char* filename, const H5::H5Location& fg, const char* name); + + /** + * Reads an entry from an HDF5 byte array and writes it to a file + * @param fg The hdf5 group + * @param name the name of the entry + * @param filename The filename where the data from the HDF5 file is stored. + */ + static void getFileFromHDF5(const H5::H5Location& fg, const char* name, const char* filename); + + /** Writes a string to the hdf5 file + * @param fg The hdf5 group + * @param name The name of the entry in the group + * @param s The string to be written + */ + static H5::DataSet writeString(const H5::H5Location& fg, const char* name, const std::string& s); + + /** Reads a string from the given group + * @param group the hdf5 group + * @param name the name of the entry in the group + * @return the string + */ + static std::string readString(const H5::H5Location& fg, const char* name); + + /** Writes a string attribute for the given group + * @param fg The hdf5 group + * @param name The name of the entry in the group + * @param s The string to be written + */ + static void writeStringAttribute(const H5::H5Object& group, const char* name, const std::string& s); + + /** Writes an int attribute for the given group + * @param fg The hdf5 group + * @param name The name of the entry in the group + * @param value the int value to be written + */ + static void writeIntAttribute(const H5::H5Object& fg, const char* name, int value); + + + + /** Reads a string attribute from the given group + * @param group the hdf5 group + * @param name the name of the entry in the group + * @return the value + */ + static std::string readStringAttribute(const H5::H5Object& group, const char* name); + + /** Reads a int attribute from the given group + * @param group the hdf5 group + * @param name the name of the entry in the group + * @return the value + */ + static int readIntAttribute(const H5::H5Object& group, const char* name); + + + /** Reads an integer from the hdf5 file + * @param fg The hdf5 group + * @param name The name + * @returns the integeter + */ + static int readInt(const H5::H5Location& fg, const char* name); + + /** Writes an integer to the hdf5 file + * @param fg The hdf5 group + * @param name The name + * @param value The value to be written + */ + static H5::DataSet writeInt(const H5::H5Location& fg, const char* name, int value); + + /** Reads an dobule from the hdf5 file + * @param fg The hdf5 group + * @param name The name + * @returns the read number + */ + static float readFloat(const H5::H5Location& fg, const char* name); + + /** Writes an double to the hdf5 file + * @param fg The hdf5 group + * @param name The name + * @param value The value to be written + */ + static H5::DataSet writeFloat(const H5::H5Location& fg, const char* name, float value); + + /** Reads an array from the hdf5 group + * @param fg The hdf5 group + * @param name The name + * @param array The array (type std::vector) to be read, contents will be lost + */ + template + static void readArray(const H5::H5Location& fg, const char* name, std::vector & array); + + /** Writes an array to the hdf5 group + * @param fg The hdf5 group + * @param name The name + * @param array The array (type std::vector) to be written + */ + template + static H5::DataSet writeArray(const H5::H5Location& fg, const char* name, std::vector const& array ); + + + /** Check whether an object (direct child) of fg with the given name exists + */ + static bool existsObjectWithName(const H5::H5Location& fg, const std::string& name); + +}; + +} // namespace statismo + +#include "HDF5Utils.hxx" + +#endif /* HDF5UTILS_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.hxx new file mode 100644 index 000000000..b79a39b90 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/HDF5Utils.hxx @@ -0,0 +1,556 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __HDF5_UTILS_CXX +#define __HDF5_UTILS_CXX + +#include "HDF5Utils.h" + +#include +#include +#include +#include + +#include "CommonTypes.h" +#include "Exceptions.h" +#include "itk_H5Cpp.h" + +namespace statismo { + +inline +H5::H5File +HDF5Utils::openOrCreateFile(const std::string filename) { + + // check if file exists + std::ifstream ifile(filename.c_str()); + H5::H5File file; + + if (!ifile) { + // create it + file = H5::H5File( filename.c_str(), H5F_ACC_EXCL); + } else { + // open it + file = H5::H5File( filename.c_str(), H5F_ACC_RDWR); + } + return file; +} + + + +inline +H5::Group +HDF5Utils::openPath(H5::H5File& file, const std::string& path, bool createPath) { + H5::Group group; + + // take the first part of the path + size_t curpos = 1; + size_t nextpos = path.find_first_of("/", curpos); + H5::Group g = file.openGroup("/"); + + std::string name = path.substr(curpos, nextpos-1); + + while (curpos != std::string::npos && name != "") { + + if (existsObjectWithName(g, name)) { + g = g.openGroup(name); + } else { + if (createPath) { + g = g.createGroup(name); + } else { + std::string msg = std::string("the path ") +path +" does not exist"; + throw StatisticalModelException(msg.c_str()); + } + } + + curpos = nextpos+1; + nextpos = path.find_first_of("/", curpos); + if ( nextpos != std::string::npos ) + name = path.substr(curpos, nextpos-curpos); + else + name = path.substr(curpos); + } + + return g; +} + +template +inline +void HDF5Utils::readMatrixOfType(const H5::H5Location& fg, const char* name, typename GenericEigenType::MatrixType& matrix) { + throw StatisticalModelException("Invalid type proided for writeMatrixOfType"); +} + +template <> +inline +void HDF5Utils::readMatrixOfType(const H5::H5Location& fg, const char* name, GenericEigenType::MatrixType& matrix) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[2]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + + // simply read the whole dataspace + matrix.resize(dims[0], dims[1]); + ds.read(matrix.data(), H5::PredType::NATIVE_UINT); +} + +template <> +inline +void HDF5Utils::readMatrixOfType(const H5::H5Location& fg, const char* name, GenericEigenType::MatrixType& matrix) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[2]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + + // simply read the whole dataspace + matrix.resize(dims[0], dims[1]); + ds.read(matrix.data(), H5::PredType::NATIVE_FLOAT); +} + +template <> +inline +void HDF5Utils::readMatrixOfType(const H5::H5Location& fg, const char* name, GenericEigenType::MatrixType& matrix) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[2]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + + // simply read the whole dataspace + matrix.resize(dims[0], dims[1]); + ds.read(matrix.data(), H5::PredType::NATIVE_DOUBLE); +} + + +inline +void HDF5Utils::readMatrix(const H5::H5Location& fg, const char* name, MatrixType& matrix) { + readMatrixOfType(fg, name, matrix); +} + + +inline +void HDF5Utils::readMatrix(const H5::H5Location& fg, const char* name, unsigned maxNumColumns, MatrixType& matrix) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[2]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + + hsize_t nRows = dims[0]; // take the number of rows defined in the hdf5 file + hsize_t nCols = std::min(dims[1], static_cast(maxNumColumns)); // take the number of cols provided by the user + + hsize_t offset[2] = {0,0}; // hyperslab offset in the file + hsize_t count[2]; + count[0] = nRows; + count[1] = nCols; + + H5::DataSpace dataspace = ds.getSpace(); + dataspace.selectHyperslab( H5S_SELECT_SET, count, offset ); + + /* Define the memory dataspace. */ + hsize_t dimsm[2]; + dimsm[0] = nRows; + dimsm[1] = nCols; + H5::DataSpace memspace( 2, dimsm ); + + /* Define memory hyperslab. */ + hsize_t offset_out[2] = {0, 0}; // hyperslab offset in memory + hsize_t count_out[2]; // size of the hyperslab in memory + + count_out[0] = nRows; + count_out[1] = nCols; + memspace.selectHyperslab( H5S_SELECT_SET, count_out, offset_out ); + + matrix.resize(nRows, nCols); + // ds.read(matrix.data(), H5::PredType::NATIVE_FLOAT, memspace, dataspace); + ds.read(matrix.data(), H5::PredType::NATIVE_DOUBLE, memspace, dataspace); + +} + +template +inline +H5::DataSet HDF5Utils::writeMatrixOfType(const H5::H5Location& fg, const char* name, const typename GenericEigenType::MatrixType& matrix) { + throw StatisticalModelException("Invalid type proided for writeMatrixOfType"); +} + +template <> +inline +H5::DataSet HDF5Utils::writeMatrixOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::MatrixType& matrix) { + // HDF5 does not like empty matrices. + // + if (matrix.rows() == 0 || matrix.cols() == 0) { + throw StatisticalModelException("Empty matrix provided to writeMatrix"); + } + + hsize_t dims[2] = {static_cast(matrix.rows()), static_cast(matrix.cols())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_UINT, H5::DataSpace(2, dims)); + ds.write( matrix.data(), H5::PredType::NATIVE_UINT ); + return ds; +} + +template <> +inline +H5::DataSet HDF5Utils::writeMatrixOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::MatrixType& matrix) { + // HDF5 does not like empty matrices. + // + if (matrix.rows() == 0 || matrix.cols() == 0) { + throw StatisticalModelException("Empty matrix provided to writeMatrix"); + } + + hsize_t dims[2] = {static_cast(matrix.rows()), static_cast(matrix.cols())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_FLOAT, H5::DataSpace(2, dims)); + ds.write( matrix.data(), H5::PredType::NATIVE_FLOAT ); + return ds; +} + +template <> +inline +H5::DataSet HDF5Utils::writeMatrixOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::MatrixType& matrix) { + // HDF5 does not like empty matrices. + // + if (matrix.rows() == 0 || matrix.cols() == 0) { + throw StatisticalModelException("Empty matrix provided to writeMatrix"); + } + + hsize_t dims[2] = {static_cast(matrix.rows()), static_cast(matrix.cols())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_DOUBLE, H5::DataSpace(2, dims)); + ds.write( matrix.data(), H5::PredType::NATIVE_DOUBLE ); + return ds; +} + + +inline +H5::DataSet HDF5Utils::writeMatrix(const H5::H5Location& fg, const char* name, const MatrixType& matrix) { + return writeMatrixOfType(fg, name, matrix); +} + + +template +inline +void HDF5Utils::readVectorOfType(const H5::H5Location& fg, const char* name, typename GenericEigenType::VectorType& vector) { + throw StatisticalModelException("Invalid type proided for readVectorOfType"); +} + +template <> +inline +void HDF5Utils::readVectorOfType(const H5::H5Location& fg, const char* name, GenericEigenType::VectorType& vector) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + vector.resize(dims[0], 1); + ds.read(vector.data(), H5::PredType::NATIVE_DOUBLE); +} + +template <> +inline +void HDF5Utils::readVectorOfType(const H5::H5Location& fg, const char* name, GenericEigenType::VectorType& vector) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + vector.resize(dims[0], 1); + ds.read(vector.data(), H5::PredType::NATIVE_FLOAT); +} + +template <> +inline +void HDF5Utils::readVectorOfType(const H5::H5Location& fg, const char* name, GenericEigenType::VectorType& vector) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + vector.resize(dims[0], 1); + ds.read(vector.data(), H5::PredType::NATIVE_INT); +} + +inline +void HDF5Utils::readVector(const H5::H5Location& fg, const char* name, VectorType& vector) { + readVectorOfType(fg, name, vector); +} + + +inline +void HDF5Utils::readVector(const H5::H5Location& fg, const char* name, unsigned maxNumElements, VectorType& vector) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + + hsize_t nElements = std::min(dims[0], static_cast(maxNumElements)); // take the number of rows defined in the hdf5 file + + hsize_t offset[1] = {0}; // hyperslab offset in the file + hsize_t count[1]; + count[0] = nElements; + + H5::DataSpace dataspace = ds.getSpace(); + dataspace.selectHyperslab( H5S_SELECT_SET, count, offset ); + + /* Define the memory dataspace. */ + hsize_t dimsm[1]; + dimsm[0] = nElements; + H5::DataSpace memspace( 1, dimsm ); + + /* Define memory hyperslab. */ + hsize_t offset_out[1] = {0}; // hyperslab offset in memory + hsize_t count_out[1]; // size of the hyperslab in memory + + count_out[0] = nElements; + memspace.selectHyperslab( H5S_SELECT_SET, count_out, offset_out ); + + vector.resize(nElements); + HDF5Utils::readVector(fg, name, vector); +} + + + + +template +inline +H5::DataSet HDF5Utils::writeVectorOfType(const H5::H5Location& fg, const char* name, const typename GenericEigenType::VectorType& vector) { + throw StatisticalModelException("Invalid type provided for writeVectorOfType"); +} + +template <> +inline +H5::DataSet HDF5Utils::writeVectorOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::VectorType& vector) { + hsize_t dims[1] = {static_cast(vector.size())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_DOUBLE, H5::DataSpace(1, dims)); + ds.write( vector.data(), H5::PredType::NATIVE_DOUBLE ); + return ds; +} + +template <> +inline +H5::DataSet HDF5Utils::writeVectorOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::VectorType& vector) { + hsize_t dims[1] = {static_cast(vector.size())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_FLOAT, H5::DataSpace(1, dims)); + ds.write( vector.data(), H5::PredType::NATIVE_FLOAT ); + return ds; +} + +template <> +inline +H5::DataSet HDF5Utils::writeVectorOfType(const H5::H5Location& fg, const char* name, const GenericEigenType::VectorType& vector) { + hsize_t dims[1] = {static_cast(vector.size())}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_INT, H5::DataSpace(1, dims)); + ds.write( vector.data(), H5::PredType::NATIVE_INT ); + return ds; +} + +inline +H5::DataSet HDF5Utils::writeVector(const H5::H5Location& fg, const char* name, const VectorType& vector) { + return writeVectorOfType(fg, name, vector); +} + + +inline +H5::DataSet HDF5Utils::writeString(const H5::H5Location& fg, const char* name, const std::string& s) { + H5::StrType fls_type(H5::PredType::C_S1, s.length() + 1); // + 1 for trailing zero + H5::DataSet ds = fg.createDataSet(name, fls_type, H5::DataSpace(H5S_SCALAR)); + ds.write(s, fls_type); + return ds; +} + + +inline +std::string +HDF5Utils::readString(const H5::H5Location& fg, const char* name) { + H5std_string outputString; + H5::DataSet ds = fg.openDataSet(name); + ds.read(outputString, ds.getStrType()); + return outputString; +} + +inline +void HDF5Utils::writeStringAttribute(const H5::H5Object& fg, const char* name, const std::string& s) { + H5::StrType strdatatype(H5::PredType::C_S1, s.length() + 1 ); // + 1 for trailing 0 + H5::Attribute att = fg.createAttribute(name, strdatatype, H5::DataSpace(H5S_SCALAR)); + att.write(strdatatype, s); + att.close(); +} + + +inline +std::string +HDF5Utils::readStringAttribute(const H5::H5Object& fg, const char* name) { + H5std_string outputString; + + H5::Attribute myatt_out = fg.openAttribute(name); + myatt_out.read(myatt_out.getStrType(), outputString); + return outputString; +} + +inline +void HDF5Utils::writeIntAttribute(const H5::H5Object& fg, const char* name, int value) { + H5::IntType int_type(H5::PredType::NATIVE_INT32); + H5::DataSpace att_space(H5S_SCALAR); + H5::Attribute att = fg.createAttribute(name, int_type, att_space ); + att.write( int_type, &value); + att.close(); +} + +inline +int +HDF5Utils::readIntAttribute(const H5::H5Object& fg, const char* name) { + H5::IntType fls_type(H5::PredType::NATIVE_INT32); + int value = 0; + H5::Attribute myatt_out = fg.openAttribute(name); + myatt_out.read(fls_type, &value); + return value; +} + + +inline +H5::DataSet HDF5Utils::writeInt(const H5::H5Location& fg, const char* name, int value) { + H5::IntType fls_type(H5::PredType::NATIVE_INT32); // 0 is a dummy argument + H5::DataSet ds = fg.createDataSet(name, fls_type, H5::DataSpace(H5S_SCALAR)); + ds.write(&value, fls_type); + return ds; +} + +inline +int HDF5Utils::readInt(const H5::H5Location& fg, const char* name) { + H5::IntType fls_type(H5::PredType::NATIVE_INT32); + H5::DataSet ds = fg.openDataSet( name ); + + int value = 0; + ds.read(&value, fls_type); + return value; +} + +inline +H5::DataSet HDF5Utils::writeFloat(const H5::H5Location& fg, const char* name, float value) { + H5::FloatType fls_type(H5::PredType::NATIVE_FLOAT); // 0 is a dummy argument + H5::DataSet ds = fg.createDataSet(name, fls_type, H5::DataSpace(H5S_SCALAR)); + ds.write(&value, fls_type); + return ds; +} + +inline +float HDF5Utils::readFloat(const H5::H5Location& fg, const char* name) { + H5::FloatType fls_type(H5::PredType::NATIVE_FLOAT); + H5::DataSet ds = fg.openDataSet( name ); + + float value = 0; + ds.read(&value, fls_type); + return value; +} + +inline +void HDF5Utils::getFileFromHDF5(const H5::H5Location& fg, const char* name, const char* filename) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + std::vector buffer(dims[0]); + if(!buffer.empty()) ds.read(&buffer[0], H5::PredType::NATIVE_CHAR); + + typedef std::ostream_iterator ostream_iterator; + std::ofstream ofile(filename, std::ios::binary); + if (!ofile) { + std::string s= std::string("could not open file ") +filename; + throw StatisticalModelException(s.c_str()); + } + + std::copy(buffer.begin(), buffer.end(), ostream_iterator(ofile)); + ofile.close(); +} + +inline +void +HDF5Utils::dumpFileToHDF5( const char* filename, const H5::H5Location& fg, const char* name) { + + typedef std::istream_iterator istream_iterator; + + std::ifstream ifile(filename, std::ios::binary); + if (!ifile) { + std::string s= std::string("could not open file ") +filename; + throw StatisticalModelException(s.c_str()); + } + + std::vector buffer; + ifile >> std::noskipws; + std::copy(istream_iterator(ifile), istream_iterator(), std::back_inserter(buffer)); + + ifile.close(); + + hsize_t dims[] = {buffer.size()}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_CHAR, H5::DataSpace(1, dims)); + ds.write( &buffer[0], H5::PredType::NATIVE_CHAR ); + +} + +template +inline +void +HDF5Utils::readArray(const H5::H5Location& fg, const char* name, std::vector & array) { + throw StatisticalModelException( "not implemented" ); +} + + +template +inline +H5::DataSet +HDF5Utils::writeArray(const H5::H5Location& fg, const char* name, std::vector const& array) { + throw StatisticalModelException( "not implemented" ); +} + +template<> +inline +void +HDF5Utils::readArray(const H5::H5Location& fg, const char* name, std::vector & array) { + H5::DataSet ds = fg.openDataSet( name ); + hsize_t dims[1]; + ds.getSpace().getSimpleExtentDims(dims, NULL); + array.resize(dims[0]); + ds.read( &array[0], H5::PredType::NATIVE_INT32); +} + +template<> +inline +H5::DataSet +HDF5Utils::writeArray(const H5::H5Location& fg, const char* name, std::vector const& array) { + hsize_t dims[1] = {array.size()}; + H5::DataSet ds = fg.createDataSet( name, H5::PredType::NATIVE_INT32, H5::DataSpace(1, dims)); + ds.write( &array[0], H5::PredType::NATIVE_INT32 ); + return ds; +} + +inline +bool +HDF5Utils::existsObjectWithName(const H5::H5Location& fg, const std::string& name) { + for (hsize_t i = 0; i < fg.getNumObjs(); ++i) { + std::string objname= fg.getObjnameByIdx(i); + if (objname == name) { + return true; + } + } + return false; +} + +} //namespace statismo + +#endif + diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/KernelCombinators.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/KernelCombinators.h new file mode 100644 index 000000000..18f69c0c7 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/KernelCombinators.h @@ -0,0 +1,310 @@ +/** + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * Thomas Gerig (thomas.gerig@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Statismo is licensed under the BSD licence (3 clause) license + */ + +#ifndef KERNELCOMBINATORS_H +#define KERNELCOMBINATORS_H + +#include +#include +#include + +#include "CommonTypes.h" +#include "Kernels.h" +#include "Nystrom.h" +#include "Representer.h" + +namespace statismo { + +/** + * A (matrix valued) kernel, which represents the sum of two matrix valued kernels. + */ +template +class SumKernel: public MatrixValuedKernel { + public: + + typedef MatrixValuedKernel MatrixValuedKernelType; + + + SumKernel(const MatrixValuedKernelType* lhs, + const MatrixValuedKernelType* rhs) : + MatrixValuedKernelType(lhs->GetDimension()), + m_lhs(lhs), + m_rhs(rhs) { + if (lhs->GetDimension() != rhs->GetDimension()) { + throw StatisticalModelException( + "Kernels in SumKernel must have the same dimensionality"); + } + } + + MatrixType operator()(const TPoint& x, const TPoint& y) const { + return (*m_lhs)(x, y) + (*m_rhs)(x, y); + } + + std::string GetKernelInfo() const { + std::ostringstream os; + os << m_lhs->GetKernelInfo() << " + " << m_rhs->GetKernelInfo(); + return os.str(); + } + + private: + const MatrixValuedKernelType* m_lhs; + const MatrixValuedKernelType* m_rhs; +}; + + + +/** + * A (matrix valued) kernel, which represents the product of two matrix valued kernels. + */ + +template +class ProductKernel: public MatrixValuedKernel { + + public: + + typedef MatrixValuedKernel MatrixValuedKernelType; + + ProductKernel(const MatrixValuedKernelType* lhs, + const MatrixValuedKernelType* rhs) : + MatrixValuedKernelType(lhs->GetDimension()), m_lhs(lhs), m_rhs( + rhs) { + if (lhs->GetDimension() != rhs->GetDimension()) { + throw StatisticalModelException( + "Kernels in SumKernel must have the same dimensionality"); + } + + } + + MatrixType operator()(const TPoint& x, const TPoint& y) const { + return (*m_lhs)(x, y) * (*m_rhs)(x, y); + } + + std::string GetKernelInfo() const { + std::ostringstream os; + os << m_lhs->GetKernelInfo() << " * " << m_rhs->GetKernelInfo(); + return os.str(); + } + + private: + const MatrixValuedKernelType* m_lhs; + const MatrixValuedKernelType* m_rhs; +}; + + +/** + * A (matrix valued) kernel, which represents a scalar multiple of a matrix valued kernel. + */ + +template +class ScaledKernel: public MatrixValuedKernel { + public: + + + typedef MatrixValuedKernel MatrixValuedKernelType; + + + ScaledKernel(const MatrixValuedKernelType* kernel, + double scalingFactor) : + MatrixValuedKernelType(kernel->GetDimension()), m_kernel(kernel), m_scalingFactor(scalingFactor) { + } + + MatrixType operator()(const TPoint& x, const TPoint& y) const { + return (*m_kernel)(x, y) * m_scalingFactor; + } + std::string GetKernelInfo() const { + std::ostringstream os; + os << (*m_kernel).GetKernelInfo() << " * " << m_scalingFactor; + return os.str(); + } + + private: + const MatrixValuedKernelType* m_kernel; + double m_scalingFactor; +}; + + +/** + * Takes a scalar valued kernel and creates a matrix valued kernel of the given dimension. + * The new kernel models the output components as independent, i.e. if K(x,y) is a scalar valued Kernel, + * the matrix valued kernel becomes Id*K(x,y), where Id is an identity matrix of dimensionality d. + */ +template +class UncorrelatedMatrixValuedKernel: public MatrixValuedKernel { + public: + + typedef MatrixValuedKernel MatrixValuedKernelType; + + UncorrelatedMatrixValuedKernel( + const ScalarValuedKernel* scalarKernel, + unsigned dimension) : + MatrixValuedKernelType( dimension), m_kernel(scalarKernel), + m_ident(MatrixType::Identity(dimension, dimension)) { + } + + MatrixType operator()(const TPoint& x, const TPoint& y) const { + + return m_ident * (*m_kernel)(x, y); + } + + virtual ~UncorrelatedMatrixValuedKernel() { + } + + std::string GetKernelInfo() const { + std::ostringstream os; + os << "UncorrelatedMatrixValuedKernel(" << (*m_kernel).GetKernelInfo() + << ", " << this->m_dimension << ")"; + return os.str(); + } + + private: + + const ScalarValuedKernel* m_kernel; + MatrixType m_ident; + +}; + + +/** + * Base class for defining a tempering function for the SpatiallyVaryingKernel + */ +template +class TemperingFunction { + public: + virtual double operator()(const TPoint& pt) const = 0; + virtual ~TemperingFunction() {} +}; + +/** + * spatially-varing kernel, as described in the paper: + * + * T. Gerig, K. Shahim, M. Reyes, T. Vetter, M. Luethi + * Spatially varying registration using gaussian processes + * Miccai 2014 + */ +template +class SpatiallyVaryingKernel : public MatrixValuedKernel::PointType> { + + typedef boost::unordered_map CacheType; + + public: + + typedef Representer RepresenterType; + typedef typename RepresenterType::PointType PointType; + + + /** + * @brief Make a given kernel spatially varying according to the given tempering function + * @param representer, A representer which defines the domain over which the approximation is done + * @param kernel The kernel that is made spatially adaptive + * @param eta The tempering function that defines the amount of tempering for each point in the domain + * @param numEigenfunctions The number of eigenfunctions to be used for the approximation + * @param numberOfPointsForApproximation The number of points used for the nystrom approximation + * @param cacheValues Cache result of eigenfunction computations. Greatly speeds up the computation. + */ + SpatiallyVaryingKernel(const RepresenterType* representer, const MatrixValuedKernel& kernel, const TemperingFunction& eta, unsigned numEigenfunctions, unsigned numberOfPointsForApproximation = 0, bool cacheValues = true) + : m_representer(representer), + m_eta(eta), + m_nystrom(Nystrom::Create(representer, kernel, numEigenfunctions, numberOfPointsForApproximation == 0 ? numEigenfunctions * 2 : numberOfPointsForApproximation)), + m_eigenvalues(m_nystrom->getEigenvalues()), + m_cacheValues(cacheValues), + MatrixValuedKernel(kernel.GetDimension()) { + } + + inline MatrixType operator()(const PointType& x, const PointType& y) const { + + MatrixType sum = MatrixType::Zero(this->m_dimension, this->m_dimension); + + float eta_x = m_eta(x); + float eta_y = m_eta(y); + + + statismo::MatrixType phisAtX = phiAtPoint(x); + statismo::MatrixType phisAtY = phiAtPoint(y); + + double largestTemperedEigenvalue = std::pow(m_eigenvalues(0), (eta_x + eta_y)/2); + + for (unsigned i = 0; i < m_eigenvalues.size(); ++i) { + + float temperedEigenvalue = std::pow(m_eigenvalues(i), (eta_x + eta_y)/2); + + // ignore too small eigenvalues, as they don't contribute much. + // (the eigenvalues are ordered, all the following are smaller and can also be ignored) + if (temperedEigenvalue / largestTemperedEigenvalue < 1e-6) { + break; + } else { + sum += phisAtX.col(i) * phisAtY.col(i).transpose() * temperedEigenvalue; + } + } + // normalize such that the largest eigenvalue is unaffected by the tempering + float normalizationFactor = largestTemperedEigenvalue / m_eigenvalues(0); + sum *= 1.0 / normalizationFactor; + return sum; + } + + + virtual ~SpatiallyVaryingKernel() { + } + + std::string GetKernelInfo() const { + std::ostringstream os; + os << "SpatiallyVaryingKernel"; + return os.str(); + } + + + + + private: + + // returns a d x n matrix holding the value of all n eigenfunctions evaluated at the given point. + const statismo::MatrixType phiAtPoint(const PointType& pt) const { + + statismo::MatrixType v; + if (m_cacheValues) { + // we need to convert the point to a vector, as the function hash_value (required by boost) + // is not defined for an arbitrary point. + const VectorType ptAsVec = this->m_representer->PointToVector(pt); + _phiCacheLock.lock(); + typename CacheType::const_iterator got = m_phiCache.find (ptAsVec); + _phiCacheLock.unlock(); + if (got == m_phiCache.end()) { + v = m_nystrom->computeEigenfunctionsAtPoint(pt); + _phiCacheLock.lock(); + m_phiCache.insert(std::make_pair(ptAsVec, v)); + _phiCacheLock.unlock(); + } else { + v = got->second; + } + } else { + v = m_nystrom->computeEigenfunctionsAtPoint(pt); + } + return v; + } + + + // + // members + + const RepresenterType* m_representer; + boost::scoped_ptr > m_nystrom; + statismo::VectorType m_eigenvalues; + const TemperingFunction& m_eta; + bool m_cacheValues; + mutable CacheType m_phiCache; + mutable boost::mutex _phiCacheLock; +}; + + + +} + +#endif // KERNELCOMBINATORS_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Kernels.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Kernels.h new file mode 100644 index 000000000..b9f54f832 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Kernels.h @@ -0,0 +1,131 @@ +/** + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Statismo is licensed under the BSD licence (3 clause) license + */ + + +#ifndef __KERNELS_H +#define __KERNELS_H + +#include + +#include +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "ModelInfo.h" +#include "Representer.h" +#include "StatisticalModel.h" + + +namespace statismo { + +/** + * Base class from which all ScalarValuedKernels derive. + */ +template +class ScalarValuedKernel { + public: + + /** + * Create a new scalar valued kernel. + */ + ScalarValuedKernel() { } + + virtual ~ScalarValuedKernel() { + } + + /** + * Evaluate the kernel function at the points x and y + */ + virtual double operator()(const TPoint& x, const TPoint& y) const = 0; + + /** + * Return a description of this kernel + */ + virtual std::string GetKernelInfo() const = 0; + +}; + + +/** + * Base class for all matrix valued kernels + */ +template +class MatrixValuedKernel { + public: + + /** + * Create a new MatrixValuedKernel + */ + MatrixValuedKernel(unsigned dim) : + m_dimension(dim) { + } + + /** + * Evaluate the kernel at the points x and y + */ + virtual MatrixType operator()(const TPoint& x, + const TPoint& y) const = 0; + + /** + * Return the dimensionality of the kernel (i.e. the size of the matrix) + */ + virtual unsigned GetDimension() const { + return m_dimension; + } + ; + virtual ~MatrixValuedKernel() { + } + + /** + * Return a description of this kernel. + */ + virtual std::string GetKernelInfo() const = 0; + + protected: + unsigned m_dimension; + +}; + +template +class StatisticalModelKernel: public MatrixValuedKernel::PointType > { + public: + + typedef Representer RepresenterType; + typedef typename RepresenterType::PointType PointType; + typedef StatisticalModel StatisticalModelType; + + StatisticalModelKernel(const StatisticalModelType* model) : + MatrixValuedKernel(model->GetRepresenter()->GetDimensions()), m_statisticalModel(model) { + } + + virtual ~StatisticalModelKernel() { + } + + inline MatrixType operator()(const PointType& x, const PointType& y) const { + MatrixType m = m_statisticalModel->GetCovarianceAtPoint(x, y); + return m; + } + + std::string GetKernelInfo() const { + return "StatisticalModelKernel"; + } + + private: + const StatisticalModelType* m_statisticalModel; +}; + + + +} // namespace statismo + +#endif // __KERNELS_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/LowRankGPModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/LowRankGPModelBuilder.h new file mode 100644 index 000000000..d3280f8ef --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/LowRankGPModelBuilder.h @@ -0,0 +1,289 @@ +/** + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Statismo is licensed under the BSD licence (3 clause) license + */ + + +#ifndef __LOW_RANK_GP_MODEL_BUILDER_H +#define __LOW_RANK_GP_MODEL_BUILDER_H + +#include + +#include + +#include +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "DataManager.h" +#include "Kernels.h" +#include "ModelInfo.h" +#include "ModelBuilder.h" +#include "Nystrom.h" +#include "Representer.h" +#include "StatisticalModel.h" + +namespace statismo { + + +/** + * This class holds the result of the eigenfunction computation for + * the points with index entries (lowerInd to upperInd) + */ +struct EigenfunctionComputationResult { + + + EigenfunctionComputationResult(unsigned _lowerInd, unsigned _upperInd, + const MatrixType& _resMat) : + lowerInd(_lowerInd), upperInd(_upperInd), resultForPoints(_resMat) { + } + + unsigned lowerInd; + unsigned upperInd; + MatrixType resultForPoints; + + // emulate move semantics, as boost::async seems to depend on it. + EigenfunctionComputationResult& operator=(BOOST_COPY_ASSIGN_REF(EigenfunctionComputationResult) rhs) { // Copy assignment + if (&rhs != this) { + copyMembers(rhs); + } + return *this; + } + + EigenfunctionComputationResult(BOOST_RV_REF(EigenfunctionComputationResult) that) { //Move constructor + copyMembers(that); + } + EigenfunctionComputationResult& operator=(BOOST_RV_REF(EigenfunctionComputationResult) rhs) { //Move assignment + if (&rhs != this) { + copyMembers(rhs); + } + return *this; + } + private: + BOOST_COPYABLE_AND_MOVABLE(EigenfunctionComputationResult) + void copyMembers(const EigenfunctionComputationResult& that) { + lowerInd = that.lowerInd; + upperInd = that.upperInd; + resultForPoints = that.resultForPoints; + } +}; + + +/** + * A model builder for building statistical models that are specified by an arbitrary Gaussian Process. + * For details on the theoretical basis for this type of model builder, see the paper + * + * A unified approach to shape model fitting and non-rigid registration + * Marcel Lüthi, Christoph Jud and Thomas Vetter + * IN: Proceedings of the 4th International Workshop on Machine Learning in Medical Imaging, + * LNCS 8184, pp.66-73 Nagoya, Japan, September 2013 + * + */ + +template +class LowRankGPModelBuilder: public ModelBuilder { + + public: + + typedef Representer RepresenterType; + typedef typename RepresenterType::PointType PointType; + + typedef ModelBuilder Superclass; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + + typedef Domain DomainType; + typedef typename DomainType::DomainPointsListType DomainPointsListType; + + typedef MatrixValuedKernel MatrixValuedKernelType; + + /** + * Factory method to create a new ModelBuilder + */ + static LowRankGPModelBuilder* Create(const RepresenterType* representer) { + return new LowRankGPModelBuilder(representer); + } + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + /** + * The desctructor + */ + virtual ~LowRankGPModelBuilder() { + } + + + /** + * Build a new model using a zero-mean Gaussian process with given kernel. + * \param kernel: A kernel (or covariance) function + * \param numComponents The number of components used for the low rank approximation. + * \param numPointsForNystrom The number of points used for the Nystrom approximation + * + * \return a new statistical model representing the given Gaussian process + */ + StatisticalModelType* BuildNewZeroMeanModel( + const MatrixValuedKernelType& kernel, unsigned numComponents, + unsigned numPointsForNystrom = 500) const { + + return BuildNewModel(m_representer->IdentitySample(), kernel, numComponents, + numPointsForNystrom); + } + + /** + * Build a new model using a Gaussian process with given mean and kernel. + * \param mean: A dataset that represents the mean (shape or deformation) + * \param kernel: A kernel (or covariance) function + * \param numComponents The number of components used for the low rank approximation. + * \param numPointsForNystrom The number of points used for the Nystrom approximation + * + * \return a new statistical model representing the given Gaussian process + */ + StatisticalModelType* BuildNewModel( + typename RepresenterType::DatasetConstPointerType mean, + const MatrixValuedKernelType& kernel, + unsigned numComponents, + unsigned numPointsForNystrom = 500) const { + + + std::vector domainPoints = m_representer->GetDomain().GetDomainPoints(); + unsigned numDomainPoints = m_representer->GetDomain().GetNumberOfPoints(); + unsigned kernelDim = kernel.GetDimension(); + + + boost::scoped_ptr > nystrom(Nystrom::Create(m_representer, kernel, numComponents, numPointsForNystrom)); + + // we precompute the value of the eigenfunction for each domain point + // and store it later in the pcaBasis matrix. In this way we obtain + // a standard statismo model. + // To save time, we parallelize over the rows + std::vector* > futvec; + + + unsigned numChunks = boost::thread::hardware_concurrency() + 1; + + for (unsigned i = 0; i <= numChunks; i++) { + + unsigned chunkSize = static_cast< unsigned >( ceil( static_cast< float >( numDomainPoints ) / static_cast< float >( numChunks ) ) ); + unsigned lowerInd = i * chunkSize; + unsigned upperInd = + std::min( static_cast< unsigned >(numDomainPoints), + (i + 1) * chunkSize); + + if (lowerInd >= upperInd) { + break; + } + + boost::future* fut = new boost::future( + boost::async(boost::launch::async, boost::bind(&LowRankGPModelBuilder::computeEigenfunctionsForPoints, + this, nystrom.get(), &kernel, numComponents, domainPoints, lowerInd, upperInd))); + futvec.push_back(fut); + } + + MatrixType pcaBasis = MatrixType::Zero(numDomainPoints * kernelDim, numComponents); + + // collect the result + for (unsigned i = 0; i < futvec.size(); i++) { + EigenfunctionComputationResult res = futvec[i]->get(); + pcaBasis.block(res.lowerInd * kernelDim, 0, + (res.upperInd - res.lowerInd) * kernelDim, pcaBasis.cols()) = + res.resultForPoints; + delete futvec[i]; + } + + + VectorType pcaVariance = nystrom->getEigenvalues(); + + RowVectorType mu = m_representer->SampleToSampleVector(mean); + + StatisticalModelType* model = StatisticalModelType::Create( + m_representer, mu, pcaBasis, pcaVariance, 0); + + // the model builder does not use any data. Hence the scores and the datainfo is emtpy + MatrixType scores; // no scores + typename BuilderInfo::DataInfoList dataInfo; + + + typename BuilderInfo::ParameterInfoList bi; + bi.push_back(BuilderInfo::KeyValuePair("NoiseVariance", Utils::toString(0))); + bi.push_back(BuilderInfo::KeyValuePair("KernelInfo", kernel.GetKernelInfo())); + + // finally add meta data to the model info + BuilderInfo builderInfo("LowRankGPModelBuilder", dataInfo, bi); + + ModelInfo::BuilderInfoList biList( 1, builderInfo );; + + ModelInfo info(scores, biList); + model->SetModelInfo(info); + + return model; + } + + + private: + + + + /* + * Compute the eigenfunction value at the poitns with index lowerInd - upperInd. + * Return a result object with the given values. + * This method is used to be able to parallelize the computations. + */ + EigenfunctionComputationResult computeEigenfunctionsForPoints( + const Nystrom* nystrom, + const MatrixValuedKernelType* kernel, unsigned numEigenfunctions, + const std::vector & domainPts, + unsigned lowerInd, unsigned upperInd) const { + + unsigned kernelDim = kernel->GetDimension(); + + assert(upperInd <= domainPts.size()); + + // holds the results of the computation + MatrixType resMat = MatrixType::Zero((upperInd - lowerInd) * kernelDim, + numEigenfunctions); + + // compute the nystrom extension for each point i in domainPts, for which + // i is in the right range + for (unsigned i = lowerInd; i < upperInd; i++) { + + PointType pti = domainPts[i]; + resMat.block((i - lowerInd) * kernelDim, 0, kernelDim, resMat.cols()) = nystrom->computeEigenfunctionsAtPoint(pti); + + } + return EigenfunctionComputationResult(lowerInd, upperInd, resMat); + } + + + + /** + * constructor - only used internally + */ + LowRankGPModelBuilder(const RepresenterType* representer) : + m_representer(representer) { + } + + // purposely not implemented + LowRankGPModelBuilder(const LowRankGPModelBuilder& orig); + LowRankGPModelBuilder& operator=(const LowRankGPModelBuilder& rhs); + + const RepresenterType* m_representer; + +}; + +} // namespace statismo + +#endif // __LOW_RANK_GP_MODEL_BUILDER_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelBuilder.h new file mode 100644 index 000000000..b52bccac0 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelBuilder.h @@ -0,0 +1,111 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __MODELBUILDER_H_ +#define __MODELBUILDER_H_ + +#include +#include + +#include "CommonTypes.h" +#include "DataManager.h" +#include "StatisticalModel.h" + +namespace statismo { + +/** + * \brief Common base class for all the model builder classes + */ +template +class ModelBuilder { + + public: + typedef Representer RepresenterType; + typedef StatisticalModel StatisticalModelType; + typedef DataManager DataManagerType; + typedef typename DataManagerType::DataItemListType DataItemListType; + + // Values below this tolerance are treated as 0. + static const double TOLERANCE; + + + protected: + + MatrixType ComputeScores(const MatrixType& X, const StatisticalModelType* model) const { + + MatrixType scores(model->GetNumberOfPrincipalComponents(), X.rows()); + for (unsigned i = 0; i < scores.cols(); i++) { + scores.col(i) = model->ComputeCoefficientsForSampleVector(X.row(i)); + } + return scores; + } + + + MatrixType ComputeScores(const DataItemListType& sampleDataList, const StatisticalModelType* model) const { + + unsigned n = sampleDataList.size(); + MatrixType scores(model->GetNumberOfPrincipalComponents(), n); + + unsigned i = 0; + for (typename DataItemListType::const_iterator it = sampleDataList.begin(); + it != sampleDataList.end(); ++it) { + // Todo: for sample or for dataset?? + scores.col(i++) = model->ComputeCoefficientsForSampleVector((*it)->GetSampleVector()); + } + return scores; + } + + + + ModelBuilder() {} + + ModelInfo CollectModelInfo() const; + + private: + // private - to prevent use + ModelBuilder(const ModelBuilder& orig); + ModelBuilder& operator=(const ModelBuilder& rhs); + +}; + +template +const double ModelBuilder::TOLERANCE = 1e-5; + +} // namespace statismo + + +#endif /* __MODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelInfo.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelInfo.h new file mode 100644 index 000000000..de542f2d6 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ModelInfo.h @@ -0,0 +1,192 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + + +#ifndef MODELINFO_H_ +#define MODELINFO_H_ + +#include + +#include + +#include "CommonTypes.h" +#include "itk_H5Cpp.h" + +namespace statismo { + +class BuilderInfo; // forward declaration + +/** + * \brief stores meta information about the model, such as e.g. the name (uri) of the datasets used to build the models, or specific parameters of the modelBuilders. + * + * The ModelInfo object stores the scores and a list of BuilderInfo objects. Each BuilderInfo contains the information (datasets, parameters) that were used to build the model. + * If n model builders had been used in succession to create a model, there will be a list of n builder objects. + * + */ +class ModelInfo { + public: + + typedef std::vector BuilderInfoList; + + /// create an new, empty model info object + ModelInfo(); + + /** + * Creates a new ModelInfo object with the given information + * \param scores A matrix holding the scores + * \param builderInfos A list of BuilderInfo objects + */ + ModelInfo(const MatrixType& scores, const BuilderInfoList& builderInfos); + + /** + * Create a ModelInfo object without specifying any BuilderInfos + * \param scores A matrix holding the scores + */ + ModelInfo(const MatrixType& scores); + + + /// destructor + virtual ~ModelInfo(); + + ModelInfo& operator=(const ModelInfo& rhs); + + ModelInfo(const ModelInfo& orig) { + operator=(orig); + } + + /** + * Returns a list with BuilderInfos + */ + BuilderInfoList GetBuilderInfoList() const; + + /** + * Returns the scores matrix. That is, a matrix where the i-th column corresponds to the + * coefficients of the i-th dataset in the model + */ + const MatrixType& GetScoresMatrix() const; + + /** + * Saves the model info to the given group in the HDF5 file + */ + virtual void Save(const H5::H5Location& publicFg) const; + + /** + * Loads the model info from the given group in the HDF5 file. + */ + virtual void Load(const H5::H5Location& publicFg); + + + private: + + BuilderInfo LoadDataInfoOldStatismoFormat(const H5::H5Location& publicFg) const; + + MatrixType m_scores; + BuilderInfoList m_builderInfo; +}; + +/** + * \brief Holds information about the data and the parameters used by a specific modelbuilder + */ +class BuilderInfo { + + friend class ModelInfo; + + public: + typedef std::pair KeyValuePair; + typedef std::list KeyValueList; + + // Currently all the info entries are just simple list of string pairs. + // We don't want to use maps, as this would sort the items according to the key. + typedef KeyValueList DataInfoList; + typedef KeyValueList ParameterInfoList; + + /** + * Creates a new BuilderInfo object with the given information + */ + BuilderInfo(const std::string& modelBuilderName, const std::string& buildTime, const DataInfoList& di, const ParameterInfoList& pi); + + BuilderInfo(const std::string& modelBuilderName, const DataInfoList& di, const ParameterInfoList& pi); + + /** + * Create a new, empty BilderInfo object + */ + BuilderInfo(); + + /// destructor + virtual ~BuilderInfo(); + + BuilderInfo& operator=(const BuilderInfo& rhs); + + BuilderInfo(const BuilderInfo& orig); + + + /** + * Saves the builder info to the given group in the HDF5 file + */ + virtual void Save(const H5::H5Location& publicFg) const; + + /** + * Loads the builder info from the given group in the HDF5 file. + */ + virtual void Load(const H5::H5Location& publicFg); + + /** + * Returns the data info + */ + const DataInfoList& GetDataInfo() const; + + /** + * Returns the parameter info + */ + const ParameterInfoList& GetParameterInfo() const; + + private: + + + + static void FillKeyValueListFromInfoGroup(const H5::H5Location& group, KeyValueList& keyValueList); + + std::string m_modelBuilderName; + std::string m_buildtime; + DataInfoList m_dataInfo; + ParameterInfoList m_parameterInfo; +}; + +} // namespace statismo + +#endif /* MODELINFO_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Nystrom.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Nystrom.h new file mode 100644 index 000000000..46277c4cc --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Nystrom.h @@ -0,0 +1,179 @@ +#ifndef NYSTROM_H +#define NYSTROM_H + +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Statismo is licensed under the BSD licence (3 clause) license + */ + +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "Kernels.h" +#include "RandSVD.h" +#include "Representer.h" + +namespace statismo { + +/** + * Computes the Nystrom approximation of a given kernel + * The type parameter T is the type of the dataset (e.g. Mesh, Image) for which the nystom approximation is computed + */ +template +class Nystrom { + public: + + + typedef typename Representer::PointType PointType; + typedef statismo::Domain::PointType> DomainType; + typedef typename DomainType::DomainPointsListType DomainPointsListType; + + static Nystrom* Create(const Representer* representer, const MatrixValuedKernel& kernel, unsigned numEigenfunctions, unsigned numberOfPointsForApproximation) { + return new Nystrom(representer, kernel, numEigenfunctions, numberOfPointsForApproximation); + } + + + /** + * Returns a d x n matrix, which holds the d-dimension value of all the n eigenfunctiosn at the given point + */ + MatrixType computeEigenfunctionsAtPoint(const PointType& pt) const { + + unsigned kernelDim = m_kernel.GetDimension(); + + // for every domain point x in the list, we compute the kernel vector + // kx = (k(x, x1), ... k(x, xm)) + // since the kernel is matrix valued, kx is actually a matrix + MatrixType kxi = MatrixType::Zero(kernelDim, m_nystromPoints.size() * kernelDim); + + for (unsigned j = 0; j < m_nystromPoints.size(); j++) { + kxi.block(0, j * kernelDim, kernelDim, kernelDim) = m_kernel(pt, m_nystromPoints[j]); + } + + + MatrixType resMat = MatrixType::Zero(kernelDim, m_numEigenfunctions); + for (unsigned j = 0; j < m_numEigenfunctions; j++) { + MatrixType x = (kxi * m_nystromMatrix.col(j)); + resMat.block(0, j, kernelDim, 1) = x; + } + return resMat; + } + + + /** + * Returns a vector of size n, where n is the number of eigenfunctions/eigenvalues that were approximated + */ + const VectorType& getEigenvalues() const { + return m_eigenvalues; + } + + + private: + + + Nystrom(const Representer* representer, const MatrixValuedKernel& kernel, unsigned numEigenfunctions, unsigned numberOfPointsForApproximation) + : m_representer(representer), m_kernel(kernel), m_numEigenfunctions(numEigenfunctions) { + + DomainType domain = m_representer->GetDomain(); + m_nystromPoints = getNystromPoints(domain, numberOfPointsForApproximation); + unsigned numDomainPoints = domain.GetNumberOfPoints(); + + // compute a eigenvalue decomposition of the kernel matrix, evaluated at the points used for the + // nystrom approximation + + MatrixType U; // will hold the eigenvectors (principal components) + VectorType D; // will hold the eigenvalues (variance) + computeKernelMatrixDecomposition(&kernel, m_nystromPoints, numEigenfunctions, U, D); + + + // precompute the part of the nystrom approximation, which is independent of the domain point + float normFactor = static_cast(m_nystromPoints.size()) / static_cast(numDomainPoints); + m_nystromMatrix = std::sqrt(normFactor) * (U.leftCols(numEigenfunctions) + * D.topRows(numEigenfunctions).asDiagonal().inverse()); + + m_eigenvalues = (1.0f / normFactor) * D.topRows(numEigenfunctions); + + } + + + /* + * Returns a random set of points from the domain. + * + * @param domain the domain to sample from + * @param numberOfPoints the size of the sample + */ + std::vector getNystromPoints(DomainType& domain, unsigned numberOfPoints) const { + + numberOfPoints = std::min(numberOfPoints, domain.GetNumberOfPoints()); + + std::vector shuffledDomainPoints = domain.GetDomainPoints(); + std::random_shuffle ( shuffledDomainPoints.begin(), shuffledDomainPoints.end() ); + + return std::vector(shuffledDomainPoints.begin(), shuffledDomainPoints.begin() + numberOfPoints); + } + + + + + /** + * Compute the kernel matrix for all points given in xs and + * return a matrix U with the first numComponents eigenvectors and a vector D with + * the corresponding eigenvalues of this kernel matrix + */ + void computeKernelMatrixDecomposition(const MatrixValuedKernel* kernel, + const std::vector& xs, unsigned numComponents, + MatrixType& U, VectorType& D) const { + unsigned kernelDim = kernel->GetDimension(); + + unsigned n = xs.size(); + MatrixTypeDoublePrecision K = MatrixTypeDoublePrecision::Zero( + n * kernelDim, n * kernelDim); + for (unsigned i = 0; i < n; ++i) { + for (unsigned j = i; j < n; ++j) { + MatrixType k_xixj = (*kernel)(xs[i], xs[j]); + for (unsigned d1 = 0; d1 < kernelDim; d1++) { + for (unsigned d2 = 0; d2 < kernelDim; d2++) { + double elem_d1d2 = k_xixj(d1, d2); + K(i * kernelDim + d1, j * kernelDim + d2) = elem_d1d2; + K(j * kernelDim + d2, i * kernelDim + d1) = elem_d1d2; + } + } + } + } + + typedef RandSVD SVDType; + SVDType svd(K, numComponents * kernelDim); + U = svd.matrixU().cast(); + D = svd.singularValues().cast(); + } + + + // private, to prevent use + Nystrom(); + Nystrom& operator=(const Nystrom& rhs); + Nystrom(const Nystrom& orig); + + + // + // members + // + + const Representer* m_representer; + MatrixType m_nystromMatrix; + VectorType m_eigenvalues; + std::vector m_nystromPoints; + const MatrixValuedKernel& m_kernel; + unsigned m_numEigenfunctions; + + +}; + + +} // namespace statismo +#endif // NYSTROM_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.h new file mode 100644 index 000000000..1ea84d14e --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.h @@ -0,0 +1,127 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __PCAMODELBUILDER_H_ +#define __PCAMODELBUILDER_H_ + +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "DataManager.h" +#include "ModelBuilder.h" +#include "ModelInfo.h" +#include "StatisticalModel.h" + +namespace statismo { + + +/** + * \brief Creates StatisticalModel using Principal Component Analysis. + * + * This class implements the classical PCA based approach to Statistical Models. + */ +template +class PCAModelBuilder : public ModelBuilder { + + + public: + + typedef ModelBuilder Superclass; + typedef typename Superclass::DataManagerType DataManagerType; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + typedef typename DataManagerType::DataItemListType DataItemListType; + + /** + * @brief The EigenValueMethod enum This type is used to specify which decomposition method resp. eigenvalue solver sould be used. Default is JacobiSVD which is the most accurate but for larger systems quite slow. In this case the SelfAdjointEigensolver is more appropriate (especially, if there are more examples than variables). + */ + typedef enum { JacobiSVD, SelfAdjointEigenSolver } EigenValueMethod; + + /** + * Factory method to create a new PCAModelBuilder + */ + static PCAModelBuilder* Create() { + return new PCAModelBuilder(); + } + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + + /** + * The desctructor + */ + virtual ~PCAModelBuilder() {} + + /** + * Build a new model from the training data provided in the dataManager. + * \param samples A sampleSet holding the data + * \param noiseVariance The variance of N(0, noiseVariance) distributed noise on the points. + * If this parameter is set to 0, we have a standard PCA model. For values > 0 we have a PPCA model. + * \param computeScores Determines whether the scores (the pca coefficients of the examples) are computed and stored as model info + * (computing the scores may take a long time for large models). + * \param method Specifies the method which is used for the decomposition resp. eigenvalue solver. + * + * \return A new Statistical model + * \warning The method allocates a new Statistical Model object, that needs to be deleted by the user. + */ + StatisticalModelType* BuildNewModel(const DataItemListType& samples, double noiseVariance, bool computeScores = true, EigenValueMethod method = JacobiSVD) const; + + + private: + // to prevent use + PCAModelBuilder(); + PCAModelBuilder(const PCAModelBuilder& orig); + PCAModelBuilder& operator=(const PCAModelBuilder& rhs); + + StatisticalModelType* BuildNewModelInternal(const Representer* representer, const MatrixType& X, const VectorType& mu, double noiseVariance, EigenValueMethod method = JacobiSVD) const; + + +}; + +} // namespace statismo + +#include "PCAModelBuilder.hxx" + +#endif /* __PCAMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.hxx new file mode 100644 index 000000000..95094c331 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PCAModelBuilder.hxx @@ -0,0 +1,251 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __PCAModelBuilder_TXX +#define __PCAModelBuilder_TXX + +#include "PCAModelBuilder.h" + +#include + +#include ITK_EIGEN(SVD) +#include ITK_EIGEN(Eigenvalues) + +#include "CommonTypes.h" +#include "Exceptions.h" + +namespace statismo { + +template +PCAModelBuilder::PCAModelBuilder() + : Superclass() { +} + + +template +typename PCAModelBuilder::StatisticalModelType* +PCAModelBuilder::BuildNewModel(const DataItemListType& sampleDataList, double noiseVariance, bool computeScores, EigenValueMethod method) const { + + unsigned n = sampleDataList.size(); + if (n <= 0) { + throw StatisticalModelException("Provided empty sample set. Cannot build the sample matrix"); + } + + unsigned p = sampleDataList.front()->GetSampleVector().rows(); + const Representer* representer = sampleDataList.front()->GetRepresenter(); + + + // Compute the mean vector mu + VectorType mu = VectorType::Zero(p); + + for (typename DataItemListType::const_iterator it = sampleDataList.begin(); + it != sampleDataList.end(); ++it) { + assert((*it)->GetSampleVector().rows() == p); // all samples must have same number of rows + assert((*it)->GetRepresenter() == representer); // all samples have the same representer + mu += (*it)->GetSampleVector(); + } + mu /= n; + + // Build the mean free sample matrix X0 + MatrixType X0(n, p); + unsigned i = 0; + for (typename DataItemListType::const_iterator it = sampleDataList.begin(); + it != sampleDataList.end(); ++it) { + X0.row(i++) = (*it)->GetSampleVector() - mu; + } + + + + + + // build the model + StatisticalModelType* model = BuildNewModelInternal(representer, X0, mu, noiseVariance, method); + + // compute the scores if requested + MatrixType scores; + if (computeScores) { + scores = this->ComputeScores(sampleDataList, model); + } + + + typename BuilderInfo::ParameterInfoList bi; + bi.push_back(BuilderInfo::KeyValuePair("NoiseVariance ", Utils::toString(noiseVariance))); + + typename BuilderInfo::DataInfoList dataInfo; + i = 0; + for (typename DataItemListType::const_iterator it = sampleDataList.begin(); + it != sampleDataList.end(); + ++it, i++) { + std::ostringstream os; + os << "URI_" << i; + dataInfo.push_back(BuilderInfo::KeyValuePair(os.str().c_str(),(*it)->GetDatasetURI())); + } + + + // finally add meta data to the model info + BuilderInfo builderInfo("PCAModelBuilder", dataInfo, bi); + + ModelInfo::BuilderInfoList biList; + biList.push_back(builderInfo); + + ModelInfo info(scores, biList); + model->SetModelInfo(info); + + return model; +} + + +template +typename PCAModelBuilder::StatisticalModelType* +PCAModelBuilder::BuildNewModelInternal(const Representer* representer, const MatrixType& X0, const VectorType& mu, + double noiseVariance, EigenValueMethod method) const { + + unsigned n = X0.rows(); + unsigned p = X0.cols(); + + switch(method) { + case JacobiSVD: + + typedef Eigen::JacobiSVD SVDType; + typedef Eigen::JacobiSVD SVDDoublePrecisionType; + + // We destinguish the case where we have more variables than samples and + // the case where we have more samples than variable. + // In the first case we compute the (smaller) inner product matrix instead of the full covariance matrix. + // It is known that this has the same non-zero singular values as the covariance matrix. + // Furthermore, it is possible to compute the corresponding eigenvectors of the covariance matrix from the + // decomposition. + + if (n < p) { + // we compute the eigenvectors of the covariance matrix by computing an SVD of the + // n x n inner product matrix 1/(n-1) X0X0^T + MatrixType Cov = X0 * X0.transpose() * 1.0/(n-1); + SVDDoublePrecisionType SVD(Cov.cast(), Eigen::ComputeThinV); + VectorType singularValues = SVD.singularValues().cast(); + MatrixType V = SVD.matrixV().cast(); + + unsigned numComponentsAboveTolerance = ((singularValues.array() - noiseVariance - Superclass::TOLERANCE) > 0).count(); + + // there can be at most n-1 nonzero singular values in this case. Everything else must be due to numerical inaccuracies + unsigned numComponentsToKeep = std::min(numComponentsAboveTolerance, n - 1); + // compute the pseudo inverse of the square root of the singular values + // which is then needed to recompute the PCA basis + VectorType singSqrt = singularValues.array().sqrt(); + VectorType singSqrtInv = VectorType::Zero(singSqrt.rows()); + for (unsigned i = 0; i < numComponentsToKeep; i++) { + assert(singSqrt(i) > Superclass::TOLERANCE); + singSqrtInv(i) = 1.0 / singSqrt(i); + } + + if (numComponentsToKeep == 0) { + throw StatisticalModelException("All the eigenvalues are below the given tolerance. Model cannot be built."); + } + + // we recover the eigenvectors U of the full covariance matrix from the eigenvectors V of the inner product matrix. + // We use the fact that if we decompose X as X=UDV^T, then we get X^TX = UD^2U^T and XX^T = VD^2V^T (exploiting the orthogonormality + // of the matrix U and V from the SVD). The additional factor sqrt(n-1) is to compensate for the 1/sqrt(n-1) in the formula + // for the covariance matrix. + + MatrixType pcaBasis = X0.transpose() * V * singSqrtInv.asDiagonal(); + pcaBasis /= sqrt(n - 1.0); + pcaBasis.conservativeResize(Eigen::NoChange, numComponentsToKeep); + + + VectorType sampleVarianceVector = singularValues.topRows(numComponentsToKeep); + VectorType pcaVariance = (sampleVarianceVector - VectorType::Ones(numComponentsToKeep) * noiseVariance); + + StatisticalModelType* model = StatisticalModelType::Create(representer, mu, pcaBasis, pcaVariance, noiseVariance); + + return model; + } else { + // we compute an SVD of the full p x p covariance matrix 1/(n-1) X0^TX0 directly + SVDType SVD(X0.transpose() * X0, Eigen::ComputeThinU); + VectorType singularValues = SVD.singularValues(); + singularValues /= (n - 1.0); + unsigned numComponentsToKeep = ((singularValues.array() - noiseVariance - Superclass::TOLERANCE) > 0).count(); + MatrixType pcaBasis = SVD.matrixU(); + + pcaBasis.conservativeResize(Eigen::NoChange, numComponentsToKeep); + + if (numComponentsToKeep == 0) { + throw StatisticalModelException("All the eigenvalues are below the given tolerance. Model cannot be built."); + } + + VectorType sampleVarianceVector = singularValues.topRows(numComponentsToKeep); + VectorType pcaVariance = (sampleVarianceVector - VectorType::Ones(numComponentsToKeep) * noiseVariance); + StatisticalModelType* model = StatisticalModelType::Create(representer, mu, pcaBasis, pcaVariance, noiseVariance); + return model; + } + break; + + case SelfAdjointEigenSolver: { + // we compute the eigenvalues/eigenvectors of the full p x p covariance matrix 1/(n-1) X0^TX0 directly + + typedef Eigen::SelfAdjointEigenSolver SelfAdjointEigenSolver; + SelfAdjointEigenSolver es; + es.compute(X0.transpose() * X0); + VectorType eigenValues = es.eigenvalues().reverse(); // SelfAdjointEigenSolver orders the eigenvalues in increasing order + eigenValues /= (n -1.0); + + + unsigned numComponentsToKeep = ((eigenValues.array() - noiseVariance - Superclass::TOLERANCE) > 0).count(); + MatrixType pcaBasis = es.eigenvectors().rowwise().reverse(); + pcaBasis.conservativeResize(Eigen::NoChange, numComponentsToKeep); + + + if (numComponentsToKeep == 0) { + throw StatisticalModelException("All the eigenvalues are below the given tolerance. Model cannot be built."); + } + + VectorType sampleVarianceVector = eigenValues.topRows(numComponentsToKeep); + VectorType pcaVariance = (sampleVarianceVector - VectorType::Ones(numComponentsToKeep) * noiseVariance); + StatisticalModelType* model = StatisticalModelType::Create(representer, mu, pcaBasis, pcaVariance, noiseVariance); + return model; + } + break; + + default: + throw StatisticalModelException("Unrecognized decomposition/eigenvalue solver method."); + return 0; + break; + } +} + + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.h new file mode 100644 index 000000000..737858e71 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.h @@ -0,0 +1,209 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef __POSTERIORMODELBUILDER_H_ +#define __POSTERIORMODELBUILDER_H_ + +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "DataManager.h" +#include "ModelBuilder.h" +#include "Representer.h" +#include "StatisticalModel.h" + +namespace statismo { + + +/** + * \brief Given a statistical model (prior) and a set of point constraints (likelihood), generate a new PCA model (posterior). + * + * This class builds a StatisticalModel, just as PCAModelBuilder. However, in addition to the data, + * this model builder also takes as input a set of point constraints, i.e. known values for points. + * The resulting model will satisfy these constraints, and thus has a much lower variability than an + * unconstrained model would have. + * + * For mathematical detailes see the paper + * Posterior Shape Models + * Thomas Albrecht, Marcel Luethi, Thomas Gerig, Thomas Vetter + * Medical Image Analysis 2013 + * + * Add method that allows for the use of the pointId in the constraint. + */ +template +class PosteriorModelBuilder : public ModelBuilder { + public: + + typedef Representer RepresenterType; + typedef ModelBuilder Superclass; + typedef typename Superclass::DataManagerType DataManagerType; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + typedef typename RepresenterType::ValueType ValueType; + typedef typename RepresenterType::PointType PointType; + typedef typename StatisticalModelType::PointValueListType PointValueListType; + typedef typename DataManagerType::DataItemListType DataItemListType; + + + typedef typename StatisticalModelType::PointValuePairType PointValuePairType; + typedef typename StatisticalModelType::PointCovarianceMatrixType PointCovarianceMatrixType; + typedef typename StatisticalModelType::PointValueWithCovariancePairType PointValueWithCovariancePairType; + typedef typename StatisticalModelType::PointValueWithCovarianceListType PointValueWithCovarianceListType; + + /** + * Factory method to create a new PosteriorModelBuilder + * \param representer The representer + */ + static PosteriorModelBuilder* Create() { + return new PosteriorModelBuilder(); + } + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + /** + * destructor + */ + virtual ~PosteriorModelBuilder() {} + + /** + * Builds a new model from the data provided in the dataManager, and the given constraints. + * This version of the function assumes a noise with a uniform uncorrelated variance + * of the form pointValueNoiseVariance * identityMatrix at every given point. + * \param DataItemList The list holding the data the model is built from + * \param pointValues A list of (point, value) pairs with the known values. + * \param pointValueNoiseVariance The variance of the estimated error at the known points (the pointValues) + * \param noiseVariance The variance of the noise assumed on our data + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModel(const DataItemListType& dataItemList, + const PointValueListType& pointValues, + double pointValueNoiseVariance, + double noiseVariance) const; + + + /** + * Builds a new model from the data provided in the dataManager, and the given constraints. + * For this version of the function, the covariance matrix of the noise needs to be specified for + * every point. These covariance matrices are passed in the pointValuesWithCovariance list. + * + * \param DataItemList The list holding the data the model is built from + * \param pointValuesWithCovariance A list of ((point,value), covarianceMatrix) for each known value. + * \param noiseVariance The variance of the noise assumed on our data + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModel(const DataItemListType& DataItemList, + const PointValueWithCovarianceListType& pointValuesWithCovariance, + double noiseVariance) const; + + + + /** + * Builds a new StatisticalModel given a StatisticalModel and the given constraints. + * If we interpret the given model as a prior distribution over the modeled objects, + * the resulting model can (loosely) be interpreted as the posterior distribution, + * after having observed the data given in the PointValues. + * This version of the function assumes a noise with a uniform uncorrelated variance + * of the form pointValueNoiseVariance * identityMatrix at every given point. + * + * \param model A statistical model. + * \param pointValues A list of (point, value) pairs with the known values. + * \param pointValueNoiseVariance The variance of the estimated error at the known points (the pointValues) + * \param computeScores Determines whether the scores are computed and stored in the model. + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModelFromModel(const StatisticalModelType* model, const PointValueListType& pointValues, double pointValueNoiseVariance, bool computeScores=true) const; + + + /** + * Builds a new StatisticalModel given a StatisticalModel and the given constraints. + * If we interpret the given model as a prior distribution over the modeled objects, + * the resulting model can (loosely) be interpreted as the posterior distribution, + * after having observed the data given in the PointValues. + * For this version of the function, the covariance matrix of the noise needs to be specified for + * every point. These covariance matrices are passed in the pointValuesWithCovariance list. + * + * \param model A statistical model. + * \param pointValuesWithCovariance A list of ((point,value), covarianceMatrix) for each known value. + * \param computeScores Determines whether the scores are computed and stored in the model. + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModelFromModel(const StatisticalModelType* model, + const PointValueWithCovarianceListType& pointValuesWithCovariance, + bool computeScores=true) const; + + /** + * A convenience function to create a PointValueWithCovarianceList with uniform variance + * + * \param pointValues A list of (point, value) pairs with the known values. + * \param pointValueNoiseVariance The variance of the estimated error at the known points (the pointValues) + * \return a PointValueWithCovarianceListType with the given uniform variance + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + PointValueWithCovarianceListType TrivialPointValueWithCovarianceListWithUniformNoise(const PointValueListType& pointValues, + double pointValueNoiseVariance) const; + + private: + PosteriorModelBuilder(); + PosteriorModelBuilder(const PosteriorModelBuilder& orig); + PosteriorModelBuilder& operator=(const PosteriorModelBuilder& rhs); + + +}; + +} // namespace statismo + +#include "PosteriorModelBuilder.hxx" + +#endif /* __POSTERIORMODELBUILDER_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.hxx new file mode 100644 index 000000000..d4c86a5ca --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/PosteriorModelBuilder.hxx @@ -0,0 +1,269 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __PosteriorModelBuilder_hxx +#define __PosteriorModelBuilder_hxx + +#include "PosteriorModelBuilder.h" + +#include + +#include + +#include "CommonTypes.h" +#include "PCAModelBuilder.h" + +namespace statismo { + +// +// PosteriorModelBuilder +// +// + +template +PosteriorModelBuilder::PosteriorModelBuilder() + : Superclass() { +} + + +template +typename PosteriorModelBuilder::StatisticalModelType* +PosteriorModelBuilder::BuildNewModel( + const DataItemListType& sampleDataList, + const PointValueListType& pointValues, + double pointValuesNoiseVariance, + double noiseVariance) const { + return BuildNewModel(sampleDataList, TrivialPointValueWithCovarianceListWithUniformNoise(pointValues, pointValuesNoiseVariance), noiseVariance); +} + + +template +typename PosteriorModelBuilder::StatisticalModelType* +PosteriorModelBuilder::BuildNewModelFromModel( + const StatisticalModelType* inputModel, + const PointValueListType& pointValues, + double pointValuesNoiseVariance, + bool computeScores) const { + + return BuildNewModelFromModel(inputModel, TrivialPointValueWithCovarianceListWithUniformNoise(pointValues,pointValuesNoiseVariance), computeScores); + +} + +template +typename PosteriorModelBuilder::PointValueWithCovarianceListType +PosteriorModelBuilder::TrivialPointValueWithCovarianceListWithUniformNoise( + const PointValueListType& pointValues, double pointValueNoiseVariance) const { + + const MatrixType pointCovarianceMatrix = pointValueNoiseVariance * MatrixType::Identity(3,3); + PointValueWithCovarianceListType pvcList;//(pointValues.size()); + + + for (typename PointValueListType::const_iterator it = pointValues.begin(); it != pointValues.end(); ++it) { + pvcList.push_back(PointValueWithCovariancePairType(*it,pointCovarianceMatrix)); + } + + return pvcList; + +} + + +template +typename PosteriorModelBuilder::StatisticalModelType* +PosteriorModelBuilder::BuildNewModel( + const DataItemListType& sampleDataList, + const PointValueWithCovarianceListType& pointValuesWithCovariance, + double noiseVariance) const { + typedef PCAModelBuilder PCAModelBuilderType; + PCAModelBuilderType* modelBuilder = PCAModelBuilderType::Create(); + StatisticalModelType* model = modelBuilder->BuildNewModel(sampleDataList, noiseVariance); + StatisticalModelType* PosteriorModel = BuildNewModelFromModel(model, pointValuesWithCovariance, noiseVariance); + delete modelBuilder; + delete model; + return PosteriorModel; +} + + +template +typename PosteriorModelBuilder::StatisticalModelType* +PosteriorModelBuilder::BuildNewModelFromModel( + const StatisticalModelType* inputModel, + const PointValueWithCovarianceListType& pointValuesWithCovariance, + bool computeScores) const { + + typedef statismo::Representer RepresenterType; + + const RepresenterType* representer = inputModel->GetRepresenter(); + + + // The naming of the variables correspond to those used in the paper + // Posterior Shape Models, + // Thomas Albrecht, Marcel Luethi, Thomas Gerig, Thomas Vetter + // + const MatrixType& Q = inputModel->GetPCABasisMatrix(); + const VectorType& mu = inputModel->GetMeanVector(); + + // this method only makes sense for a proper PPCA model (e.g. the noise term is properly defined) + // if the model has zero noise, we assume a small amount of noise + double rho2 = std::max((double) inputModel->GetNoiseVariance(), (double) Superclass::TOLERANCE); + + unsigned dim = representer->GetDimensions(); + + + // build the part matrices with , considering only the points that are fixed + // + unsigned numPrincipalComponents = inputModel->GetNumberOfPrincipalComponents(); + MatrixType Q_g(pointValuesWithCovariance.size()* dim, numPrincipalComponents); + VectorType mu_g(pointValuesWithCovariance.size() * dim); + VectorType s_g(pointValuesWithCovariance.size() * dim); + + MatrixType LQ_g(pointValuesWithCovariance.size()* dim, numPrincipalComponents); + + unsigned i = 0; + for (typename PointValueWithCovarianceListType::const_iterator it = pointValuesWithCovariance.begin(); it != pointValuesWithCovariance.end(); ++it) { + VectorType val = representer->PointSampleToPointSampleVector(it->first.second); + unsigned pt_id = representer->GetPointIdForPoint(it->first.first); + + // In the formulas, we actually need the precision matrix, which is the inverse of the covariance. + const MatrixType pointPrecisionMatrix = it->second.inverse(); + + // Get the three rows pertaining to this point: + const MatrixType Qrows_for_pt_id = Q.block(pt_id * dim, 0, dim, numPrincipalComponents); + + Q_g.block(i * dim, 0, dim, numPrincipalComponents) = Qrows_for_pt_id; + mu_g.block(i * dim, 0, dim, 1) = mu.block(pt_id * dim, 0, dim, 1); + s_g.block(i * dim, 0, dim, 1) = val; + + LQ_g.block(i * dim, 0, dim, numPrincipalComponents) = pointPrecisionMatrix * Qrows_for_pt_id; + i++; + } + + VectorType D2 = inputModel->GetPCAVarianceVector().array(); + + const MatrixType& Q_gT = Q_g.transpose(); + + MatrixType M = Q_gT * LQ_g; + M.diagonal() += VectorType::Ones(Q_g.cols()); + + MatrixTypeDoublePrecision Minv = M.cast().inverse(); + + // the MAP solution for the latent variables (coefficients) + VectorType coeffs = Minv.cast() * LQ_g.transpose() * (s_g - mu_g); + + // the MAP solution in the sample space + VectorType mu_c = inputModel->GetRepresenter()->SampleToSampleVector(inputModel->DrawSample(coeffs)); + + const VectorType& pcaVariance = inputModel->GetPCAVarianceVector(); + VectorTypeDoublePrecision pcaSdev = pcaVariance.cast().array().sqrt(); + + VectorType D2MinusRho = D2 - VectorType::Ones(D2.rows()) * rho2; + // the values of D2 can be negative. We need to be careful when taking the root + for (unsigned i = 0; i < D2MinusRho.rows(); i++) { + D2MinusRho(i) = std::max((ScalarType) 0, D2(i)); + } + VectorType D2MinusRhoSqrt = D2MinusRho.array().sqrt(); + + + typedef Eigen::JacobiSVD SVDType; + MatrixTypeDoublePrecision innerMatrix = D2MinusRhoSqrt.cast().asDiagonal() * Minv * D2MinusRhoSqrt.cast().asDiagonal(); + SVDType svd(innerMatrix, Eigen::ComputeThinU); + + + // SVD of the inner matrix + VectorType D_c = svd.singularValues().cast(); + + // Todo: Maybe it is possible to do this with Q, so that we don"t need to get U as well. + MatrixType U_c = inputModel->GetOrthonormalPCABasisMatrix() * svd.matrixU().cast(); + + StatisticalModelType* PosteriorModel = StatisticalModelType::Create(representer , mu_c, U_c, D_c, rho2); + + // Write the parameters used to build the models into the builderInfo + + typename ModelInfo::BuilderInfoList builderInfoList = inputModel->GetModelInfo().GetBuilderInfoList(); + + BuilderInfo::ParameterInfoList bi; + bi.push_back(BuilderInfo::KeyValuePair("NoiseVariance ", Utils::toString(rho2))); + bi.push_back(BuilderInfo::KeyValuePair("FixedPointsVariance ", Utils::toString(0.2))); +// + BuilderInfo::DataInfoList di; + + unsigned pt_no = 0; + for (typename PointValueWithCovarianceListType::const_iterator it = pointValuesWithCovariance.begin(); it != pointValuesWithCovariance.end(); ++it) { + VectorType val = representer->PointSampleToPointSampleVector(it->first.second); + + // TODO we looked up the PointId for the same point before. Having it here again is inefficient. + unsigned pt_id = representer->GetPointIdForPoint(it->first.first); + std::ostringstream keySStream; + keySStream << "Point constraint " << pt_no; + std::ostringstream valueSStream; + valueSStream << "(" << pt_id << ", ("; + + for (unsigned d = 0; d < dim - 1; d++) { + valueSStream << val[d] << ","; + } + valueSStream << val[dim -1]; + valueSStream << "))"; + di.push_back(BuilderInfo::KeyValuePair(keySStream.str(), valueSStream.str())); + pt_no++; + } + + + BuilderInfo builderInfo("PosteriorModelBuilder", di, bi); + builderInfoList.push_back(builderInfo); + + MatrixType inputScores = inputModel->GetModelInfo().GetScoresMatrix(); + MatrixType scores = MatrixType::Zero(inputScores.rows(), inputScores.cols()); + + if (computeScores == true) { + + // get the scores from the input model + for (unsigned i = 0; i < inputScores.cols(); i++) { + // reconstruct the sample from the input model and project it back into the model + typename RepresenterType::DatasetPointerType ds = inputModel->DrawSample(inputScores.col(i)); + scores.col(i) = PosteriorModel->ComputeCoefficients(ds); + representer->DeleteDataset(ds); + } + } + ModelInfo info(scores, builderInfoList); + PosteriorModel->SetModelInfo(info); + + return PosteriorModel; + +} + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/RandSVD.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/RandSVD.h new file mode 100644 index 000000000..70174bfa0 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/RandSVD.h @@ -0,0 +1,109 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __RANDSVD_H +#define __RANDSVD_H + +#include + +#include +#include + +#include + +#include + +namespace statismo { +/** + * TODO comment and add reference to paper + */ +template +class RandSVD { + public: + + typedef Eigen::Matrix VectorType; + typedef Eigen::Matrix MatrixType; + + RandSVD(const MatrixType& A, unsigned k) { + + unsigned n = A.rows(); + + + static boost::minstd_rand randgen(static_cast(time(0))); + static boost::normal_distribution<> dist(0, 1); + static boost::variate_generator > r(randgen, dist); + + // create gaussian random amtrix + MatrixType Omega(n, k); + for (unsigned i =0; i < n ; i++) { + for (unsigned j = 0; j < k ; j++) { + Omega(i,j) = r(); + } + } + + + MatrixType Y = A * A.transpose() * A * Omega; + Eigen::FullPivHouseholderQR qr(Y); + MatrixType Q = qr.matrixQ().leftCols(k + k); + + MatrixType B = Q.transpose() * A; + + typedef Eigen::JacobiSVD SVDType; + SVDType SVD(B, Eigen::ComputeThinU); + MatrixType Uhat = SVD.matrixU(); + m_D = SVD.singularValues(); + m_U = (Q * Uhat).leftCols(k); + } + + MatrixType matrixU() const { + return m_U; + } + + VectorType singularValues() const { + return m_D; + } + + + private: + VectorType m_D; + MatrixType m_U; +}; + + +} // namespace statismo; +#endif // __LANCZOS_H diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.h new file mode 100644 index 000000000..48df50558 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.h @@ -0,0 +1,132 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ReducedVarianceModelBuilder_H_ +#define __ReducedVarianceModelBuilder_H_ + +#include +#include + +#include "Config.h" +#include "CommonTypes.h" +#include "DataManager.h" +#include "ModelBuilder.h" +#include "ModelInfo.h" +#include "StatismoUtils.h" +#include "StatisticalModel.h" + +namespace statismo { + + +/** + * \brief Builds a new model which retains only the specified total variance + * + */ +template +class ReducedVarianceModelBuilder : public ModelBuilder { + + + public: + + typedef ModelBuilder Superclass; + typedef typename Superclass::StatisticalModelType StatisticalModelType; + + /** + * Factory method to create a new ReducedVarianceModelBuilder + */ + static ReducedVarianceModelBuilder* Create() { + return new ReducedVarianceModelBuilder(); + } + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() { + delete this; + } + + + /** + * The desctructor + */ + virtual ~ReducedVarianceModelBuilder() {} + + /** + * Build a new model from the given model, which retains only the leading principal components + * + * \param model A statistical model. + * \param numberOfPrincipalComponents, + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModelWithLeadingComponents(const StatisticalModelType* model, unsigned numberOfPrincipalComponents) const; + + + /** + * Build a new model from the given model, which retains only the specified variance + * + * \param model A statistical model. + * \param totalVariance, The fraction of the variance to be retained + * \return a new statistical model + * + * \warning The returned model needs to be explicitly deleted by the user of this method. + */ + StatisticalModelType* BuildNewModelWithVariance(const StatisticalModelType* model, double totalVariance) const; + + + is_deprecated StatisticalModelType* BuildNewModelFromModel(const StatisticalModelType* model, double totalVariance) const ; + + + private: + // to prevent use + ReducedVarianceModelBuilder(); + ReducedVarianceModelBuilder(const ReducedVarianceModelBuilder& orig); + ReducedVarianceModelBuilder& operator=(const ReducedVarianceModelBuilder& rhs); + + +}; + + + +} // namespace statismo + +#include "ReducedVarianceModelBuilder.hxx" + +#endif /* __ReducedVarianceModelBuilder_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.hxx new file mode 100644 index 000000000..959ae8d64 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/ReducedVarianceModelBuilder.hxx @@ -0,0 +1,128 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __ReducedVarianceModelBuilder_hxx +#define __ReducedVarianceModelBuilder_hxx + +#include "ReducedVarianceModelBuilder.h" + +#include + +#include ITK_EIGEN(SVD) +#include ITK_EIGEN(Eigenvalues) + +#include "CommonTypes.h" +#include "Exceptions.h" + +namespace statismo { + +template +ReducedVarianceModelBuilder::ReducedVarianceModelBuilder() + : Superclass() { +} + +template +typename ReducedVarianceModelBuilder::StatisticalModelType* +ReducedVarianceModelBuilder::BuildNewModelWithLeadingComponents( + const StatisticalModelType* inputModel, + unsigned numberOfPrincipalComponents) const + +{ + StatisticalModelType* reducedModel = StatisticalModelType::Create( + inputModel->GetRepresenter(), + inputModel->GetMeanVector(), + inputModel->GetOrthonormalPCABasisMatrix().leftCols(numberOfPrincipalComponents), + inputModel->GetPCAVarianceVector().topRows(numberOfPrincipalComponents), + inputModel->GetNoiseVariance()); + + // Write the parameters used to build the models into the builderInfo + typename ModelInfo::BuilderInfoList builderInfoList = inputModel->GetModelInfo().GetBuilderInfoList(); + + BuilderInfo::ParameterInfoList bi; + bi.push_back(BuilderInfo::KeyValuePair("NumberOfPincipalComponents ", Utils::toString(numberOfPrincipalComponents))); + + BuilderInfo::DataInfoList di; + + BuilderInfo builderInfo("ReducedVarianceModelBuilder", di, bi); + builderInfoList.push_back(builderInfo); + + // If the scores matrix is not set, or if we have for some reasons not as many score entries as the number of principal components, + // we simply work with what is there. + unsigned numComponentsForScores = std::min(static_cast(inputModel->GetModelInfo().GetScoresMatrix().rows()), numberOfPrincipalComponents); + + ModelInfo info(inputModel->GetModelInfo().GetScoresMatrix().topRows(numComponentsForScores), builderInfoList); + reducedModel->SetModelInfo(info); + + return reducedModel; + +} + + + +template +typename ReducedVarianceModelBuilder::StatisticalModelType* +ReducedVarianceModelBuilder::BuildNewModelWithVariance( + const StatisticalModelType* inputModel, + double totalVariance) const { + + VectorType pcaVariance = inputModel->GetPCAVarianceVector(); + double modelVariance = pcaVariance.sum(); + + //count the number of modes required for the model + double cumulatedVariance = 0; + unsigned numComponentsToReachPrescribedVariance = 0; + for (unsigned i = 0; i < pcaVariance.size(); i++) { + cumulatedVariance += pcaVariance(i); + numComponentsToReachPrescribedVariance++; + if (cumulatedVariance / modelVariance >= totalVariance) + break; + } + return BuildNewModelWithLeadingComponents(inputModel, numComponentsToReachPrescribedVariance); +} + +template +typename ReducedVarianceModelBuilder::StatisticalModelType* +ReducedVarianceModelBuilder::BuildNewModelFromModel( + const StatisticalModelType* inputModel, + double totalVariance) const { + + return BuildNewModelWithVariance(inputModel, totalVariance); +} + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Representer.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Representer.h new file mode 100644 index 000000000..dc02cccca --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/Representer.h @@ -0,0 +1,318 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ +#ifndef REPRESENTER_H_ +#define REPRESENTER_H_ + +#include +#include + +#include "CommonTypes.h" +#include "Domain.h" + +/** + * \brief Provides the interface between statismo and the dataset type the application uses. + * + * A Representer is a type that provides the connection between the statismo library + * and the application. It distinguishes three different representations of the data, and provides methods for conversion between those representations: + * - a Dataset, typically as read from a file on the disk + * - a Sample, which is a geometric (generally a rigid or affine) transform of the dataset + * - a SampleVector, which is an internal representation (vector) useful from the statistical analysis. + * + * In the following the methods and types that have to be implemented to write a new + * Representer for your application are given. + * + * \warning This class is never actually used, but serves only for documentation purposes. + */ +//RB: would it be possible to make all representers inherit from it, so as to strictly enforce the interface? +namespace statismo { + +template +class RepresenterTraits { +}; + +template +class Representer { + + public: + + enum RepresenterDataType { + UNKNOWN = 0, + POINT_SET = 1, + POLYGON_MESH = 2, + VOLUME_MESH = 3, + IMAGE = 4, + VECTOR = 5, + CUSTOM = 99 + }; + + static RepresenterDataType TypeFromString(const std::string& s) { + if (s == "POINT_SET") + return POINT_SET; + else if (s == "POLYGON_MESH") + return POLYGON_MESH; + else if (s == "VOLUME_MESH") + return VOLUME_MESH; + else if (s == "IMAGE") + return IMAGE; + else if (s == "VECTOR") + return VECTOR; + else if (s == "CUSTOM") + return CUSTOM; + else + return UNKNOWN; + } + + static std::string TypeToString(const RepresenterDataType& type) { + switch (type) { + case POINT_SET: { + return "POINT_SET"; + break; + } + case POLYGON_MESH: { + return "POLYGON_MESH"; + break; + } + case VOLUME_MESH: { + return "VOLUME_MESH"; + break; + } + case IMAGE: { + return "IMAGE"; + break; + } + case VECTOR: { + return "VECTOR"; + break; + } + case CUSTOM: { + return "CUSTOM"; + break; + } + default: { + return "UNKNOWN"; + } + } + } + + /** + * \name Type definitions + */ + ///@{ + /// Defines (a pointer to) the type of the dataset that is represented. + /// This could either be a naked pointer or a smart pointer. + typedef typename RepresenterTraits::DatasetPointerType DatasetPointerType; + + /// Defines the const pointer type o fthe datset that is represented + typedef typename RepresenterTraits::DatasetConstPointerType DatasetConstPointerType; + + /// Defines the pointtype of the dataset + typedef typename RepresenterTraits::PointType PointType; + + /// Defines the type of the value when the dataset is evaluated at a given point + /// (for a image, this could for example be a scalar value or an RGB value) + typedef typename RepresenterTraits::ValueType ValueType; + + typedef T DatasetType; + + typedef Domain DomainType; + + virtual ~Representer() { + } + + /// Returns a name that identifies the representer + virtual std::string GetName() const = 0; + + virtual RepresenterDataType GetType() const = 0; + + /// Returns the dimensionality of the dataset (for a mesh this is 3, for a scalar image + /// this would be 1) + virtual unsigned GetDimensions() const = 0; + ///@} + + virtual std::string GetVersion() const = 0; + + /** + * \name Object creation and destruction + */ + ///@{ + /** Creates a new representer object, with the + * the information defined inthe given hdf5 group + * \sa Save + */ + virtual void Load(const H5::Group& fg) = 0; + + /** Clone the representer */ + virtual Representer* Clone() const = 0; + + /** Delete the representer object */ + virtual void Delete() const = 0; + + ///@} + + /** + * \name Adapter methods + */ + virtual void DeleteDataset(DatasetPointerType d) const = 0; + virtual DatasetPointerType CloneDataset(DatasetConstPointerType d) const = 0; + ///@} + + + + + + /** + * \name Conversion from the dataset to a vector representation and back + */ + ///@{ + + /** + * Returns the Domain for this representers. The domain is essentially a list of all the points on which the model is defined. + * \sa statismo::Domain + */ + virtual const statismo::Domain& GetDomain() const = 0; + + + virtual DatasetConstPointerType GetReference() const = 0; + + /** + * Converts a Dataset::PointType to a vector in statismo::Vector + */ + virtual VectorType PointToVector(const PointType& pt) const = 0; + + /** + * Returns a vectorial representation of the given sample. + */ + virtual VectorType SampleToSampleVector( + DatasetConstPointerType sample) const = 0; + + /** + * Takes a vector of nd elements and converts it to a sample. The sample is a type + * that is represnter (e.g. an image, a mesh, etc). + */ + virtual DatasetPointerType SampleVectorToSample( + const VectorType& sample) const = 0; + + /** + * Returns the value of the sample at the point with the given id. + */ + virtual ValueType PointSampleFromSample(DatasetConstPointerType sample, + unsigned ptid) const = 0; + + /** + * Take a point sample (i.e. the value of a sample at a given point) and converts it + * to its vector representation. + * The type of the point sample is a ValueType, that depends on the type of the dataset. + * For a mesh this would for example be a 3D point, + * while for a scalar image this would be a scalar value representing the intensity. + */ + virtual ValueType PointSampleVectorToPointSample( + const VectorType& v) const = 0; + + /** + * Convert the given vector represenation of a pointSample back to its ValueType + * \sa PointSampleVectorToPointSample + */ + virtual VectorType PointSampleToPointSampleVector( + const ValueType& pointSample) const = 0; + + /** + * Defines the mapping between the point ids and the position in the vector. + * Assume for example that a 3D mesh type is representerd. + * A conversion strategy used in DatasetToSampleVector could be to return + * a vector \f$(pt1_x, pt1_y, pt1_z, ..., ptn_x, ptn_y, ptn_z\f$. + * In this case, this method would return for inputs ptId, componentId + * the value ptId * 3 + componentId + */ + virtual unsigned MapPointIdToInternalIdx(unsigned ptId, + unsigned componentInd) const { + return ptId * GetDimensions() + componentInd; + } + + /** + * Given a point (the coordinates) return the pointId of this point. + */ + virtual unsigned GetPointIdForPoint(const PointType& point) const = 0; + + ///@} + + /** + * \name Persistence + */ + ///@{ + /** + * Save the informatino that define this representer to the group + * in the HDF5 file given by fg. + */ + virtual void Save(const H5::Group& fg) const = 0; + + ///@} + + /** + * \name Utiities + */ + /* + * Returns a new dataset that corresponds to the zero element of the underlying vectorspace + * obtained when vectorizing a dataset. + * + */ + virtual DatasetPointerType IdentitySample() const { + + switch (this->GetType()) { + case POINT_SET: + case POLYGON_MESH: + case VOLUME_MESH: { + return CloneDataset(this->GetReference()); + break; + } + case IMAGE: + case VECTOR: { + VectorType zeroVec = VectorType::Zero(GetDomain().GetNumberOfPoints() * GetDimensions()); + return SampleVectorToSample(zeroVec); + break; + } + default: { + throw statismo::StatisticalModelException( + "No cannonical identityDataset method is defined for custom Representers."); + } + } + } +}; + +} + +#endif /* REPRESENTER_H_ */ + diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoIO.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoIO.h new file mode 100644 index 000000000..8596135ff --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoIO.h @@ -0,0 +1,247 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef STATISMOIO_H_ +#define STATISMOIO_H_ + +#include "StatisticalModel.h" + +namespace H5 { +class Group; +} + +namespace statismo { +/** + * \brief The IO class is used to Load() and or Save() a StatisticalModel. The Load and Save functions are static and as such there's no need to create an instance of this class. + * + * The Template parameter is the same as the one of the StatisticalModel class. + * + */ +template +class IO { + private: + //This class is made up of static methods only and as such the Constructor is private to prevent misunderstandings. + IO() {} + + public: + typedef StatisticalModel StatisticalModelType; + + /** + * Returns a new statistical model, which is loaded from the given HDF5 file + * \param filename The filename + * \param maxNumberOfPCAComponents The maximal number of pca components that are loaded + * to create the model. + */ + static StatisticalModelType* LoadStatisticalModel(typename StatisticalModelType::RepresenterType *representer, + const std::string &filename, + unsigned maxNumberOfPCAComponents = std::numeric_limits::max()) { + + StatisticalModelType* newModel = 0; + + H5::H5File file; + try { + file = H5::H5File(filename.c_str(), H5F_ACC_RDONLY); + } catch (H5::Exception& e) { + std::string msg(std::string("could not open HDF5 file \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + + H5::Group modelRoot = file.openGroup("/"); + + newModel = LoadStatisticalModel(representer, modelRoot, maxNumberOfPCAComponents); + + modelRoot.close(); + file.close(); + return newModel; + } + + /** + * Returns a new statistical model, which is stored in the given HDF5 Group + * + * \param modelroot A h5 group where the model is saved + * \param maxNumberOfPCAComponents The maximal number of pca components that are loaded + * to create the model. + */ + static StatisticalModelType* LoadStatisticalModel(typename StatisticalModelType::RepresenterType *representer, + const H5::Group &modelRoot, + unsigned maxNumberOfPCAComponents = std::numeric_limits::max()) { + + StatisticalModelType* newModel; + ModelInfo modelInfo; + + try { + H5::Group representerGroup = modelRoot.openGroup("./representer"); + + representer->Load(representerGroup); + representerGroup.close(); + + int minorVersion = 0; + int majorVersion = 0; + + if (HDF5Utils::existsObjectWithName(modelRoot, "version") == false) { + // this is an old statismo format, that had not been versioned. We set the version to 0.8 as this is the last version + // that stores the old format + std::cout << "Warning: version attribute does not exist in hdf5 file. Assuming version 0.8" <SetModelInfo(modelInfo); + return newModel; + } + + + /** + * Saves the statistical model to a HDF5 file + * \param model A pointer to the model you'd like to save. + * \param filename The filename (preferred extension is .h5) + * */ + static void SaveStatisticalModel(const StatisticalModelType *const model, const std::string &filename) { + if(model == NULL) { + throw new StatisticalModelException("Passing on a NULL_Pointer when trying to save a model is not possible."); + } + SaveStatisticalModel(*model, filename); + } + + /** + * Saves the statistical model to a HDF5 file + * \param model The model you'd like to save + * \param filename The filename (preferred extension is .h5) + * */ + static void SaveStatisticalModel(const StatisticalModelType &model, const std::string &filename) { + using namespace H5; + + H5File file; + std::ifstream ifile(filename.c_str()); + + try { + file = H5::H5File( filename.c_str(), H5F_ACC_TRUNC); + } catch (H5::FileIException& e) { + std::string msg(std::string("Could not open HDF5 file for writing \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + + + H5::Group modelRoot = file.openGroup("/"); + + H5::Group versionGroup = modelRoot.createGroup("version"); + HDF5Utils::writeInt(versionGroup, "majorVersion", 0); + HDF5Utils::writeInt(versionGroup, "minorVersion", 9); + versionGroup.close(); + + SaveStatisticalModel(model, modelRoot); + modelRoot.close(); + file.close(); + }; + + /** + * Saves the statistical model to the given HDF5 group. + * \param model the model you'd like to save + * \param modelRoot the group where to store the model + * */ + static void SaveStatisticalModel(const StatisticalModelType &model, const H5::Group &modelRoot) { + try { + // create the group structure + + std::string dataTypeStr = StatisticalModelType::RepresenterType::TypeToString(model.GetRepresenter()->GetType()); + + H5::Group representerGroup = modelRoot.createGroup("./representer"); + HDF5Utils::writeStringAttribute(representerGroup, "name", model.GetRepresenter()->GetName()); + HDF5Utils::writeStringAttribute(representerGroup, "version", model.GetRepresenter()->GetVersion()); + HDF5Utils::writeStringAttribute(representerGroup, "datasetType", dataTypeStr); + + model.GetRepresenter()->Save(representerGroup); + representerGroup.close(); + + H5::Group modelGroup = modelRoot.createGroup( "./model" ); + HDF5Utils::writeMatrix(modelGroup, "./pcaBasis", model.GetOrthonormalPCABasisMatrix()); + HDF5Utils::writeVector(modelGroup, "./pcaVariance", model.GetPCAVarianceVector()); + HDF5Utils::writeVector(modelGroup, "./mean", model.GetMeanVector()); + HDF5Utils::writeFloat(modelGroup, "./noiseVariance", model.GetNoiseVariance()); + modelGroup.close(); + + model.GetModelInfo().Save(modelRoot); + + + } catch (H5::Exception& e) { + std::string msg(std::string("an exception occurred while writing HDF5 file \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + } +}; + +} // namespace statismo + +#endif /* STATISMOIO_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoUtils.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoUtils.h new file mode 100644 index 000000000..fe9702577 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatismoUtils.h @@ -0,0 +1,144 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef __UTILS_H_ +#define __UTILS_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define NOMINMAX // avoid including the min and max macro +#include +#include +#endif + + +#include + +#include "CommonTypes.h" +#include "Exceptions.h" + +namespace statismo { + +/** + * \brief A number of small utility functions - internal use only. + */ + + +#ifdef _MSC_VER +#define is_deprecated __declspec(deprecated) +#elif defined(__GNUC__) +#define is_deprecated __attribute__((deprecated)) +#else +#define is_deprecated //uncommon compiler, don't bother +#endif + + +class Utils { + public: + + + + /** + * return string representation of t + */ + template + static std::string toString(T t) { + std::ostringstream os; + os << t; + return os.str(); + } + + + /** return a N(0,1) vector of size n */ + static VectorType generateNormalVector(unsigned n) { + // we would like to use tr1 here as well, but on some versions of visual studio it hangs forever. + // therefore we use the functionality from boost + + // we make the random generate static, to ensure that the seed is only set once, and not with + // every call + static boost::minstd_rand randgen(static_cast(time(0))); + static boost::normal_distribution<> dist(0, 1); + static boost::variate_generator > r(randgen, dist); + + VectorType v = VectorType::Zero(n); + for (unsigned i=0; i < n; i++) { + v(i) = r(); + } + return v; + } + + + static VectorType ReadVectorFromTxtFile(const char *name) { + typedef std::list ListType; + std::list values; + std::ifstream inFile(name, std::ios::in); + if (inFile.good()) { + std::copy(std::istream_iterator(inFile), std::istream_iterator(), std::back_insert_iterator(values)); + inFile.close(); + } else { + throw StatisticalModelException((std::string("Could not read text file ") + name).c_str()); + } + + VectorType v = VectorType::Zero(values.size()); + unsigned i = 0; + for (ListType::const_iterator it = values.begin(); it != values.end(); ++it) { + v(i) = *it; + i++; + } + return v; + } + + + static std::string CreateTmpName(const std::string& extension) { + boost::filesystem::path uniquePath = boost::filesystem::unique_path(); + return uniquePath.replace_extension(extension).string(); + } + +}; + +} // namespace statismo + +#endif /* __UTILS_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.h new file mode 100644 index 000000000..2cfd039f3 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.h @@ -0,0 +1,569 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef STATISTICALMODEL_H_ +#define STATISTICALMODEL_H_ + +#include +#include + +#include "CommonTypes.h" +#include "Config.h" +#include "DataManager.h" +#include "ModelInfo.h" +#include "Representer.h" + + +namespace statismo { + +/** + * \brief A Point/Value pair that is used to specify a value at a given point. + */ +/* +template +class PointValuePair { +public: + typedef typename Representer::PointType PointType; + typedef typename Representer::ValueType ValueType; + + PointValuePair(const PointType& pt, const ValueType& val) : point(pt), value(val) {} + + PointType point; + ValueType value; +}; +*/ + +/** + * \brief Representation of a linear statistical model (PCA Model). + * + * The statistical model class represents a statistical (PCA based) model. + * The implementation is based on the Probabilistic PCA, which includes the Standard PCA as a special case. + * + * Mathematically, the statistical model is a generative model, where a sample is given as a linear combination + * of a \f$n\f$ dimensional latent variable \f$ \alpha \f$ plus additive gaussian noise. + * \f[ S = \mu + W \alpha + \epsilon. \f] + * Here, \f$ \mu \in \mathbf{R}^p \f$ is the mean, \f$W \in \mathbf{R}^{p \times n}\f$ is a linear mapping, + * \f$\alpha \in \mathbf{R}^n \f$ is a vector of latent variables (later also referred to as coefficients) + * and \f$\epsilon \sim \mathcal{N}(0, \sigma^2)\f$ is a noise component. + * The linear mapping \f$ W \f$ is referred to as the PCABasis. It is usually given as \f$W = U D \f$ where \f$U\f$ where + * U is an orthonormal matrix and \f$D\f$ is a scaling matrix, referred to as PCAvariance. + * Usually, \f$U \in \mathbf{R}^{p \times n}\f$ is the matrix of eigenvectors of the data covariance matrix and + * \f$D\f$ the corresponding eigenvalues. + * + * While all the matrices and vectors defined above could be obtained directly, the main goal of this class + * is to abstract from these technicalities, by providing a high level interface to shape models. + * In this high level view, the model represents a multivariate normal distribution over the types defined by the representer + * (which are typically either, surface meshes, point clouds, deformation fields or images). + * This class provides the method to sample from this probability distribution, and to compute the probability + * of given samples directly. + * + */ +template +class StatisticalModel { + public: + typedef Representer RepresenterType ; + typedef typename RepresenterType::DatasetPointerType DatasetPointerType; + typedef typename RepresenterType::DatasetConstPointerType DatasetConstPointerType; + typedef typename RepresenterType::ValueType ValueType; + typedef typename RepresenterType::PointType PointType; + + + typedef Domain DomainType; + + typedef unsigned PointIdType; + + + //typedef PointValuePair PointValuePairType; + typedef std::pair PointValuePairType; + typedef std::pair PointIdValuePairType; + typedef std::list PointValueListType; + typedef std::list PointIdValueListType; + + // Maybe at some point, we can statically define a 3x3 resp. 2x3 matrix type. + typedef MatrixType PointCovarianceMatrixType; + typedef std::pair PointValueWithCovariancePairType; + typedef std::list PointValueWithCovarianceListType; + + + + + /** + * Destructor + */ + virtual ~StatisticalModel(); + + + /** + * @name Creating models + */ + ///@{ + + /** + * Factory method that creates a new Model. + * + * \warning The use of this constructor is discouraged. If possible, use a ModelBuilder to create + * a new model or call Load to load an existing model + * + * \param representer the represener + * \param m the mean + * \param orthonormalPCABasis An orthonormal matrix with the principal Axes. + * \param pcaVariance The Variance for each principal Axis + * \param noiseVariance The variance of the (N(0,noiseVariance)) noise on each point + */ + static StatisticalModel* Create(const RepresenterType* representer, + const VectorType& m, + const MatrixType& orthonormalPCABasis, + const VectorType& pcaVariance, + double noiseVariance) { + return new StatisticalModel(representer, m, orthonormalPCABasis, pcaVariance, noiseVariance); + } + + + + /** + * Destroy the object. + * The same effect can be achieved by deleting the object in the usual + * way using the c++ delete keyword. + */ + void Delete() const { + delete this; + } + + ///@} + + + /** + * @name General Info + */ + ///@{ + /** + * \return The number of PCA components in the model + */ + unsigned int GetNumberOfPrincipalComponents() const; + + /** + * \return A model info object \sa ModelInfo + */ + const ModelInfo& GetModelInfo() const; + ///@} + + /** + * @name Sample from the model + * + * \warning Note that these methods return a new Sample. If the representer used returns naked pointers (i.e. not smart pointers), + * the sample needs to be deleted manually. + */ + ///@{ + + + + /** + * Returns the value of the given sample at the point specified with the ptId + * + * \param sample A sample + * \param ptId the point id where to evaluate the sample + * + * \returns The value of the sample, at the specified point + */ + ValueType EvaluateSampleAtPoint(DatasetConstPointerType sample, unsigned ptId) const ; + + + /** + * Returns the value of the given sample corresponding to the given domain point + * + * \param sample A sample + * \param point the (domain) point on which the sample should be evaluated. + * + * \returns The value of the sample, at the specified point + */ + ValueType EvaluateSampleAtPoint(DatasetConstPointerType sample, const PointType& pt) const; + + + /** + * \return A new sample representing the mean of the model + */ + DatasetPointerType DrawMean() const; + + /** + * Draws the sample with the given coefficients + * + * \param coefficients A coefficient vector. The size of the coefficient vector should be smaller + * than number of factors in the model. Otherwise an exception is thrown. + * \param addNoise If true, the Gaussian noise assumed in the model is added to the sample + * + * \return A new sample + * */ + DatasetPointerType DrawSample(const VectorType& coefficients, bool addNoise = false) const ; + + /** + * As StatisticalModel::DrawSample, but where the coefficients are chosen at random according to a standard normal distribution + * + * \param addNoise If true, the Gaussian noise assumed in the model is added to the sample + * + * \return A new sample + * \sa DrawSample + */ + DatasetPointerType DrawSample(bool addNoise = false) const ; + + + + /** + * Draws the sample corresponding to the ith pca matrix + * + * \param componentNumber The number of the PCA Basis to be retrieved + * + * \return A new sample + * */ + DatasetPointerType DrawPCABasisSample(unsigned componentNumber) const; + + + /** + * @name Point sampling and point information + */ + ///@{ + + /** + * Returns the mean of the model, evaluated at the given point. + * + * \param point A point on the domain the model is defined + * \returns The mean Sample evaluated at the point point + */ + ValueType DrawMeanAtPoint( const PointType& point) const; + + /** + * Returns the mean of the model, evaluated at the given pointid. + * + * \param pointid The pointId of the point where it should be evaluated (as defined by the representer) + * \returns The mean sample evaluated at the given pointId \see DrawMeanAtPoint + */ + ValueType DrawMeanAtPoint( unsigned pointId) const; + + /** + * Returns the value of the sample defined by coefficients at the specified point. + * This method computes the value of the sample only for the given point, and is thus much more + * efficient that calling DrawSample, if only a few points are of interest. + * + * \param coefficients the coefficients of the sample + * \param the point of the sample where it is evaluated + * \param addNoise If true, the Gaussian noise assumed in the model is added to the sample + */ + ValueType DrawSampleAtPoint(const VectorType& coefficients, const PointType& point, bool addNoise = false) const; + + /** + * Returns the value of the sample defined by coefficients at the specified pointID. + * This method computes the value of the sample only for the given point, and is thus much more + * efficient that calling DrawSample, if only a few points are of interest. + * + * \param coefficients the coefficients of the sample + * \param the point of the sample where it is evaluated + * \param addNoise If true, the Gaussian noise assumed in the model is added to the sample + */ + ValueType DrawSampleAtPoint(const VectorType& coefficients, unsigned pointId, bool addNoise = false) const; + + + /** + * Computes the jacobian of the Statistical model at a given point + * \param pt The point where the Jacobian is computed + * \param jacobian Output parameter where the jacobian is stored. + */ + MatrixType GetJacobian(const PointType& pt) const; + + /** + * Computes the jacobian of the Statistical model at a specified pointID + * \param ptId The pointID where the Jacobian is computed + * \param jacobian Output parameter where the jacobian is stored. + */ + MatrixType GetJacobian(unsigned ptId) const; + + /** + * Returns the variance in the model for point pt + * @param pt The point + * @returns a d x d covariance matrix + */ + MatrixType GetCovarianceAtPoint(const PointType& pt1, const PointType& pt2) const; + + /** + * Returns the variance in the model for point pt + * @param pt The point + * @returns a d x d covariance matrix + */ + MatrixType GetCovarianceAtPoint(unsigned ptId1, unsigned ptId2) const; + ///@} + + + /** + * @name Statistical Information from Dataset + */ + ///@{ + + /** + * Returns the covariance matrix for the model. If the model is defined on + * n points, in d dimensions, then this is a \f$nd \times nd\f$ matrix of + * n \f$d \times d \f$ block matrices corresponding to the covariance at each point. + * \warning This method is only useful when $n$ is small, since otherwise the matrix + * becomes huge. + */ + MatrixType GetCovarianceMatrix() const; + + /** + * Returns the inverse covariance matrix. Computes if necessary + */ + MatrixType GetInverseCovarianceMatrix() const; + + /** + * Returns the probability of observing the given dataset under this model. + * If the coefficients \f$\alpha \in \mathbf{R}^n\f$ define the dataset, the probability is + * \f$ + * (2 \pi)^{- \frac{n}{2}} \exp(||\alpha||) + * \f$ + * + * + * \param dataset The dataset + * \return The probability + */ + double ComputeProbability(DatasetConstPointerType dataset) const ; + + /** + * Returns the log probability of observing a given dataset. + * + * \param dataset The dataset + * \return The log probability + * + */ + double ComputeLogProbability(DatasetConstPointerType dataset) const ; + + + /** + * Returns the probability of observing the given coefficients under this model. + * If the coefficients \f$\alpha \in \mathbf{R}^n\f$ define the dataset, the probability is + * \f$ + * (2 \pi)^{- \frac{n}{2}} \exp(||\alpha||) + * \f$ + * + * + * \param coefficients The coefficients \f$\alpha \in \mathbf{R}^n\f$ + * \return The probability + */ + double ComputeProbabilityOfCoefficients(const VectorType& coefficients) const ; + + /** + * Returns the log probability of observing given coefficients. + * + * \param dataset The coefficients \f$\alpha \in \mathbf{R}^n\f$ + * \return The log probability + * + */ + double ComputeLogProbabilityOfCoefficients(const VectorType& coefficients) const ; + + + /** + * Returns the mahalonoibs distance for the given dataset. + */ + double ComputeMahalanobisDistance(DatasetConstPointerType dataset) const; + + + /** + * Returns the coefficients of the latent variables for the given dataset, i.e. + * the vectors of numbers \f$\alpha \f$, such that for the dataset \f$S\f$ it holds that + * \f$ S = \mu + U \alpha\f$ + * + * @returns The coefficient vector \f$\alpha\f$ + */ + VectorType ComputeCoefficients(DatasetConstPointerType dataset) const; + + + /** + * Returns the coefficients of the latent variables for the given values provided in the PointValueList. + * This is useful, when only a part of the dataset is given. + * The method is described in the paper + * + * Probabilistic Modeling and Visualization of the Flexibility in Morphable Models, + * M. Luethi, T. Albrecht and T. Vetter, Mathematics of Surfaces, 2009 + * + * \param pointValues A list with PointValuePairs . + * \param pointValueNoiseVariance The variance of estimated (gaussian) noise at the known points + * + */ + VectorType ComputeCoefficientsForPointValues(const PointValueListType& pointValues, double pointValueNoiseVariance=0.0) const; + + /** + * Similar to ComputeCoefficientsForPointValues, only here there is no global pointValueNoiseVariance. + * Instead, a covariance matrix with noise values is specified for each point. + * The returned coefficients are the mean of the posterior model described in + * + * Posterior Shape Models + * Thomas Albrecht, Marcel Luethi, Thomas Gerig, Thomas Vetter + * Medical Image Analysis 2013 + * + * To get the full posterior model, use the PosteriorModelBuilder + * + * \param pointValuesWithCovariance A list with PointValuePairs and PointCovarianceMatrices. + * + */ + VectorType ComputeCoefficientsForPointValuesWithCovariance(const PointValueWithCovarianceListType& pointValuesWithCovariance) const; + + + /** + * Same as ComputeCoefficientsForPointValues(const PointValueListType& pointValues), but used when the + * point ids, rather than the points are known. + * + * \param pointValues A list with (Point,Value) pairs, a list of (PointId, Value) is provided. + * \param pointValueNoiseVariance The variance of estimated (gaussian) noise at the known points + */ + //RB: I had to modify the method name, to avoid prototype collisions when the PointType corresponds to unsigned (= type of the point id) + VectorType ComputeCoefficientsForPointIDValues(const PointIdValueListType& pointValues, double pointValueNoiseVariance=0.0) const; + + + /** + * @name Low level access + * These methods provide a low level interface to the model content. They are of only limited use for + * an application. Prefer whenever possible the high level functions. + */ + ///@{ + + /** + * Returns the variance of the noise of the error term, that was set when the model was built. + */ + float GetNoiseVariance() const; + + /** + * Returns a vector where each element holds the variance of the corresponding principal component in data space + * */ + const VectorType& GetPCAVarianceVector() const; + + /** + * Returns a vector holding the mean. Assume the mean \f$\mu \subset \mathbf{R}^d\f$ is defined on + * \f$p\f$ points, the returned mean vector \f$m\f$ has dimensionality \f$m \in \mathbf{R}^{dp} \f$, i.e. + * the \f$d\f$ components are stacked into the vector. The order of the components in the vector is + * undefined and depends on the representer. + * */ + const VectorType& GetMeanVector() const; + + /** + * Returns a matrix with the PCA Basis as its columns. + * Assume the shapes \f$s \subset \mathbf{R}^d\f$ are defined on + * \f$n\f$ points, the returned matrix \f$W\f$ has dimensionality \f$W \in \mathbf{R}^{dp \times n} \f$, i.e. + * the \f$d\f$ components are stacked into the matrix. The order of the components in the matrix is + * undefined and depends on the representer. + * + */ + const MatrixType& GetPCABasisMatrix() const; + + /** + * Returns the PCA Matrix, but with its principal axis normalized to unit length. + * \warning This is more expensive than GetPCABasisMatrix as the matrix has to be computed + * and a copy is returned + */ + MatrixType GetOrthonormalPCABasisMatrix() const; + + /** + * Returns an instance for the given coefficients as a vector. + * \param addNoise If true, the Gaussian noise assumed in the model is added to the sample + */ + VectorType DrawSampleVector(const VectorType& coefficients, bool addNoise = false) const ; + ///@} + + + ///@{ + /** + * Sets the model information. This is for library internal use only. + */ + void SetModelInfo(const ModelInfo& modelInfo); + + /** + * Computes the coefficients for the given sample vector. + * This is for library internal use only. + */ + VectorType ComputeCoefficientsForSampleVector(const VectorType& sample) const; + + + /** + * Return an instance of the representer + */ + const RepresenterType* GetRepresenter() const { + return m_representer; + } + + + /** + * Return the domain of the statistical model + */ + const DomainType& GetDomain() const { + return m_representer->GetDomain(); + } + + ///@} + + private: + // computes the M Matrix for the PPCA Method (see Bishop, PRML, Chapter 12) + void CheckAndUpdateCachedParameters() const; + + + + /** + * Create an instance of the StatisticalModel + * @param representer An instance of the representer, used to convert the samples to dataset of the represented type. + */ + StatisticalModel(const RepresenterType* representer, const VectorType& m, const MatrixType& orthonormalPCABasis, const VectorType& pcaVariance, double noiseVariance); + + // to prevent use + StatisticalModel(const StatisticalModel& rhs); + StatisticalModel& operator=(const StatisticalModel& rhs); + + const RepresenterType* m_representer; + + VectorType m_mean; + MatrixType m_pcaBasisMatrix; + VectorType m_pcaVariance; + float m_noiseVariance; + + + // caching + mutable bool m_cachedValuesValid; + + //the matrix M^{-1} in Bishops PRML book. This is roughly the Latent Covariance matrix (but not exactly) + mutable MatrixType m_MInverseMatrix; + + ModelInfo m_modelInfo; +}; + +} // namespace statismo + +#include "StatisticalModel.hxx" + +#endif /* STATISTICALMODEL_H_ */ diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.hxx new file mode 100644 index 000000000..b9abdd11d --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/StatisticalModel.hxx @@ -0,0 +1,519 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#ifndef __StatisticalModel_hxx +#define __StatisticalModel_hxx + +#include + +#include +#include +#include + +#include "Exceptions.h" +#include "HDF5Utils.h" +#include "ModelBuilder.h" +#include "StatisticalModel.h" + +namespace statismo { + +template +StatisticalModel::StatisticalModel(const RepresenterType* representer, const VectorType& m, const MatrixType& orthonormalPCABasis, const VectorType& pcaVariance, double noiseVariance) + : m_representer(representer->Clone()), + m_mean(m), + m_pcaVariance(pcaVariance), + m_noiseVariance(noiseVariance), + m_cachedValuesValid(false) { + VectorType D = pcaVariance.array().sqrt(); + m_pcaBasisMatrix = orthonormalPCABasis * DiagMatrixType(D); +} + + +template +StatisticalModel::~StatisticalModel() { + + if (m_representer != 0) { +// not all representers can implement a const correct version of delete. +// We therefore simply const cast it. This is save here. + const_cast(m_representer)->Delete(); + m_representer = 0; + } + +} + + +template +typename StatisticalModel::ValueType +StatisticalModel::EvaluateSampleAtPoint(const DatasetConstPointerType sample, const PointType& point) const { + unsigned ptid = this->m_representer->GetPointIdForPoint(point); + return EvaluateSampleAtPoint(sample, ptid); +} + + +template +typename StatisticalModel::ValueType +StatisticalModel::EvaluateSampleAtPoint(const DatasetConstPointerType sample, unsigned ptid) const { + return this->m_representer->PointSampleFromSample(sample, ptid); +} + + +template +typename StatisticalModel::DatasetPointerType +StatisticalModel::DrawMean() const { + VectorType coeffs = VectorType::Zero(this->GetNumberOfPrincipalComponents()); + return DrawSample(coeffs, false); +} + + +template +typename StatisticalModel::ValueType +StatisticalModel::DrawMeanAtPoint(const PointType& point) const { + VectorType coeffs = VectorType::Zero(this->GetNumberOfPrincipalComponents()); + return DrawSampleAtPoint(coeffs, point); + +} + +template +typename StatisticalModel::ValueType +StatisticalModel::DrawMeanAtPoint(unsigned pointId) const { + VectorType coeffs = VectorType::Zero(this->GetNumberOfPrincipalComponents()); + return DrawSampleAtPoint(coeffs, pointId, false); + +} + + + + +template +typename StatisticalModel::DatasetPointerType +StatisticalModel::DrawSample(bool addNoise) const { + + // we create random coefficients and draw a random sample from the model + VectorType coeffs = Utils::generateNormalVector(GetNumberOfPrincipalComponents()); + + return DrawSample(coeffs, addNoise); +} + + +template +typename StatisticalModel::DatasetPointerType +StatisticalModel::DrawSample(const VectorType& coefficients, bool addNoise) const { + return m_representer->SampleVectorToSample(DrawSampleVector(coefficients, addNoise)); +} + + +template +typename StatisticalModel::DatasetPointerType +StatisticalModel::DrawPCABasisSample(const unsigned pcaComponent) const { + if (pcaComponent >= this->GetNumberOfPrincipalComponents()) { + throw StatisticalModelException("Wrong pcaComponent index provided to DrawPCABasisSample!"); + } + + + return m_representer->SampleVectorToSample( m_pcaBasisMatrix.col(pcaComponent)); +} + + + +template +VectorType +StatisticalModel::DrawSampleVector(const VectorType& coefficients, bool addNoise) const { + + if (coefficients.size() != this->GetNumberOfPrincipalComponents()) { + throw StatisticalModelException("Incorrect number of coefficients provided !"); + } + + unsigned vectorSize = this->m_mean.size(); + assert (vectorSize != 0); + + VectorType epsilon = VectorType::Zero(vectorSize); + if (addNoise) { + epsilon = Utils::generateNormalVector(vectorSize) * sqrt(m_noiseVariance); + } + + + return m_mean+ m_pcaBasisMatrix * coefficients + epsilon; +} + + +template +typename StatisticalModel::ValueType +StatisticalModel::DrawSampleAtPoint(const VectorType& coefficients, const PointType& point, bool addNoise) const { + + unsigned ptId = this->m_representer->GetPointIdForPoint(point); + + return DrawSampleAtPoint(coefficients, ptId, addNoise); + +} + +template +typename StatisticalModel::ValueType +StatisticalModel::DrawSampleAtPoint(const VectorType& coefficients, const unsigned ptId, bool addNoise) const { + + unsigned dim = m_representer->GetDimensions(); + + VectorType v(dim); + VectorType epsilon = VectorType::Zero(dim); + if (addNoise) { + epsilon = Utils::generateNormalVector(dim) * sqrt(m_noiseVariance); + } + for (unsigned d = 0; d < dim; d++) { + unsigned idx =m_representer->MapPointIdToInternalIdx(ptId, d); + + if (idx >= m_mean.rows()) { + std::ostringstream os; + os << "Invalid idx computed in DrawSampleAtPoint. "; + os << " The most likely cause of this error is that you provided an invalid point id (" << ptId <<")"; + throw StatisticalModelException(os.str().c_str()); + } + + v[d] = m_mean[idx] + m_pcaBasisMatrix.row(idx).dot(coefficients) + epsilon[d]; + } + + return this->m_representer->PointSampleVectorToPointSample(v); +} + + + +template +MatrixType +StatisticalModel::GetCovarianceAtPoint(const PointType& pt1, const PointType& pt2) const { + unsigned ptId1 = this->m_representer->GetPointIdForPoint(pt1); + unsigned ptId2 = this->m_representer->GetPointIdForPoint(pt2); + + return GetCovarianceAtPoint(ptId1, ptId2); +} + +template +MatrixType +StatisticalModel::GetCovarianceAtPoint(unsigned ptId1, unsigned ptId2) const { + unsigned dim = m_representer->GetDimensions(); + MatrixType cov(dim, dim); + + for (unsigned i = 0; i < dim; i++) { + unsigned idxi = m_representer->MapPointIdToInternalIdx(ptId1, i); + VectorType vi = m_pcaBasisMatrix.row(idxi); + for (unsigned j = 0; j < dim; j++) { + unsigned idxj = m_representer->MapPointIdToInternalIdx(ptId2, j); + VectorType vj = m_pcaBasisMatrix.row(idxj); + cov(i,j) = vi.dot(vj); + if (i == j) cov(i,j) += m_noiseVariance; + } + } + return cov; +} + +template +MatrixType +StatisticalModel::GetCovarianceMatrix() const { + MatrixType M = m_pcaBasisMatrix * m_pcaBasisMatrix.transpose(); + M.diagonal() += m_noiseVariance * VectorType::Ones(m_pcaBasisMatrix.rows()); + return M; +} + +template +MatrixType +StatisticalModel::GetInverseCovarianceMatrix() const { + CheckAndUpdateCachedParameters(); + return this->m_MInverseMatrix; +} + + +template +VectorType +StatisticalModel::ComputeCoefficients(DatasetConstPointerType ds) const { + return ComputeCoefficientsForSampleVector(m_representer->SampleToSampleVector(ds)); +} + +template +VectorType +StatisticalModel::ComputeCoefficientsForSampleVector(const VectorType& sample) const { + + CheckAndUpdateCachedParameters(); + + const MatrixType& WT = m_pcaBasisMatrix.transpose(); + + VectorType coeffs = m_MInverseMatrix * (WT * (sample - m_mean)); + return coeffs; +} + + + +template +VectorType +StatisticalModel::ComputeCoefficientsForPointValues(const PointValueListType& pointValueList, double pointValueNoiseVariance) const { + PointIdValueListType ptIdValueList; + + for (typename PointValueListType::const_iterator it = pointValueList.begin(); + it != pointValueList.end(); + ++it) { + ptIdValueList.push_back(PointIdValuePairType(m_representer->GetPointIdForPoint(it->first), it->second)); + } + return ComputeCoefficientsForPointIDValues(ptIdValueList, pointValueNoiseVariance); +} + +template +VectorType +StatisticalModel::ComputeCoefficientsForPointIDValues(const PointIdValueListType& pointIdValueList, double pointValueNoiseVariance) const { + + unsigned dim = m_representer->GetDimensions(); + + double noiseVariance = std::max(pointValueNoiseVariance, (double) m_noiseVariance); + + // build the part matrices with , considering only the points that are fixed + MatrixType PCABasisPart(pointIdValueList.size()* dim, this->GetNumberOfPrincipalComponents()); + VectorType muPart(pointIdValueList.size() * dim); + VectorType sample(pointIdValueList.size() * dim); + + unsigned i = 0; + for (typename PointIdValueListType::const_iterator it = pointIdValueList.begin(); it != pointIdValueList.end(); ++it) { + VectorType val = this->m_representer->PointSampleToPointSampleVector(it->second); + unsigned pt_id = it->first; + for (unsigned d = 0; d < dim; d++) { + PCABasisPart.row(i * dim + d) = this->GetPCABasisMatrix().row(m_representer->MapPointIdToInternalIdx(pt_id, d)); + muPart[i * dim + d] = this->GetMeanVector()[m_representer->MapPointIdToInternalIdx(pt_id, d)]; + sample[i * dim + d] = val[d]; + } + i++; + } + + MatrixType M = PCABasisPart.transpose() * PCABasisPart; + M.diagonal() += noiseVariance * VectorType::Ones(PCABasisPart.cols()); + VectorType coeffs = M.inverse() * PCABasisPart.transpose() * (sample - muPart); + + return coeffs; +} + +template +VectorType +StatisticalModel::ComputeCoefficientsForPointValuesWithCovariance(const PointValueWithCovarianceListType& pointValuesWithCovariance) const { + + // The naming of the variables correspond to those used in the paper + // Posterior Shape Models, + // Thomas Albrecht, Marcel Luethi, Thomas Gerig, Thomas Vetter + // + const MatrixType& Q = m_pcaBasisMatrix; + const VectorType& mu = m_mean; + + unsigned dim = m_representer->GetDimensions(); + + // build the part matrices with , considering only the points that are fixed + // + unsigned numPrincipalComponents = this->GetNumberOfPrincipalComponents(); + MatrixType Q_g(pointValuesWithCovariance.size()* dim, numPrincipalComponents); + VectorType mu_g(pointValuesWithCovariance.size() * dim); + VectorType s_g(pointValuesWithCovariance.size() * dim); + + MatrixType LQ_g(pointValuesWithCovariance.size()* dim, numPrincipalComponents); + + unsigned i = 0; + for (typename PointValueWithCovarianceListType::const_iterator it = pointValuesWithCovariance.begin(); it != pointValuesWithCovariance.end(); ++it) { + VectorType val = m_representer->PointSampleToPointSampleVector(it->first.second); + unsigned pt_id = m_representer->GetPointIdForPoint(it->first.first); + + // In the formulas, we actually need the precision matrix, which is the inverse of the covariance. + const MatrixType pointPrecisionMatrix = it->second.inverse(); + + // Get the three rows pertaining to this point: + const MatrixType Qrows_for_pt_id = Q.block(pt_id * dim, 0, dim, numPrincipalComponents); + + Q_g.block(i * dim, 0, dim, numPrincipalComponents) = Qrows_for_pt_id; + mu_g.block(i * dim, 0, dim, 1) = mu.block(pt_id * dim, 0, dim, 1); + s_g.block(i * dim, 0, dim, 1) = val; + + LQ_g.block(i * dim, 0, dim, numPrincipalComponents) = pointPrecisionMatrix * Qrows_for_pt_id; + i++; + } + + VectorType D2 = m_pcaVariance.array(); + + const MatrixType& Q_gT = Q_g.transpose(); + + MatrixType M = Q_gT * LQ_g; + M.diagonal() += VectorType::Ones(Q_g.cols()); + + MatrixTypeDoublePrecision Minv = M.cast().inverse(); + + // the MAP solution for the latent variables (coefficients) + VectorType coeffs = Minv.cast() * LQ_g.transpose() * (s_g - mu_g); + + return coeffs; + +} + + + +template +double +StatisticalModel::ComputeLogProbability(DatasetConstPointerType ds) const { + VectorType alpha = ComputeCoefficients(ds); + return ComputeLogProbabilityOfCoefficients(alpha); +} + +template +double +StatisticalModel::ComputeProbability(DatasetConstPointerType ds) const { + VectorType alpha = ComputeCoefficients(ds); + return ComputeProbabilityOfCoefficients(alpha); +} + + +template +double +StatisticalModel::ComputeLogProbabilityOfCoefficients(const VectorType& coefficents) const { + return log(pow(2 * PI, -0.5 * this->GetNumberOfPrincipalComponents())) - 0.5 * coefficents.squaredNorm(); +} + +template +double +StatisticalModel::ComputeProbabilityOfCoefficients(const VectorType& coefficients) const { + return pow(2 * PI, - 0.5 * this->GetNumberOfPrincipalComponents()) * exp(- 0.5 * coefficients.squaredNorm()); +} + + +template +double +StatisticalModel::ComputeMahalanobisDistance(DatasetConstPointerType ds) const { + VectorType alpha = ComputeCoefficients(ds); + return std::sqrt(alpha.squaredNorm()); +} + + + +template +float +StatisticalModel::GetNoiseVariance() const { + return m_noiseVariance; +} + + +template +const VectorType& +StatisticalModel::GetMeanVector() const { + return m_mean; +} + +template +const VectorType& +StatisticalModel::GetPCAVarianceVector() const { + return m_pcaVariance; +} + + +template +const MatrixType& +StatisticalModel::GetPCABasisMatrix() const { + return m_pcaBasisMatrix; +} + +template +MatrixType +StatisticalModel::GetOrthonormalPCABasisMatrix() const { + // we can recover the orthonormal matrix by undoing the scaling with the pcaVariance + // (c.f. the method SetParameters) + + assert(m_pcaVariance.maxCoeff() > 1e-8); + VectorType D = m_pcaVariance.array().sqrt(); + return m_pcaBasisMatrix * DiagMatrixType(D).inverse(); +} + + + +template +void +StatisticalModel::SetModelInfo(const ModelInfo& modelInfo) { + m_modelInfo = modelInfo; +} + + +template +const ModelInfo& +StatisticalModel::GetModelInfo() const { + return m_modelInfo; +} + + + +template +unsigned int +StatisticalModel::GetNumberOfPrincipalComponents() const { + return m_pcaBasisMatrix.cols(); +} + +template +MatrixType +StatisticalModel::GetJacobian(const PointType& pt) const { + + unsigned ptId = m_representer->GetPointIdForPoint(pt); + + return GetJacobian(ptId); +} + +template +MatrixType +StatisticalModel::GetJacobian(unsigned ptId) const { + + unsigned Dimensions = m_representer->GetDimensions(); + MatrixType J = MatrixType::Zero(Dimensions, GetNumberOfPrincipalComponents()); + + for(unsigned i = 0; i < Dimensions; i++) { + unsigned idx = m_representer->MapPointIdToInternalIdx(ptId, i); + for(unsigned j = 0; j < GetNumberOfPrincipalComponents(); j++) { + J(i,j) += m_pcaBasisMatrix(idx,j) ; + } + } + return J; +} + +template +void +StatisticalModel::CheckAndUpdateCachedParameters() const { + + if (m_cachedValuesValid == false) { + VectorType I = VectorType::Ones(m_pcaBasisMatrix.cols()); + MatrixType Mmatrix = m_pcaBasisMatrix.transpose() * m_pcaBasisMatrix; + Mmatrix.diagonal() += m_noiseVariance * I; + + m_MInverseMatrix = Mmatrix.inverse(); + + } + m_cachedValuesValid = true; +} + +} // namespace statismo + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/TrivialVectorialRepresenter.h b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/TrivialVectorialRepresenter.h new file mode 100644 index 000000000..2a6ed2443 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/TrivialVectorialRepresenter.h @@ -0,0 +1,202 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#ifndef TRIVIALVECTORIALREPRESENTER_H +#define TRIVIALVECTORIALREPRESENTER_H + +#include +#include + +#include "CommonTypes.h" +#include "Domain.h" +#include "HDF5Utils.h" +#include "Representer.h" + +namespace statismo { + +// A pointId is actually just an unsigned. However, we need to create a distinct type, to disambiguate some of +// the methods. +struct PointIdType { + PointIdType(unsigned ptId_) : ptId(ptId_) {} + PointIdType() : ptId(0) {} + + unsigned ptId; + +}; + + +template <> +struct RepresenterTraits { + typedef statismo::VectorType DatasetPointerType; + typedef statismo::VectorType DatasetConstPointerType; + + typedef PointIdType PointType; + typedef statismo::ScalarType ValueType; + ///@} + + +}; + + + + +/** + * \brief A trivial representer, that does no representation at all, but works directly with vectorial data + * + * \warning This representer is mainly for debugging purposes and not intended to be used for real projets + */ +class TrivialVectorialRepresenter : public Representer { + public: + + typedef statismo::ScalarType ValueType; + typedef statismo::Domain DomainType; + typedef Representer RepresenterBaseType; + + static TrivialVectorialRepresenter* Create() { + return new TrivialVectorialRepresenter(); + } + + static TrivialVectorialRepresenter* Create(unsigned numberOfPoints) { + return new TrivialVectorialRepresenter(numberOfPoints); + } + + void Load(const H5::Group& fg) { + unsigned numPoints = static_cast(statismo::HDF5Utils::readInt(fg, "numberOfPoints")); + initializeObject(numPoints); + } + + TrivialVectorialRepresenter* Clone() const { + return TrivialVectorialRepresenter::Create(m_domain.GetNumberOfPoints()); + } + void Delete() const { + delete this; + } + + virtual ~TrivialVectorialRepresenter() {} + + + std::string GetName() const { + return "TrivialVectorialRepresenter"; + } + unsigned GetDimensions() const { + return 1; + } + std::string GetVersion() const { + return "0.1"; + } + RepresenterBaseType::RepresenterDataType GetType() const { + return RepresenterBaseType::VECTOR; + } + + + void DeleteDataset(DatasetPointerType d) const { }; + DatasetPointerType CloneDataset(DatasetConstPointerType d) const { + return d; + } + + + const DomainType& GetDomain() const { + return m_domain; + } + DatasetConstPointerType GetReference() const { + return VectorType::Zero(m_domain.GetNumberOfPoints()); + } + + VectorType PointToVector(const PointType& pt) const { + // here, the pt type is simply an id (the index into the vector). + VectorType v(1); + v(0) = pt.ptId; + return v; + } + VectorType SampleToSampleVector(DatasetConstPointerType sample) const { + return sample; + } + DatasetPointerType SampleVectorToSample(const statismo::VectorType& sample) const { + return sample; + } + + VectorType PointSampleToPointSampleVector(const ValueType& v) const { + VectorType vec = VectorType::Zero(1); + vec(0) = v; + return vec; + + } + + ValueType PointSampleFromSample(DatasetConstPointerType sample, unsigned ptid) const { + return sample[ptid]; + } + ValueType PointSampleVectorToPointSample(const VectorType& pointSample) const { + return pointSample(0); + } + + + void Save(const H5::Group& fg) const { + HDF5Utils::writeInt(fg, "numberOfPoints", static_cast(m_domain.GetNumberOfPoints())); + } + + unsigned GetPointIdForPoint(const PointType& point) const { + return point.ptId; + } + + + private: + TrivialVectorialRepresenter() {} + + TrivialVectorialRepresenter(unsigned numberOfPoints) { + initializeObject(numberOfPoints); + } + + void initializeObject(unsigned numberOfPoints) { + + // the domain for vectors correspond to the valid indices. + DomainType::DomainPointsListType domainPoints; + for (unsigned i = 0; i < numberOfPoints; i++) { + domainPoints.push_back(PointIdType(i)); + } + m_domain = DomainType(domainPoints); + } + + DomainType m_domain; + + TrivialVectorialRepresenter(const TrivialVectorialRepresenter& orig); + TrivialVectorialRepresenter& operator=(const TrivialVectorialRepresenter& rhs); +}; + +} + +#endif diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/genericRepresenterTest.hxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/genericRepresenterTest.hxx new file mode 100644 index 000000000..72226cd80 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/include/genericRepresenterTest.hxx @@ -0,0 +1,381 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS addINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include + +#include "CommonTypes.h" +#include "HDF5Utils.h" +#include "StatismoUtils.h" + +/** + * This class provides generic tests for representer. The tests need to hold for all representers. + */ +template +class GenericRepresenterTest { + + + // we define typedefs for all required typenames, to force a compilation error if one of them + // is not defined + typedef typename Representer::DatasetConstPointerType DatasetConstPointerType; + typedef typename Representer::DatasetPointerType DatasetPointerType; + typedef typename Representer::PointType PointType; + typedef typename Representer::ValueType ValueType; + + typedef typename Representer::DomainType DomainType; + + public: + + /// Create new test with the given representer. + /// Tests are performed using the given testDataset and the pointValuePair. + /// It is assumed that the PointValuePair is taken from the testDataset (otherwise some tests will fail). + GenericRepresenterTest(const Representer* representer, DatasetConstPointerType testDataset, std::pair pointValuePair) + : m_representer(representer), + m_testDataset(testDataset), + m_testPoint(pointValuePair.first), + m_testValue(pointValuePair.second) { + } + + + bool testSamplePointEvaluation() const { + DatasetConstPointerType sample = m_testDataset; + unsigned id = m_representer->GetPointIdForPoint(m_testPoint); + ValueType val = m_representer->PointSampleFromSample(sample, id); + statismo::VectorType valVec = m_representer->PointSampleToPointSampleVector(val); + + // the obtained value should correspond to the value that is obtained by obtaining the sample vector, and evaluating it at the given position + statismo::VectorType sampleVector = m_representer->SampleToSampleVector(sample); + for (unsigned i = 0; i < m_representer->GetDimensions(); ++i) { + unsigned idx = m_representer->MapPointIdToInternalIdx(id, i); + if (sampleVector(i) != valVec(i)) { + return false; + } + } + return true; + } + + bool testDomainValid() const { + std::cout << "testDomainValid" << std::endl; + + const DomainType domain = m_representer->GetDomain(); + typename DomainType::DomainPointsListType domPoints = domain.GetDomainPoints(); + + if (domPoints.size() == 0) { + std::cout << "representer defined empty domain" << std::endl; + return false; + } + if (domPoints.size() != domain.GetNumberOfPoints()) { + std::cout << "domPoints.size() != domain.GetNumberOfPoints() (" << domPoints.size() << " != " << domain.GetNumberOfPoints() << std::endl; + return false; + } + // if we convert a dataset to a samplevector, the resulting vector needs to have + // as many entries as there are points * dimensions + DatasetConstPointerType sample = m_testDataset; + statismo::VectorType sampleVector = m_representer->SampleToSampleVector(sample); + if (sampleVector.rows() != m_representer->GetDimensions() * domain.GetNumberOfPoints()) { + std::cout << "the dimension of the sampleVector does not agree with the number of points in the domain (#points * dimensionality)" << std::endl; + return false; + } + + + unsigned ptNo = 0; + for (typename DomainType::DomainPointsListType::const_iterator it = domPoints.begin(); + it != domPoints.end(); + ++it) { + // since this can take long, we only do it for every 10th point + if (ptNo % 10 != 0) + break; + + if (m_representer->GetPointIdForPoint(*it) >= domain.GetNumberOfPoints()) { + std::cout << "a point in the domain did not evaluate to a valid point it" << std::endl; + return false; + } + ptNo++; + } + + + return true; + + } + + + + /// test whether converting a sample to a vector and back to a sample yields the original sample + bool testSampleToVectorAndBack() const { + std::cout << "testSampleToVectorToSample" << std::endl; + + statismo::VectorType sampleVec = getSampleVectorFromTestDataset(); + + DatasetConstPointerType reconstructedSample = m_representer->SampleVectorToSample(sampleVec); + + // as we don't know anything about how to compare samples, we compare their vectorial representation + statismo::VectorType reconstructedSampleAsVec = m_representer->SampleToSampleVector(reconstructedSample); + bool isOkay = assertSampleVectorsEqual(sampleVec, reconstructedSampleAsVec); + if (isOkay == false) { + std::cout << "Error: the sample has changed by converting between the representations " << std::endl; + } + return true; + } + + /// test if the pointSamples have the correct dimensionality + bool testPointSampleDimension() const { + std::cout << "testPointSampleDimension" << std::endl; + + statismo::VectorType valVec = m_representer->PointSampleToPointSampleVector(m_testValue); + + if (valVec.rows() != m_representer->GetDimensions()) { + std::cout << "Error: The dimensionality of the pointSampleVector is not the same as the Dimensionality of the representer" << std::endl; + return false; + } + return true; + } + + + /// tests if the conversion from a pointSample and the pointSampleVector, and back to a pointSample + /// yields the original sample. + bool testPointSampleToPointSampleVectorAndBack() const { + std::cout << "testPointSampleToPointSampleVectorAndBack" << std::endl; + + statismo::VectorType valVec = m_representer->PointSampleToPointSampleVector(m_testValue); + ValueType recVal = m_representer->PointSampleVectorToPointSample(valVec); + + // we compare the vectors and not the points, as we don't know how to compare poitns. + statismo::VectorType recValVec = m_representer->PointSampleToPointSampleVector(recVal); + bool ok = assertSampleVectorsEqual(valVec, recValVec); + if (!ok) { + std::cout << "Error: the point sample has changed by converting between the representations" << std::endl; + } + return ok; + } + + /// test if the testSample contains the same entries in the vector as those obtained by taking the + /// pointSample at the corresponding position. + bool testSampleVectorHasCorrectValueAtPoint() const { + std::cout << "testSampleVectorHasCorrectValueAtPoint" << std::endl; + + unsigned ptId = m_representer->GetPointIdForPoint(m_testPoint); + if (ptId < 0 || ptId >= m_representer->GetNumberOfPoints()) { + std::cout << "Error: invalid point id for test point " << ptId << std::endl; + return false; + } + + // the value of the point in the sample vector needs to correspond the the value that was provided + statismo::VectorType sampleVec = getSampleVectorFromTestDataset(); + statismo::VectorType pointSampleVec = m_representer->PointSampleToPointSampleVector(m_testValue); + + for (unsigned d = 0; d < m_representer->GetDimensions(); ++d) { + unsigned idx = m_representer->MapPointIdToInternalIdx(ptId, d); + if (sampleVec[idx] != pointSampleVec[d]) { + std::cout << "Error: the sample vector does not contain the correct value of the pointSample " << std::endl; + return false; + } + } + return true; + + } + + + /// test whether the representer is correctly restored + bool testSaveLoad() const { + std::cout << "testSaveLoad" << std::endl; + + using namespace H5; + + std::string filename = statismo::Utils::CreateTmpName(".rep"); + H5File file; + try { + file = H5File( filename, H5F_ACC_TRUNC ); + } catch (Exception& e) { + std::string msg(std::string("Error: Could not open HDF5 file for writing \n") + e.getCDetailMsg()); + std::cout << msg << std::endl; + return false; + } + H5::Group representerGroup = file.createGroup("/representer"); + + m_representer->Save(representerGroup); + + // We add the required attributes, which are usually written by the StatisticalModel class. + // This is needed, as some representers check on these values. + + statismo::HDF5Utils::writeStringAttribute(representerGroup, "name", m_representer->GetName()); + std::string dataTypeStr = Representer::TypeToString(m_representer->GetType()); + statismo::HDF5Utils::writeStringAttribute(representerGroup, "datasetType", dataTypeStr); + + file.close(); + try { + file = H5File(filename.c_str(), H5F_ACC_RDONLY); + } catch (Exception& e) { + std::string msg(std::string("Error: could not open HDF5 file \n") + e.getCDetailMsg()); + std::cout << msg << std::endl; + return false; + } + + representerGroup.close(); + representerGroup = file.openGroup("/representer"); + + Representer* newRep = Representer::Create(); + newRep->Load(representerGroup); + + bool isOkay = assertRepresenterEqual(newRep, m_representer); + newRep->Delete(); + + return isOkay; + + } + + /// test whether cloning a representer results in a representer with the same behaviour + bool testClone() const { + std::cout << "testClone" << std::endl; + bool isOkay = true; + + Representer* rep = m_representer->Clone(); + if (assertRepresenterEqual(rep, m_representer) == false) { + std::cout << "Error: the clone of the representer is not the same as the representer " << std::endl; + isOkay = false; + } + rep->Delete(); + return isOkay; + } + + /// test if the sample vector dimensions are correct + bool testSampleVectorDimensions() const { + std::cout << "testSampleVectorDimensions()" << std::endl; + statismo::VectorType testSampleVec = getSampleVectorFromTestDataset(); + + bool isOk = m_representer->GetDimensions() * m_representer->GetNumberOfPoints() == testSampleVec.rows(); + if (!isOk) { + std::cout << "Error: Dimensionality of the sample vector does not agree with the representer parameters " + << "dimension and numberOfPoints" << std::endl; + std::cout << testSampleVec.rows() << " != " << m_representer->GetDimensions() << " * " << m_representer->GetNumberOfPoints() << std::endl; + } + return isOk; + } + + /// test whether the name is defined + bool testGetName() const { + std::cout << "testGetName" << std::endl; + + if (m_representer->GetName() == "") { + std::cout << "Error: representer name has to be non empty" << std::endl; + return false; + } + return true; + } + + /// test if the dimensionality is nonnegative + bool testDimensions() const { + std::cout << "testDimensions " << std::endl; + + if (m_representer->GetDimensions() <= 0) { + std::cout << "Error: Dimensionality of representer has to be > 0" << std::endl; + return false; + } + return true; + } + + /// run all the tests + bool runAllTests() { + bool ok = true; + ok = testPointSampleDimension() && ok; + ok = testSamplePointEvaluation() && ok; + ok = testDomainValid() && ok; + ok = testPointSampleToPointSampleVectorAndBack() && ok; + ok = testSampleVectorHasCorrectValueAtPoint() && ok; + ok = testSampleToVectorAndBack() && ok; + ok = testSaveLoad() && ok; + ok = testClone() && ok; + ok = testSampleVectorDimensions() && ok; + ok = testGetName() && ok; + ok = testDimensions() && ok; + return ok; + } + + + private: + + + bool assertRepresenterEqual(const Representer* representer1, const Representer* representer2) const { + if (representer1->GetNumberOfPoints() != representer2->GetNumberOfPoints()) { + std::cout << "the representers do not have the same nubmer of points " <SampleToSampleVector(sample); + return sampleVec; + } + + + const Representer* m_representer; + DatasetConstPointerType m_testDataset; + PointType m_testPoint; + ValueType m_testValue; + +}; + + diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/CMakeLists.txt new file mode 100644 index 000000000..2eb19af75 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/CMakeLists.txt @@ -0,0 +1,14 @@ +find_package(Boost REQUIRED COMPONENTS system filesystem) +find_package(ITKInternalEigen3 REQUIRED) + +add_library(statismo_core STATIC ModelInfo.cxx) + +target_include_directories(statismo_core PUBLIC + ${Boost_INCLUDE_DIRS} + ${HDF5_INCLUDE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../../ITK/include) + +target_link_libraries(statismo_core ${Boost_LIBRARIES} ITKInternalEigen3::Eigen) + +elastix_export_target(statismo_core) diff --git a/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/ModelInfo.cxx b/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/ModelInfo.cxx new file mode 100644 index 000000000..48da39ca7 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/Statismo/core/src/ModelInfo.cxx @@ -0,0 +1,307 @@ +/* + * This file is part of the statismo library. + * + * Author: Marcel Luethi (marcel.luethi@unibas.ch) + * + * Copyright (c) 2011 University of Basel + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * Neither the name of the project's author nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include +#include + +#include "DataManager.h" +#include "Exceptions.h" +#include "HDF5Utils.h" +#include "ModelInfo.h" + +namespace statismo { + +ModelInfo::ModelInfo() { +} + +ModelInfo::ModelInfo(const MatrixType &scores, const ModelInfo::BuilderInfoList &builderInfos) + : m_scores(scores), m_builderInfo(builderInfos) { +} + +ModelInfo::ModelInfo(const MatrixType &scores) + : m_scores(scores) { +} + +ModelInfo::~ModelInfo() {} + +ModelInfo &ModelInfo::operator=(const ModelInfo &rhs) { + if (this == &rhs) { + return *this; + } + this->m_builderInfo = rhs.m_builderInfo; + this->m_scores = rhs.m_scores; + return *this; +} + +ModelInfo::BuilderInfoList ModelInfo::GetBuilderInfoList() const { + return m_builderInfo; +} + +const MatrixType &ModelInfo::GetScoresMatrix() const { + return m_scores; +} + +void +ModelInfo::Save(const H5::H5Location& publicFg) const { + using namespace H5; + + // get time and date + time_t rawtime; + struct tm * timeinfo; + std::time ( &rawtime ); + timeinfo = std::localtime ( &rawtime ); + + + try { + Group publicInfo = publicFg.createGroup("./modelinfo"); + HDF5Utils::writeString(publicInfo, "./build-time", std::asctime (timeinfo)); + if (m_scores.rows() != 0 && m_scores.cols() != 0) { + HDF5Utils::writeMatrix(publicInfo, "./scores", m_scores); + } else { + // HDF5 does not allow us to write empty matrices. Therefore, we write a dummy matrix with 1 element + HDF5Utils::writeMatrix(publicInfo, "./scores", MatrixType::Zero(1,1)); + } + + + for (unsigned i =0; i < m_builderInfo.size(); i++) { + std::ostringstream ss; + ss << "./modelBuilder-" << i; + Group modelBuilderGroup = publicInfo.createGroup(ss.str().c_str()); + m_builderInfo[i].Save(modelBuilderGroup); + modelBuilderGroup.close(); + } + + publicInfo.close(); + + } catch (H5::Exception& e) { + std::string msg(std::string("an exception occurred while writing model info HDF5 file \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + +} + +void +ModelInfo::Load(const H5::H5Location& publicFg) { + using namespace H5; + Group publicModelGroup = publicFg.openGroup("./modelinfo"); + try { + HDF5Utils::readMatrix(publicModelGroup, "./scores", m_scores); + } catch (H5::Exception& e) { + // the likely cause is that there are no scores. so we set them as empty + m_scores.resize(0,0); + } + + if (m_scores.cols() == 1 && m_scores.rows() == 1 && m_scores(0,0) == 0.0) { + // we observed a dummy matrix, that was created when saving the model info. + // This means that no scores have been saved. + m_scores.resize(0,0); + } + + m_builderInfo.clear(); + unsigned numEntries = publicModelGroup.getNumObjs(); + + for (unsigned i = 0; i < numEntries; i++) { + H5std_string key = publicModelGroup.getObjnameByIdx(i); + + // Compatibility to older statismo file-format. + // if we find at this level a dataInfo object, then it needs to be an old statismo file. + if (key.find("dataInfo") != std::string::npos || key.find("builderInfo") != std::string::npos) { + BuilderInfo bi = LoadDataInfoOldStatismoFormat(publicModelGroup); + m_builderInfo.push_back(bi); + // we have all the information that is stored in the info block of an old statismo file. + // hence we can leave + break; + + } + + // check for all modelBuilder objects and compile them into a list + if (key.find("modelBuilder") != std::string::npos) { + + Group modelBuilderGroup = publicModelGroup.openGroup(key.c_str()); + BuilderInfo bi; + bi.Load(modelBuilderGroup); + m_builderInfo.push_back(bi); + } + + } + publicModelGroup.close(); +} + +inline +BuilderInfo +ModelInfo::LoadDataInfoOldStatismoFormat(const H5::H5Location& publicModelGroup) const { + using namespace H5; + + Group dataInfoGroup = publicModelGroup.openGroup("./dataInfo"); + BuilderInfo::KeyValueList dataInfo; + BuilderInfo::FillKeyValueListFromInfoGroup(dataInfoGroup, dataInfo); + dataInfoGroup.close(); + + Group builderInfoGroup = publicModelGroup.openGroup("./builderInfo"); + BuilderInfo::KeyValueList paramInfo; + BuilderInfo::FillKeyValueListFromInfoGroup(builderInfoGroup, paramInfo); + + std::string buildTime= HDF5Utils::readString(publicModelGroup,"build-time"); + + // add the information to a new BuilderInfo object + // as a first step we need to find the builderName from the parameter list + std::string builderName = ""; + for (BuilderInfo::KeyValueList::iterator it = paramInfo.begin(); it != paramInfo.end(); it++) { + if (it->first.find("BuilderName") != std::string::npos) { + builderName = it->second; + paramInfo.erase(it); + break; + } + } + + return BuilderInfo(builderName, buildTime, dataInfo, paramInfo); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +// BuilderInfo +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +BuilderInfo::BuilderInfo(const std::string &modelBuilderName, const std::string &buildTime, const BuilderInfo::DataInfoList &di, const BuilderInfo::ParameterInfoList &pi) + : m_modelBuilderName(modelBuilderName), m_buildtime(buildTime), m_dataInfo(di), m_parameterInfo(pi) { +} + +BuilderInfo::BuilderInfo(const std::string &modelBuilderName, const BuilderInfo::DataInfoList &di, const BuilderInfo::ParameterInfoList &pi) + : m_modelBuilderName(modelBuilderName), m_dataInfo(di), m_parameterInfo(pi) { + + // get time and date + time_t rawtime; + struct tm * timeinfo; + + std::time ( &rawtime ); + timeinfo = std::localtime ( &rawtime ); + m_buildtime = std::asctime (timeinfo); + +} + +BuilderInfo::BuilderInfo() {} + +BuilderInfo::~BuilderInfo() {} + +BuilderInfo &BuilderInfo::operator=(const BuilderInfo &rhs) { + if (this == &rhs) { + return *this; + } + this->m_modelBuilderName =rhs.m_modelBuilderName; + this->m_buildtime = rhs.m_buildtime; + this->m_dataInfo = rhs.m_dataInfo; + this->m_parameterInfo = rhs.m_parameterInfo; + return *this; +} + +BuilderInfo::BuilderInfo(const BuilderInfo &orig) { + operator=(orig); +} + +void +BuilderInfo::Save(const H5::H5Location& modelBuilderGroup) const { + using namespace H5; + + try { + HDF5Utils::writeString(modelBuilderGroup, "./builderName", m_modelBuilderName); + HDF5Utils::writeString(modelBuilderGroup, "./buildTime", m_buildtime); + + Group dataInfoGroup = modelBuilderGroup.createGroup("./dataInfo"); + for (DataInfoList::const_iterator it = m_dataInfo.begin(); it != m_dataInfo.end(); ++it) { + HDF5Utils::writeString(dataInfoGroup, it->first.c_str(), it->second.c_str()); + } + + + dataInfoGroup.close(); + + Group parameterGroup = modelBuilderGroup.createGroup("./parameters"); + for (ParameterInfoList::const_iterator it = m_parameterInfo.begin(); it != m_parameterInfo.end(); ++it) { + HDF5Utils::writeString(parameterGroup, it->first.c_str(), it->second.c_str()); + } + + parameterGroup.close(); + + } catch (H5::Exception& e) { + std::string msg(std::string("an exception occurred while writing model info HDF5 file \n") + e.getCDetailMsg()); + throw StatisticalModelException(msg.c_str()); + } + +} + +void +BuilderInfo::Load(const H5::H5Location& modelBuilderGroup) { + + using namespace H5; + + + m_modelBuilderName = HDF5Utils::readString(modelBuilderGroup, "./builderName"); + m_buildtime = HDF5Utils::readString(modelBuilderGroup, "./buildTime"); + + Group dataInfoGroup = modelBuilderGroup.openGroup("./dataInfo"); + FillKeyValueListFromInfoGroup(dataInfoGroup, m_dataInfo); + dataInfoGroup.close(); + + Group parameterGroup = modelBuilderGroup.openGroup("./parameters"); + FillKeyValueListFromInfoGroup(parameterGroup, m_parameterInfo); + parameterGroup.close(); + + + +} + +const BuilderInfo::DataInfoList &BuilderInfo::GetDataInfo() const { + return m_dataInfo; +} + +const BuilderInfo::ParameterInfoList &BuilderInfo::GetParameterInfo() const { + return m_parameterInfo; +} + +inline +void +BuilderInfo::FillKeyValueListFromInfoGroup(const H5::H5Location& group, KeyValueList& keyValueList) { + keyValueList.clear(); + unsigned numEntries = group.getNumObjs(); + for (unsigned i = 0; i < numEntries; i++) { + H5std_string key = group.getObjnameByIdx(i); + std::string value = HDF5Utils::readString(group, key.c_str()); + keyValueList.push_back(std::make_pair(key, value)); + } +} + + +} // end namespace diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.cxx b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.cxx new file mode 100644 index 000000000..67863d5f4 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.cxx @@ -0,0 +1,21 @@ +/*========================================================================= + * + * Copyright UMC Utrecht and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ + +#include "elxActiveRegistrationModelIntensityMetric.h" + +elxInstallMacro( ActiveRegistrationModelIntensityMetric ); \ No newline at end of file diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.h b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.h new file mode 100644 index 000000000..630e3599a --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.h @@ -0,0 +1,242 @@ +/*========================================================================= + * + * Copyright UMC Utrecht and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef __elxActiveRegistrationModelIntensityMetric_h__ +#define __elxActiveRegistrationModelIntensityMetric_h__ + +#include "elxIncludes.h" // include first to avoid MSVS warning +#include "itkActiveRegistrationModelIntensityMetric.h" + +#include "itkDirectory.h" +#include "itkImageFileReader.h" +#include "itkImageFileWriter.h" + +#include "itkStatisticalModel.h" +#include "itkStandardImageRepresenter.h" +#include "itkPCAModelBuilder.h" +#include "itkReducedVarianceModelBuilder.h" +#include "itkStatismoIO.h" + +namespace elastix +{ + +/** + * \class AdvancedMeanSquaresMetric + * \brief An metric based on the itk::AdvancedMeanSquaresImageToImageMetric. + * + * The parameters used in this class are: + * \parameter Metric: Select this metric as follows:\n + * (Metric "AdvancedMeanSquares") + * \parameter UseNormalization: Bool to use normalization or not.\n + * If true, the MeanSquares is divided by a factor (range/10)^2, + * where range represents the maximum gray value range of the images.\n + * (UseNormalization "true")\n + * The default value is false. + * + * \ingroup Metrics + * + */ + +template< class TElastix > +class ActiveRegistrationModelIntensityMetric : + public + itk::ActiveRegistrationModelIntensityMetric< + typename MetricBase< TElastix >::FixedImageType, + typename MetricBase< TElastix >::MovingImageType >, + public MetricBase< TElastix > +{ +public: + + /** Standard ITK-stuff. */ + typedef ActiveRegistrationModelIntensityMetric Self; + typedef itk::ActiveRegistrationModelIntensityMetric< + typename MetricBase< TElastix >::FixedImageType, + typename MetricBase< TElastix >::MovingImageType > Superclass1; + typedef MetricBase< TElastix > Superclass2; + typedef itk::SmartPointer< Self > Pointer; + typedef itk::SmartPointer< const Self > ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro(ActiveRegistrationModelIntensityMetric, itk::ImageIntensityMetric ); + + /** Name of this class. + * Use this name in the parameter file to select this specific metric. \n + * example: (Metric "AdvancedMeanSquares")\n + */ + elxClassNameMacro( "ActiveRegistrationModelIntensityMetric" ); + + /** Typedefs from the superclass. */ + typedef typename + Superclass1::CoordinateRepresentationType CoordinateRepresentationType; + typedef typename Superclass1::MovingImageType MovingImageType; + typedef typename Superclass1::MovingImagePixelType MovingImagePixelType; + typedef typename Superclass1::MovingImageConstPointer MovingImageConstPointer; + typedef typename Superclass1::FixedImageType FixedImageType; + typedef typename Superclass1::FixedImageConstPointer FixedImageConstPointer; + typedef typename Superclass1::FixedImageRegionType FixedImageRegionType; + typedef typename Superclass1::TransformType TransformType; + typedef typename Superclass1::TransformPointer TransformPointer; + typedef typename Superclass1::InputPointType InputPointType; + typedef typename Superclass1::OutputPointType OutputPointType; + typedef typename Superclass1::TransformParametersType TransformParametersType; + typedef typename Superclass1::TransformJacobianType TransformJacobianType; + typedef typename Superclass1::InterpolatorType InterpolatorType; + typedef typename Superclass1::InterpolatorPointer InterpolatorPointer; + typedef typename Superclass1::RealType RealType; + typedef typename Superclass1::GradientPixelType GradientPixelType; + typedef typename Superclass1::GradientImageType GradientImageType; + typedef typename Superclass1::GradientImagePointer GradientImagePointer; + typedef typename Superclass1::GradientImageFilterType GradientImageFilterType; + typedef typename Superclass1::GradientImageFilterPointer GradientImageFilterPointer; + typedef typename Superclass1::FixedImageMaskType FixedImageMaskType; + typedef typename Superclass1::FixedImageMaskPointer FixedImageMaskPointer; + typedef typename Superclass1::MovingImageMaskType MovingImageMaskType; + typedef typename Superclass1::MovingImageMaskPointer MovingImageMaskPointer; + typedef typename Superclass1::MeasureType MeasureType; + typedef typename Superclass1::DerivativeType DerivativeType; + typedef typename Superclass1::ParametersType ParametersType; + typedef typename Superclass1::FixedImagePixelType FixedImagePixelType; + typedef typename Superclass1::MovingImageRegionType MovingImageRegionType; + typedef typename Superclass1::ImageSamplerType ImageSamplerType; + typedef typename Superclass1::ImageSamplerPointer ImageSamplerPointer; + typedef typename Superclass1::ImageSampleContainerType ImageSampleContainerType; + typedef typename + Superclass1::ImageSampleContainerPointer ImageSampleContainerPointer; + typedef typename Superclass1::FixedImageLimiterType FixedImageLimiterType; + typedef typename Superclass1::MovingImageLimiterType MovingImageLimiterType; + typedef typename + Superclass1::FixedImageLimiterOutputType FixedImageLimiterOutputType; + typedef typename + Superclass1::MovingImageLimiterOutputType MovingImageLimiterOutputType; + typedef typename + Superclass1::MovingImageDerivativeScalesType MovingImageDerivativeScalesType; + + /** The fixed image dimension. */ + itkStaticConstMacro( FixedImageDimension, unsigned int, + FixedImageType::ImageDimension ); + + /** The moving image dimension. */ + itkStaticConstMacro( MovingImageDimension, unsigned int, + MovingImageType::ImageDimension ); + + /** Typedef's inherited from Elastix. */ + typedef typename Superclass2::ElastixType ElastixType; + typedef typename Superclass2::ElastixPointer ElastixPointer; + typedef typename Superclass2::ConfigurationType ConfigurationType; + typedef typename Superclass2::ConfigurationPointer ConfigurationPointer; + typedef typename Superclass2::RegistrationType RegistrationType; + typedef typename Superclass2::RegistrationPointer RegistrationPointer; + typedef typename Superclass2::ITKBaseType ITKBaseType; + + typedef typename Superclass1::FixedImagePointType FixedImagePointType; + typedef typename Superclass1::MovingImagePointType MovingImagePointType; + + typedef typename Superclass1::StatisticalModelImageType StatisticalModelImageType; + typedef typename StatisticalModelImageType::Pointer StatisticalModelImagePointer; + + typedef typename itk::ImageFileReader< StatisticalModelImageType > ImageReaderType; + typedef typename ImageReaderType::Pointer ImageReaderPointer; + + typedef typename Superclass1::StatisticalModelIdType StatisticalModelIdType; + typedef typename Superclass1::StatisticalModelPointer StatisticalModelPointer; + + typedef typename Superclass1::StatisticalModelVectorType StatisticalModelVectorType; + typedef std::vector< std::string > StatisticalModelPathVectorType; + + typedef typename Superclass1::MovingImagePointer MovingImagePointer; + + typedef typename Superclass1::StatisticalModelRepresenterType StatisticalModelRepresenterType; + typedef typename Superclass1::StatisticalModelRepresenterPointer StatisticalModelRepresenterPointer; + + typedef typename Superclass1::StatisticalModelModelBuilderType StatisticalModelBuilderType; + typedef typename Superclass1::StatisticalModelBuilderPointer StatisticalModelBuilderPointer; + + typedef typename Superclass1::StatisticalModelReducedVarianceBuilderType StatisticalModelReducedVarianceBuilderType; + typedef typename Superclass1::StatisticalModelReducedVarianceBuilderPointer StatisticalModelReducedVarianceBuilderPointer; + + typedef typename Superclass1::StatisticalModelDataManagerType StatisticalModelDataManagerType; + typedef typename Superclass1::StatisticalModelDataManagerPointer StatisticalModelDataManagerPointer; + + typedef typename Superclass1::StatisticalModelContainerType StatisticalModelContainerType; + typedef typename Superclass1::StatisticalModelContainerPointer StatisticalModelContainerPointer; + + typedef itk::ImageFileWriter< FixedImageType > FixedImageFileWriterType; + typedef typename FixedImageFileWriterType::Pointer FixedImageFileWriterPointer; + typedef itk::ImageFileWriter< MovingImageType > MovingImageFileWriterType; + typedef typename MovingImageFileWriterType::Pointer MovingImageFileWriterPointer; + + itkSetMacro( MetricNumber, std::string ); + itkGetMacro( MetricNumber, std::string ); + + StatisticalModelDataManagerPointer ReadImagesFromDirectory( std::string imageDataDirectory, + std::string referenceFilename ); + + bool ReadImage( const std::string & imageFilename, StatisticalModelImagePointer & image ); + + StatisticalModelPathVectorType ReadPath( std::string parameter ); + + StatisticalModelVectorType ReadNoiseVariance(); + + StatisticalModelVectorType ReadTotalVariance(); + + /** Sets up a timer to measure the initialization time and + * calls the Superclass' implementation. + */ + virtual void Initialize( void ) override; + + virtual int BeforeAllBase( void ) override; + + virtual void BeforeRegistration( void ) override; + + virtual void AfterEachIteration( void ) override; + + virtual void AfterEachResolution( void ) override; + + virtual void AfterRegistration( void ) override; + +protected: + /** The constructor. */ + ActiveRegistrationModelIntensityMetric() = default; + /** The destructor. */ + ~ActiveRegistrationModelIntensityMetric() override = default; + +private: + elxOverrideGetSelfMacro; + + /** The deleted copy constructor. */ + ActiveRegistrationModelIntensityMetric(const Self &) = delete; + /** The deleted assignment operator. */ + void operator=(const Self &) = delete; + + std::string m_MetricNumber; + StatisticalModelPathVectorType m_LoadIntensityModelFileNames; + StatisticalModelPathVectorType m_SaveIntensityModelFileNames; + StatisticalModelPathVectorType m_ImageDirectories; + StatisticalModelPathVectorType m_ReferenceFilenames; + +}; + +} // end namespace elastix + +#ifndef ITK_MANUAL_INSTANTIATION +#include "elxActiveRegistrationModelIntensityMetric.hxx" +#endif + +#endif // end #ifndef __elxActiveRegistrationModelIntensityMetric_h__ diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.hxx b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.hxx new file mode 100644 index 000000000..b16f0c730 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelIntensityMetric.hxx @@ -0,0 +1,712 @@ +/*========================================================================= + * + * Copyright UMC Utrecht and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef __elxActiveRegistrationModelIntensityMetric_hxx__ +#define __elxActiveRegistrationModelIntensityMetric_hxx__ + +#include "elxActiveRegistrationModelIntensityMetric.h" +#include "itkTimeProbe.h" +#include "itkCastImageFilter.h" + +namespace elastix +{ + +/** + * ******************* Initialize *********************** + */ + + template< class TElastix > + void + ActiveRegistrationModelIntensityMetric< TElastix > + ::Initialize( void ) + { + itk::TimeProbe timer; + timer.Start(); + this->Superclass1::Initialize(); + timer.Stop(); + elxout << "Initialization of ActiveRegistrationModelIntensityMetric metric took: " + << static_cast< long >( timer.GetMean() * 1000 ) << " ms." << std::endl; + + } // end Initialize() + + + +/** + * ***************** BeforeAllBase *********************** + */ + +template< class TElastix > +int +ActiveRegistrationModelIntensityMetric< TElastix > +::BeforeAllBase( void ) +{ + + this->Superclass2::BeforeAllBase(); + + std::string componentLabel( this->GetComponentLabel() ); + std::string metricNumber = componentLabel.substr( 6, 2 ); // strip "Metric" keep number + this->SetMetricNumber( metricNumber ); + + // Paths to shape models for loading + this->m_LoadIntensityModelFileNames = ReadPath( std::string("LoadIntensityModel") ); + + // Paths to shape models for saving + this->m_SaveIntensityModelFileNames = ReadPath( std::string( "SaveIntensityModel" ) ); + + // Paths to directories with images for model building + this->m_ImageDirectories = ReadPath( std::string("BuildIntensityModel") ); + + if( this->m_SaveIntensityModelFileNames.size() > 0 ) + { + if( this->m_SaveIntensityModelFileNames.size() != this->m_ImageDirectories.size() ) + { + itkExceptionMacro( "The number of destinations for saving intensity models must match the number of directories." ); + } + } + + if( this->m_ImageDirectories.size() > 0 ) + { + // Reference images for model building + this->m_ReferenceFilenames = ReadPath( "ReferenceImage" ); + + if( this->m_ReferenceFilenames.size() != this->m_ImageDirectories.size() ) + { + itkExceptionMacro( << "The number of reference images does not match the number of directories given." ); + } + } + + // Write reconstructed image each iteration + std::string value = ""; + this->m_Configuration->ReadParameter( value, "WriteReconstructedImageEachIteration", 0 ); + if( value == "true" ) + { + // this->WriteReconstructedImageEachIterationOn(); + } + + // At least one model must be specified + if( 0 == ( this->m_LoadIntensityModelFileNames.size() + this->m_ImageDirectories.size() ) ) + { + itkExceptionMacro( << "No statistical image model specified for " << this->GetComponentLabel() << "." << std::endl + << " Specify previously built models with (LoadIntensityModel" << this->GetMetricNumber() + << " \"path/to/hdf5/file1\" \"path/to/hdf5/file2\" ) or " << std::endl + << " specify directories with shapes using (BuildIntensityModel" << this->GetMetricNumber() + << " \"path/to/directory1\" \"path/to/directory2\") and " << std::endl + << " corresponding reference shapes using \"(ReferenceImage" << this->GetMetricNumber() + << " \"path/to/reference1\" \"path/to/reference2\")." << std::endl + ); + } + + return 0; + +} // end BeforeAllBase() + + + +/** + * ***************** BeforeRegistration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelIntensityMetric< TElastix > +::BeforeRegistration( void ) +{ + StatisticalModelContainerPointer statisticalModelContainer = StatisticalModelContainerType::New(); + statisticalModelContainer->Reserve( this->m_LoadIntensityModelFileNames.size() + this->m_ImageDirectories.size() ); + + // Load models + if( this->m_LoadIntensityModelFileNames.size() > 0 ) + { + elxout << std::endl << "Loading models for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << " ... " << std::endl; + + for( StatisticalModelIdType statisticalModelId = 0; statisticalModelId < this->m_LoadIntensityModelFileNames.size(); ++statisticalModelId ) + { + // Load model + StatisticalModelPointer statisticalModel; + try + { + StatisticalModelRepresenterPointer representer = StatisticalModelRepresenterType::New(); + statisticalModel = itk::StatismoIO< StatisticalModelImageType > ::LoadStatisticalModel( representer.GetPointer(), this->m_LoadIntensityModelFileNames[ statisticalModelId ] ); + statisticalModelContainer->SetElement( statisticalModelId, statisticalModel ); + } + catch( statismo::StatisticalModelException &e ) + { + itkExceptionMacro( "Error loading statistical shape model: " << e.what() ); + } + + elxout << " Loaded model " << this->m_LoadIntensityModelFileNames[ statisticalModelId ].c_str() << "." << std::endl + << " Number modes: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl + << " Variance: " << statisticalModel->GetPCAVarianceVector() << "." << std::endl + << " Noise variance: " << statisticalModel->GetNoiseVariance() << "." << std::endl; + } + } + + // Build models + if( this->m_ImageDirectories.size() ) + { + elxout << std::endl << "Building models for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << " ... " << std::endl; + + // Noise parameter for probabilistic pca model + StatisticalModelVectorType noiseVariance = this->ReadNoiseVariance(); + + // Number of principal components to keep by variance + StatisticalModelVectorType totalVariance = this->ReadTotalVariance(); + + // Loop over all data directories + for( StatisticalModelIdType statisticalModelId = 0; statisticalModelId < this->m_ImageDirectories.size(); ++statisticalModelId ) + { + // Load data + StatisticalModelDataManagerPointer dataManager; + try + { + dataManager = this->ReadImagesFromDirectory( this->m_ImageDirectories[ statisticalModelId ], this->m_ReferenceFilenames[ statisticalModelId ] ); + } + catch( statismo::StatisticalModelException &e ) + { + itkExceptionMacro( "Error loading samples in " << this->m_ImageDirectories[ statisticalModelId ] << ": " << e.what() ); + } + + // Build model + elxout << " Building statistical intensity model for metric " << this->GetMetricNumber() << " ... "; + StatisticalModelPointer statisticalModel; + try + { + StatisticalModelBuilderPointer pcaModelBuilder = StatisticalModelBuilderType::New(); + statisticalModel = pcaModelBuilder->BuildNewModel( dataManager->GetData(), noiseVariance[ statisticalModelId ] ); + elxout << " Done." << std::endl + << " Number of modes: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl + << " Variance: " << statisticalModel->GetPCAVarianceVector() + << " Noise variance: " << statisticalModel->GetNoiseVariance() + << "." << std::endl; + + // Pick out first principal components + if( totalVariance[ statisticalModelId ] < 1.0 ) + { + elxout << " Reducing model to " << totalVariance[ statisticalModelId ] * 100.0 << "% variance ... "; + StatisticalModelReducedVarianceBuilderPointer reducedVarianceModelBuilder = StatisticalModelReducedVarianceBuilderType::New(); + statisticalModel = reducedVarianceModelBuilder->BuildNewModelWithVariance( statisticalModel, totalVariance[ statisticalModelId ] ); + elxout << " Done." << std::endl + << " Number of modes retained: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl; + } + } + catch( statismo::StatisticalModelException &e ) + { + itkExceptionMacro( << "Error building statistical shape model: " << e.what() ); + } + + if( this->m_SaveIntensityModelFileNames.size() > 0 ) + { + elxout << " Saving intensity model " << statisticalModelId << " to " << this->m_SaveIntensityModelFileNames[ statisticalModelId ] << ". " << std::endl; + try + { + itk::StatismoIO< StatisticalModelImageType >::SaveStatisticalModel(statisticalModel, this->m_SaveIntensityModelFileNames[ statisticalModelId ]); + } + catch( statismo::StatisticalModelException& e ) + { + itkExceptionMacro( "Could not save shape model to " << this->m_SaveIntensityModelFileNames[ statisticalModelId ] << "."); + } + } + + statisticalModelContainer->SetElement( statisticalModelId, statisticalModel ); + } + } + + this->SetStatisticalModelContainer( statisticalModelContainer ); + + std::cout << std::endl; +} // end BeforeRegistration() + + + +/** + * ***************** ReadPath *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelIntensityMetric< TElastix >::StatisticalModelPathVectorType +ActiveRegistrationModelIntensityMetric< TElastix > +::ReadPath( std::string path ) +{ + std::ostringstream key; + key << path << this->GetMetricNumber(); + + StatisticalModelPathVectorType pathVector; + for( unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i ) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + pathVector.push_back( value ); + } + + return pathVector; +} + + + +/** + * ***************** ReadNoiseVariance *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelIntensityMetric< TElastix >::StatisticalModelVectorType +ActiveRegistrationModelIntensityMetric< TElastix > +::ReadNoiseVariance() +{ + std::ostringstream key( "NoiseVariance", std::ios_base::ate ); + key << this->GetMetricNumber(); + + StatisticalModelVectorType noiseVarianceVector = StatisticalModelVectorType( this->m_ImageDirectories.size(), 0.0 ); + unsigned int n = this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); + + if( n == 0 ) + { + elxout << "WARNING: NoiseVariance not specified for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << "." << std::endl + << " A default value of " << noiseVarianceVector[ 0 ] << " will be used (non-probabilistic PCA)." << std::endl; + + return noiseVarianceVector; + } + + for(unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + + char *e; + errno = 0; + double noiseVariance = std::strtod( value.c_str(), &e ); + + if ( *e != '\0' || // error, we didn't consume the entire string + errno != 0 ) // error, overflow or underflow + { + itkExceptionMacro( << "Invalid number format for NoiseVariance entry " << i << "." ); + } + + if( noiseVariance < 0 ) + { + itkExceptionMacro( << "NoiseVariance entry number " << i << " is negative (" << noiseVariance << "). Variance must be positive by definition. Please correct your parameter file." ); + } + + elxout << " " << key.str() << ": " << noiseVariance << std::endl; + + noiseVarianceVector[ i ] = noiseVariance; + } + + if( n == 1 && noiseVarianceVector.size() > 1 ) + { + // Fill the rest of the elements + noiseVarianceVector.fill( noiseVarianceVector[ 0 ] ); + } + + return noiseVarianceVector; +} + + + +/** + * ***************** ReadTotalVariance *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelIntensityMetric< TElastix >::StatisticalModelVectorType +ActiveRegistrationModelIntensityMetric< TElastix > +::ReadTotalVariance() +{ + std::ostringstream key( "TotalVariance", std::ios_base::ate ); + key << this->GetMetricNumber(); + + StatisticalModelVectorType totalVarianceVector = StatisticalModelVectorType( this->m_ImageDirectories.size(), 1.0 ); + unsigned int n = this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); + + if( n == 0 ) + { + elxout << "WARNING: TotalVariance not specified for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << "." << std::endl + << " A default value of 1.0 will be used (all principal componontents) for metric " << this->GetMetricNumber() << "." << std::endl; + + return totalVarianceVector; + } + + for(unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + + char *e; + errno = 0; + double totalVariance = std::strtod( value.c_str(), &e ); + + if ( *e != '\0' || // error, we didn't consume the entire string + errno != 0 ) // error, overflow or underflow + { + itkExceptionMacro( << "Invalid number format for NoiseVariance entry " << i << "." ); + } + + if( totalVariance < 0.0 || totalVariance > 1.0 ) + { + itkExceptionMacro( << "TotalVariance entries must lie in [0.0; 1.0] but entry number " << i << " is " << totalVariance << ". Please correct your parameter file." ); + } + + totalVarianceVector[ i ] = totalVariance; + } + + if( n == 1 && totalVarianceVector.size() > 1 ) + { + // Need to fill the rest of the elements + totalVarianceVector.fill( totalVarianceVector[ 0 ] ); + } + + return totalVarianceVector; +} + + + +/** + * ***************** LoadImagesFromDirectory *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelIntensityMetric< TElastix >::StatisticalModelDataManagerPointer +ActiveRegistrationModelIntensityMetric< TElastix > +::ReadImagesFromDirectory( + std::string imageDataDirectory, + std::string referenceFilename ) +{ + + itk::Directory::Pointer directory = itk::Directory::New(); + if( !directory->Load( imageDataDirectory.c_str() ) ) + { + itkExceptionMacro( "No files found in " << imageDataDirectory << "."); + } + + // Read reference image + StatisticalModelImagePointer reference = StatisticalModelImageType::New(); + if( !ReadImage( referenceFilename, reference ) ) + { + itkExceptionMacro( "Failed to read reference file " << referenceFilename << "."); + } + + StatisticalModelRepresenterPointer representer = StatisticalModelRepresenterType::New(); + representer->SetReference( reference ); + + StatisticalModelDataManagerPointer dataManager = StatisticalModelDataManagerType::New(); + dataManager->SetRepresenter( representer.GetPointer() ); + + for( unsigned int i = 0; i < directory->GetNumberOfFiles(); ++i ) + { + const char* filename = directory->GetFile( i ); + if( std::strcmp( filename, referenceFilename.c_str() ) == 0 || std::strcmp( filename, "." ) == 0 || std::strcmp( filename, ".." ) == 0 ) + { + continue; + } + + std::string fullpath = imageDataDirectory + "/" + filename; + StatisticalModelImagePointer image = StatisticalModelImageType::New(); + + if( this->ReadImage( fullpath.c_str(), image ) ) + { + dataManager->AddDataset( image, fullpath.c_str() ); + } + } + + return dataManager; +} + + + +/** + * ************** ReadImage ********************* + */ + +template< class TElastix > +bool +ActiveRegistrationModelIntensityMetric< TElastix > +::ReadImage( + const std::string& imageFilename, + StatisticalModelImagePointer& image ) +{ + // Read the input mesh. */ + ImageReaderPointer imageReader = ImageReaderType::New(); + imageReader->SetFileName( imageFilename.c_str() ); + + elxout << " Reading input image: " << imageFilename << " ... "; + try + { + imageReader->UpdateLargestPossibleRegion(); + elxout << "done." << std::endl; + } + catch( itk::ExceptionObject & err ) + { + elxout << " skipping " << imageFilename << "(not a valid image file or file does not exist)." << std::endl; + return false; + } + + image = imageReader->GetOutput(); + return true; + +} // end ReadImage() + + + +/** + * ***************** AfterEachIteration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelIntensityMetric< TElastix > +::AfterEachIteration( void ) +{ + const unsigned int iter = this->m_Elastix->GetIterationCounter(); + + /** Decide whether or not to write final model image */ + bool writeIntensityModelReconstructionAfterEachIteration = false; + this->m_Configuration->ReadParameter( writeIntensityModelReconstructionAfterEachIteration, + "WriteIntensityModelReconstructionAfterEachIteration", 0, false ); + + + if( writeIntensityModelReconstructionAfterEachIteration ) { + this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->Update(); + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->GetElement( statisticalModelId ) + ->ComputeCoefficients(this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->GetOutput()); + + + std::string imageFormat = "nii.gz"; + this->m_Configuration->ReadParameter(imageFormat, "ResultImageFormat", 0, false); + + std::ostringstream makeFileName(""); + makeFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "Iteration" << iter + << "Image." << imageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " image for " + << this->GetComponentLabel() << " after iteration " << iter << " to " << makeFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + + MovingImageFileWriterPointer imageWriter = MovingImageFileWriterType::New(); + imageWriter->SetInput(this->GetStatisticalModelContainer()->GetElement(statisticalModelId)->DrawSample(coeffs)); + imageWriter->SetFileName(makeFileName.str()); + imageWriter->Update(); + } + } +} + + + +/** + * ***************** AfterEachResolution *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelIntensityMetric< TElastix > +::AfterEachResolution( void ) +{ + const unsigned int level = this->m_Registration->GetAsITKBaseType()->GetCurrentLevel(); + + /** Decide whether or not to write model image after each resolution */ + bool writeIntensityModelReconstructionAfterEachResolution = false; + this->m_Configuration->ReadParameter( writeIntensityModelReconstructionAfterEachResolution, + "WriteIntensityModelReconstructionAfterEachResolution", 0, false ); + + if( writeIntensityModelReconstructionAfterEachResolution ) { + this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->Update(); + + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->GetElement( statisticalModelId ) + ->ComputeCoefficients(this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->GetOutput()); + + + std::string imageFormat = "nii.gz"; + this->m_Configuration->ReadParameter(imageFormat, "ResultImageFormat", 0, false); + + std::ostringstream makeFileName(""); + makeFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "Resolution" << level + << "Image." << imageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " image " << " for " + << this->GetComponentLabel() << " after resolution " << level << " to " << makeFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + + MovingImageFileWriterPointer imageWriter = MovingImageFileWriterType::New(); + imageWriter->SetInput(this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->DrawSample(coeffs)); + imageWriter->SetFileName(makeFileName.str()); + imageWriter->Update(); + } + } +} // end AfterEachResolution() + + + +/** + * ***************** AfterRegistration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelIntensityMetric< TElastix > +::AfterRegistration( void ) +{ + /** Decide whether or not to write the mean images */ + bool writeIntensityModelMeanImage = false; + this->m_Configuration->ReadParameter( writeIntensityModelMeanImage, + "WriteIntensityModelMeanImageAfterRegistration", 0, false ); + + if( writeIntensityModelMeanImage ) + { + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) + { + std::string meanImageFormat = "nii.gz"; + this->m_Configuration->ReadParameter( meanImageFormat, "ResultImageFormat", 0, false ); + + std::ostringstream makeFileName( "" ); + makeFileName + << this->m_Configuration->GetCommandLineArgument( "-out" ) + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "MeanImage." << meanImageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " mean image for " << this->GetComponentLabel() << " to " + << makeFileName.str() << std::endl; + + FixedImageFileWriterPointer imageWriter = FixedImageFileWriterType::New(); + imageWriter->SetInput( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->DrawMean() ); + imageWriter->SetFileName( makeFileName.str() ); + imageWriter->Update(); + } + } + + /** Decide whether or not to write final model image */ + bool writeIntensityModelFinalReconstruction = false; + this->m_Configuration->ReadParameter( writeIntensityModelFinalReconstruction, + "WriteIntensityModelFinalReconstructionAfterRegistration", 0, false ); + + /** Decide whether or not to write sample probability */ + bool writeIntensityModelFinalReconstructionProbability = false; + this->m_Configuration->ReadParameter( writeIntensityModelFinalReconstructionProbability, + "WriteIntensityModelFinalReconstructionProbabilityAfterRegistration", 0, false ); + + if( writeIntensityModelFinalReconstruction || writeIntensityModelFinalReconstructionProbability ) + { + this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->Update(); + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->GetElement( + statisticalModelId )->ComputeCoefficients( this->GetElastix()->GetElxResamplerBase()->GetAsITKBaseType()->GetOutput()); + + if( writeIntensityModelFinalReconstruction ) { + std::string imageFormat = "nii.gz"; + this->m_Configuration->ReadParameter(imageFormat, "ResultImageFormat", 0, false); + + std::ostringstream makeFileName(""); + makeFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "FinalImage." << imageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " final image for " << this->GetComponentLabel() << " to " << + makeFileName.str() << "."; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + + MovingImageFileWriterPointer imageWriter = MovingImageFileWriterType::New(); + imageWriter->SetInput(this->GetStatisticalModelContainer()->ElementAt(statisticalModelId)->DrawSample(coeffs)); + imageWriter->SetFileName(makeFileName.str()); + imageWriter->Update(); + } + + if( writeIntensityModelFinalReconstructionProbability ) { + std::ostringstream makeProbFileName; + makeProbFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "Probability.txt"; + + elxout << " Writing intensity model " << statisticalModelId << " final image probablity for " << this->GetComponentLabel() + << " to " << makeProbFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + std::ofstream probabilityFile; + probabilityFile.open(makeProbFileName.str()); + probabilityFile << + this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->ComputeLogProbabilityOfCoefficients(coeffs); + probabilityFile.close(); + } + } + } + + bool writeIntensityModelPrincipalComponents = false; + this->m_Configuration->ReadParameter( writeIntensityModelPrincipalComponents, + "WriteIntensityModelPrincipalComponentsAfterRegistration", 0, false ); + + if( writeIntensityModelPrincipalComponents ) + { + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) + { + std::string imageFormat = "nii.gz"; + this->m_Configuration->ReadParameter( imageFormat, "ResultImageFormat", 0, false ); + + MovingImageFileWriterPointer imageWriter = MovingImageFileWriterType::New(); + + for( unsigned int j = 0; j < this->GetStatisticalModelContainer()->GetElement( statisticalModelId )->GetNumberOfPrincipalComponents(); j++ ) { + StatisticalModelVectorType plus3std = StatisticalModelVectorType( + this->GetStatisticalModelContainer()->GetElement( statisticalModelId )->GetNumberOfPrincipalComponents(), 0.0 ); + plus3std[ j ] = 3.0; + + std::ostringstream makeFileNameP3STD( "" ); + makeFileNameP3STD + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "PC" << j << "plus3std." << imageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " principal component " << j << " plus 3 standard deviations" + << " for " << this->GetComponentLabel() << " to " << makeFileNameP3STD.str() << std::endl; + imageWriter->SetInput(this->GetStatisticalModelContainer()->GetElement( statisticalModelId )->DrawSample( plus3std ) ) ; + imageWriter->SetFileName( makeFileNameP3STD.str() ); + imageWriter->Update(); + + StatisticalModelVectorType minus3std = StatisticalModelVectorType( + this->GetStatisticalModelContainer()->GetElement( statisticalModelId )->GetNumberOfPrincipalComponents(), 0.0 ); + minus3std[ j ] = -3.0; + + std::ostringstream makeFileNamePCM3STD(""); + makeFileNamePCM3STD + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "PC" << j << "minus3std." << imageFormat; + + elxout << " Writing intensity model " << statisticalModelId << " principal component " << j << " minus 3 standard deviations" + << " for " << this->GetComponentLabel() << " to " << makeFileNamePCM3STD.str() << std::endl; + imageWriter->SetInput(this->GetStatisticalModelContainer()->GetElement(statisticalModelId)->DrawSample( minus3std ) ); + imageWriter->SetFileName( makeFileNamePCM3STD.str() ); + imageWriter->Update(); + } + } + } +} // end AfterRegistration() + + +} // end namespace elastix + +#endif // end #ifndef __elxActiveRegistrationModelIntensityMetric_hxx__ diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.cxx b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.cxx new file mode 100644 index 000000000..024644800 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.cxx @@ -0,0 +1,17 @@ +/*====================================================================== + + This file is part of the elastix software. + + Copyright (c) University Medical Center Utrecht. All rights reserved. + See src/CopyrightElastix.txt or http://elastix.isi.uu.nl/legal.php for + details. + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +======================================================================*/ + +#include "elxActiveRegistrationModelShapeMetric.h" + +elxInstallMacro( ActiveRegistrationModelShapeMetric ); diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.h b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.h new file mode 100644 index 000000000..059c4c640 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.h @@ -0,0 +1,226 @@ +/*====================================================================== + +This file is part of the elastix software. + +Copyright (c) University Medical Center Utrecht. All rights reserved. +See src/CopyrightElastix.txt or http://elastix.isi.uu.nl/legal.php for +details. + +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the above copyright notices for more information. + +======================================================================*/ +#ifndef __elxActiveRegistrationModelShapeMetric_H__ +#define __elxActiveRegistrationModelShapeMetric_H__ + +#include "elxIncludes.h" +#include "itkActiveRegistrationModelShapeMetric.h" + +#include "itkDirectory.h" +#include "itkMeshFileReader.h" +#include "itkMeshFileWriter.h" + +#include "itkStatisticalModel.h" +#include "itkStandardMeshRepresenter.h" +#include "itkPCAModelBuilder.h" +#include "itkReducedVarianceModelBuilder.h" +#include "itkStatismoIO.h" + +namespace elastix +{ +/** + * \class ActiveRegistrationModelShapeMetric + * \brief A dummy metric to generate transformed meshes at each iteration. + * This metric does not contribute to the cost function, but provides the + * options to read vtk polydata meshes from the command-line and write the + * transformed meshes to disk each iteration or resolution level. + * The command-line options for input meshes is: -fmesh<[A-Z]>. + * This metric can be used as a base for other mesh-based penalties. + * + * The parameters used in this class are: + * \parameter Metric: Select this metric as follows:\n + * (Metric "ActiveRegistrationModelShapeMetric") + * \parameter + * (WriteResultMeshAfterEachIteration "True") + * \parameter + * (WriteResultMeshAfterEachResolution "True") + * \ingroup Metrics + * + */ + +//TODO: define a base class templated on meshes in stead of 2 pointsets. +template< class TElastix > +class ActiveRegistrationModelShapeMetric : + public + itk::ActiveRegistrationModelShapeMetric< + typename MetricBase< TElastix >::FixedPointSetType, + typename MetricBase< TElastix >::MovingPointSetType >, + public MetricBase< TElastix > +{ +public: + + /** Standard ITK-stuff. */ + typedef ActiveRegistrationModelShapeMetric Self; + typedef itk::ActiveRegistrationModelShapeMetric< + typename MetricBase< TElastix >::FixedPointSetType, + typename MetricBase< TElastix >::MovingPointSetType > Superclass1; + typedef MetricBase< TElastix > Superclass2; + typedef itk::SmartPointer< Self > Pointer; + typedef itk::SmartPointer< const Self > ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro(ActiveRegistrationModelShapeMetric, itk::ActiveRegistrationModelShapeMetric ); + + /** Name of this class. + * Use this name in the parameter file to select this specific metric. \n + * example: (Metric "ActiveRegistrationModelShapeMetric")\n + */ + elxClassNameMacro( "ActiveRegistrationModelShapeMetric" ); + + /** Typedefs from the superclass. */ + + typedef typename Superclass1::CoordinateRepresentationType CoordinateRepresentationType; + typedef typename Superclass1::FixedPointSetType FixedPointSetType; + typedef typename Superclass1::MovingPointSetType MovingPointSetType; + + typedef typename Superclass1::TransformType TransformType; + typedef typename Superclass1::TransformPointer TransformPointer; + typedef typename Superclass1::InputPointType InputPointType; + typedef typename Superclass1::OutputPointType OutputPointType; + typedef typename Superclass1::TransformParametersType TransformParametersType; + typedef typename Superclass1::TransformJacobianType TransformJacobianType; + typedef typename Superclass1::FixedImageMaskType FixedImageMaskType; + typedef typename Superclass1::FixedImageMaskPointer FixedImageMaskPointer; + typedef typename Superclass1::MovingImageMaskType MovingImageMaskType; + typedef typename Superclass1::MovingImageMaskPointer MovingImageMaskPointer; + + typedef typename Superclass1::MeasureType MeasureType; + typedef typename Superclass1::DerivativeType DerivativeType; + typedef typename Superclass1::ParametersType ParametersType; + + typedef typename OutputPointType::CoordRepType CoordRepType; + + /** Other typedef's. */ + typedef itk::Object ObjectType; + + typedef itk::AdvancedCombinationTransform< CoordRepType, + itkGetStaticConstMacro( FixedImageDimension ) > CombinationTransformType; + typedef typename + CombinationTransformType::InitialTransformType InitialTransformType; + + /** Typedefs inherited from elastix. */ + typedef typename Superclass2::ElastixType ElastixType; + typedef typename Superclass2::ElastixPointer ElastixPointer; + typedef typename Superclass2::ConfigurationType ConfigurationType; + typedef typename Superclass2::ConfigurationPointer ConfigurationPointer; + typedef typename Superclass2::RegistrationType RegistrationType; + typedef typename Superclass2::RegistrationPointer RegistrationPointer; + typedef typename Superclass2::ITKBaseType ITKBaseType; + typedef typename Superclass2::FixedImageType FixedImageType; + typedef typename Superclass2::MovingImageType MovingImageType; + + /** The fixed image dimension. */ + itkStaticConstMacro( FixedImageDimension, unsigned int, + FixedImageType::ImageDimension ); + itkStaticConstMacro( MovingImageDimension, unsigned int, + MovingImageType::ImageDimension ); + + typedef FixedImageType ImageType; + + typedef typename Superclass1::StatisticalModelVectorType StatisticalModelVectorType; + typedef std::vector< std::string > StatisticalModelPathVectorType; + + /** ActiveRegistrationModel types */ + typedef typename Superclass1::StatisticalModelMeshType StatisticalModelMeshType; + typedef typename Superclass1::StatisticalModelMeshPointer StatisticalModelMeshPointer; + + typedef typename Superclass1::MeshReaderType MeshReaderType; + typedef typename Superclass1::MeshReaderPointer MeshReaderPointer; + + typedef typename Superclass1::StatisticalModelRepresenterType StatisticalModelRepresenterType; + typedef typename Superclass1::StatisticalModelRepresenterPointer StatisticalModelRepresenterPointer; + + typedef typename Superclass1::ModelBuilderType StatisticalModelBuilderType; + typedef typename Superclass1::ModelBuilderPointer StatisticalModelBuilderPointer; + + typedef typename Superclass1::StatisticalModelReducedVarianceBuilderType StatisticalModelReducedVarianceBuilderType; + typedef typename Superclass1::StatisticalModelReducedVarianceBuilderPointer StatisticalModelReducedVarianceBuilderPointer; + + typedef typename Superclass1::StatisticalModelIdType StatisticalModelIdType; + typedef typename Superclass1::StatisticalModelPointer StatisticalModelPointer; + + typedef typename Superclass1::StatisticalModelDataManagerType StatisticalModelDataManagerType; + typedef typename Superclass1::StatisticalModelDataManagerPointer StatisticalModelDataManagerPointer; + + typedef typename Superclass1::StatisticalModelContainerType StatisticalModelContainerType; + typedef typename Superclass1::StatisticalModelContainerPointer StatisticalModelContainerPointer; + + itkSetMacro( MetricNumber, unsigned long ); + itkGetMacro( MetricNumber, unsigned long ); + + StatisticalModelDataManagerPointer ReadMeshesFromDirectory( std::string shapeDataDirectory, + std::string fixedPointSetFilename ); + + unsigned long ReadMesh( const std::string & meshFilename, StatisticalModelMeshPointer& mesh ); + + typedef itk::MeshFileWriter< StatisticalModelMeshType > MeshFileWriterType; + typedef typename MeshFileWriterType::Pointer MeshFileWriterPointer; + void WriteMesh( const char * filename, StatisticalModelMeshType mesh ); + + StatisticalModelPathVectorType ReadPath( std::string parameter ); + + StatisticalModelVectorType ReadNoiseVariance(); + + StatisticalModelVectorType ReadTotalVariance(); + + /** Sets up a timer to measure the initialization time and + * calls the Superclass' implementation. + */ + virtual void Initialize( void ) override; + + virtual int BeforeAllBase( void ) override; + + virtual void BeforeRegistration( void ) override; + + virtual void AfterEachIteration( void ) override; + + virtual void AfterEachResolution( void ) override; + + virtual void AfterRegistration( void ) override; + + +protected: + /** The constructor. */ + ActiveRegistrationModelShapeMetric() = default; + /** The destructor. */ + ~ActiveRegistrationModelShapeMetric() override = default; + +private: + elxOverrideGetSelfMacro; + + /** The deleted copy constructor. */ + ActiveRegistrationModelShapeMetric(const Self &) = delete; + /** The deleted assignment operator. */ + void operator=(const Self &) = delete; + + unsigned long m_MetricNumber; + + StatisticalModelPathVectorType m_LoadShapeModelFileNames; + StatisticalModelPathVectorType m_SaveShapeModelFileNames; + StatisticalModelPathVectorType m_ShapeDirectories; + StatisticalModelPathVectorType m_ReferenceFilenames; + +}; // end class ActiveRegistrationModel + +} // end namespace elastix + +#ifndef ITK_MANUAL_INSTANTIATION +#include "elxActiveRegistrationModelShapeMetric.hxx" +#endif + +#endif // end #ifndef __elxActiveRegistrationModelShapeMetric_h__ + diff --git a/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.hxx b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.hxx new file mode 100644 index 000000000..81471aa6d --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/elxActiveRegistrationModelShapeMetric.hxx @@ -0,0 +1,769 @@ +/*====================================================================== + + This file is part of the elastix software. + + Copyright (c) University Medical Center Utrecht. All rights reserved. + See src/CopyrightElastix.txt or http://elastix.isi.uu.nl/legal.php for + details. + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +======================================================================*/ +#ifndef __elxActiveRegistrationModelShapeMetric_hxx__ +#define __elxActiveRegistrationModelShapeMetric_hxx__ + +#include +#include + +namespace elastix +{ + +/** + * ******************* Initialize *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::Initialize( void ) +{ + itk::TimeProbe timer; + timer.Start(); + this->Superclass1::Initialize(); + timer.Stop(); + elxout << "Initialization of ActiveRegistrationModel metric took: " + << static_cast< long >( timer.GetMean() * 1000 ) << " ms." << std::endl; + +} // end Initialize() + + +/** + * ***************** BeforeAllBase *********************** + */ + +template< class TElastix > +int +ActiveRegistrationModelShapeMetric< TElastix > +::BeforeAllBase( void ) +{ + + this->Superclass2::BeforeAllBase(); + + std::string componentLabel( this->GetComponentLabel() ); + std::string metricNumber = componentLabel.substr( 6, 2 ); // strip "Metric" keep number + this->SetMetricNumber( std::stoul( metricNumber ) ); + + // Paths to shape models for loading + this->m_LoadShapeModelFileNames = ReadPath( std::string("LoadShapeModel") ); + + // Paths to shape models for loading + this->m_SaveShapeModelFileNames = ReadPath( std::string("SaveShapeModel") ); + + // Paths to directories with shapes for model building + this->m_ShapeDirectories = ReadPath( std::string("BuildShapeModel") ); + + if( this->m_SaveShapeModelFileNames.size() > 0 ) + { + if( this->m_SaveShapeModelFileNames.size() != this->m_ShapeDirectories.size() ) + { + itkExceptionMacro( "The number of destinations for saving shape models must match the number of directories." ) + } + } + + if( this->m_ShapeDirectories.size() > 0 ) + { + // Reference shapes for model building + this->m_ReferenceFilenames = ReadPath("ReferenceShape"); + + if (this->m_ReferenceFilenames.size() != this->m_ShapeDirectories.size()) + { + itkExceptionMacro(<< "The number of reference shapes does not match the number of directories given."); + } + } + + // At least one model must be specified + if( 0 == ( this->m_LoadShapeModelFileNames.size() + this->m_ShapeDirectories.size() ) ) + { + itkExceptionMacro( << "No statistical shape model specified for " << this->GetComponentLabel() << "." << std::endl + << " Specify previously built models with (LoadShapeModel" << this->GetMetricNumber() + << " \"path/to/hdf5/file1\" \"path/to/hdf5/file2\" ) or " << std::endl + << " specify directories with shapes using (BuildShapeModel" << this->GetMetricNumber() + << " \"path/to/directory1\" \"path/to/directory2\") and " << std::endl + << " corresponding reference shapes using \"(ReferenceShape" << this->GetMetricNumber() + << " \"path/to/reference1\" \"path/to/reference2\")." << std::endl + ); + } + + return 0; + +} // end BeforeAllBase() + + + +/** + * ***************** BeforeRegistration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::BeforeRegistration( void ) +{ + StatisticalModelContainerPointer statisticalModelContainer = StatisticalModelContainerType::New(); + statisticalModelContainer->Reserve( this->m_LoadShapeModelFileNames.size() + this->m_ShapeDirectories.size() ); + + // Load models + if( this->m_LoadShapeModelFileNames.size() > 0 ) + { + elxout << std::endl << "Loading models for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << " ... " << std::endl; + + for( StatisticalModelIdType statisticalModelId = 0; statisticalModelId < this->m_LoadShapeModelFileNames.size(); ++statisticalModelId ) + { + // Load model + StatisticalModelPointer statisticalModel; + try + { + StatisticalModelRepresenterPointer representer = StatisticalModelRepresenterType::New(); + statisticalModel = itk::StatismoIO< StatisticalModelMeshType >::LoadStatisticalModel( representer, this->m_LoadShapeModelFileNames[ statisticalModelId ] ); + statisticalModelContainer->SetElement( statisticalModelId, statisticalModel ); + } + catch( statismo::StatisticalModelException &e ) + { + itkExceptionMacro( "Error loading statistical shape model: " << e.what() ); + } + + elxout << " Loaded model " << this->m_LoadShapeModelFileNames[ statisticalModelId ].c_str() << "." << std::endl; + elxout << " Number of principal components: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl; + elxout << " Eigenvalues: " << statisticalModel->GetPCAVarianceVector().apply(std::sqrt) << "." << std::endl; + elxout << " Noise variance: " << statisticalModel->GetNoiseVariance() << "." << std::endl; + } + } + + // Build models + if( this->m_ShapeDirectories.size() > 0 ) + { + elxout << std::endl << "Building models for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << " ... " << std::endl; + + // Noise parameter for probabilistic pca model + StatisticalModelVectorType noiseVariance = this->ReadNoiseVariance(); + + // Number of principal components to keep by variance + StatisticalModelVectorType totalVariance = this->ReadTotalVariance(); + + // Loop over all data directories + for( StatisticalModelIdType statisticalModelId = 0; statisticalModelId < this->m_ShapeDirectories.size(); ++statisticalModelId ) + { + // Load data + StatisticalModelDataManagerPointer dataManager; + try + { + dataManager = this->ReadMeshesFromDirectory(this->m_ShapeDirectories[ statisticalModelId ], + this->m_ReferenceFilenames[ statisticalModelId ]); + } + catch( statismo::StatisticalModelException &e ) + { + itkExceptionMacro( "Error loading samples in " << this->m_ShapeDirectories[ statisticalModelId ] <<": " << e.what() ); + } + + // Build model + elxout << " Building statistical shape model for metric " << this->GetMetricNumber() << "... "; + StatisticalModelPointer statisticalModel; + try + { + StatisticalModelBuilderPointer pcaModelBuilder = StatisticalModelBuilderType::New(); + statisticalModel = pcaModelBuilder->BuildNewModel( dataManager->GetData(), noiseVariance[ statisticalModelId ] ); + elxout << " Done." << std::endl + << " Number of modes: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl + << " Eigenvalues: " << statisticalModel->GetPCAVarianceVector().apply(std::sqrt) << "." << std::endl + << " Noise variance: " << statisticalModel->GetNoiseVariance() + << "." << std::endl; + + // Pick out first principal components + if( totalVariance[ statisticalModelId ] < 1.0 ) + { + elxout << " Reducing model to " << totalVariance[ statisticalModelId ] * 100.0 << "% variance ... "; + StatisticalModelReducedVarianceBuilderPointer reducedVarianceModelBuilder = StatisticalModelReducedVarianceBuilderType::New(); + statisticalModel = reducedVarianceModelBuilder->BuildNewModelWithVariance( statisticalModel, totalVariance[ statisticalModelId ] ); + elxout << " Done." << std::endl + << " Number of modes retained: " << statisticalModel->GetNumberOfPrincipalComponents() << "." << std::endl; + } + } + catch( statismo::StatisticalModelException& e ) + { + itkExceptionMacro( << "Error building statistical shape model: " << e.what() ); + } + + if( this->m_SaveShapeModelFileNames.size() > 0 ) + { + elxout << " Saving shape model " << statisticalModelId << " to " << this->m_SaveShapeModelFileNames[ statisticalModelId ] << ". " << std::endl; + try + { + itk::StatismoIO< StatisticalModelMeshType >::SaveStatisticalModel(statisticalModel, this->m_SaveShapeModelFileNames[ statisticalModelId ]); + } + catch( statismo::StatisticalModelException& e ) + { + itkExceptionMacro( "Could not save shape model to " << this->m_SaveShapeModelFileNames[ statisticalModelId ] << "."); + } + } + + statisticalModelContainer->SetElement( statisticalModelId, statisticalModel ); + } + } + + this->SetStatisticalModelContainer( statisticalModelContainer ); + + // SingleValuedPointSetToPointSetMetric (from which this class is derived) needs a fixed and moving point set + typename FixedPointSetType::Pointer fixedDummyPointSet = FixedPointSetType::New(); + typename MovingPointSetType::Pointer movingDummyPointSet = MovingPointSetType::New(); + this->SetFixedPointSet( fixedDummyPointSet ); // FB: TODO solve hack + this->SetMovingPointSet( movingDummyPointSet ); // FB: TODO solve hack + + std::cout << std::endl; +} // end BeforeRegistration() + + + +/** + * ***************** loadShapesFromDirectory *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelShapeMetric< TElastix >::StatisticalModelDataManagerPointer +ActiveRegistrationModelShapeMetric< TElastix > +::ReadMeshesFromDirectory( + std::string shapeDataDirectory, + std::string referenceFilename) +{ + + itk::Directory::Pointer directory = itk::Directory::New(); + if( !directory->Load( shapeDataDirectory.c_str() ) ) + { + itkExceptionMacro( "No files found in " << shapeDataDirectory << "."); + } + + // Read reference shape + StatisticalModelMeshPointer reference = StatisticalModelMeshType::New(); + if( this->ReadMesh( referenceFilename, reference ) == 0 ) + { + itkExceptionMacro( "Failed to read reference file " << referenceFilename << "."); + } + + StatisticalModelRepresenterPointer representer = StatisticalModelRepresenterType::New(); + representer->SetReference( reference ); + + StatisticalModelDataManagerPointer dataManager = StatisticalModelDataManagerType::New(); + dataManager->SetRepresenter( representer.GetPointer() ); + + for( int i = 0; i < directory->GetNumberOfFiles(); ++i ) + { + const char * filename = directory->GetFile( i ); + if( std::strcmp( filename, referenceFilename.c_str() ) == 0 || std::strcmp( filename, "." ) == 0 || std::strcmp( filename, ".." ) == 0 ) + { + continue; + } + + std::string fullpath = shapeDataDirectory + "/" + filename; + StatisticalModelMeshPointer mesh = StatisticalModelMeshType::New(); + + unsigned long numberOfMeshPoints = this->ReadMesh( fullpath.c_str(), mesh ); + if( numberOfMeshPoints > 0 ) + { + dataManager->AddDataset( mesh, fullpath.c_str() ); + } + } + + return dataManager; +} + + + +/** + * ************** ReadShape ********************* + */ + +template< class TElastix > +unsigned long +ActiveRegistrationModelShapeMetric< TElastix > +::ReadMesh( + const std::string& meshFilename, + StatisticalModelMeshPointer& mesh ) +{ + // Read the input mesh. */ + MeshReaderPointer meshReader = MeshReaderType::New(); + meshReader->SetFileName( meshFilename.c_str() ); + + elxout << " Reading input mesh file: " << meshFilename << " ... "; + try + { + meshReader->UpdateLargestPossibleRegion(); + } + catch( itk::ExceptionObject & err ) + { + elxout << "skipping " << meshFilename << " (not a valid mesh file or file does not exist)." << std::endl; + return 0; + } + + // Some user-feedback. + mesh = meshReader->GetOutput(); + unsigned long numberOfPoints = mesh->GetNumberOfPoints(); + if( numberOfPoints > 0 ) + { + elxout << "read " << numberOfPoints << " points." << std::endl; + } + else + { + elxout << "skipping " << meshFilename << " (no points in mesh file)." << std::endl; + } + + return numberOfPoints; +} // end ReadMesh() + + + +/** + * ******************* WriteMesh ******************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::WriteMesh( const char * filename, StatisticalModelMeshType mesh ) +{ + // Create writer. + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + + meshWriter->SetInput( mesh ); + meshWriter->SetFileName( filename ); + + try + { + meshWriter->Update(); + } + catch( itk::ExceptionObject & excp ) + { + // Add information to the exception. + excp.SetLocation( "ActiveRegistrationModel - WriteMesh()" ); + std::string err_str = excp.GetDescription(); + err_str += "\nError occurred while writing mesh.\n"; + excp.SetDescription( err_str ); + + // Pass the exception to an higher level. + throw excp; + } +} // end WriteMesh() + + + +/** + * ***************** ReadPath *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelShapeMetric< TElastix>::StatisticalModelPathVectorType +ActiveRegistrationModelShapeMetric< TElastix > +::ReadPath( std::string path ) +{ + std::ostringstream key; + key << path << this->GetMetricNumber(); + + StatisticalModelPathVectorType pathVector; + for( unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i ) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + pathVector.push_back( value ); + } + + return pathVector; +} + + + +/** + * ***************** ReadNoiseVariance *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelShapeMetric< TElastix >::StatisticalModelVectorType +ActiveRegistrationModelShapeMetric< TElastix > +::ReadNoiseVariance() +{ + std::ostringstream key( "NoiseVariance", std::ios_base::ate ); + key << this->GetMetricNumber(); + + StatisticalModelVectorType noiseVarianceVector = StatisticalModelVectorType( this->m_ShapeDirectories.size(), 0.0 ); + unsigned int n = this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); + + if( n == 0 ) + { + elxout << "WARNING: NoiseVariance not specified for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << "." << std::endl + << " A default value of " << noiseVarianceVector[ 0 ] << " will be used (non-probabilistic PCA) for metric " << this->GetMetricNumber() << "." << std::endl; + + return noiseVarianceVector; + } + + for(unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + + char *e; + errno = 0; + double noiseVariance = std::strtod( value.c_str(), &e ); + + if ( *e != '\0' || // error, we didn't consume the entire string + errno != 0 ) // error, overflow or underflow + { + itkExceptionMacro( << "Invalid number format for NoiseVariance entry " << i << "." ); + } + + if( noiseVariance < 0 ) + { + itkExceptionMacro( << "NoiseVariance entry number " << i << " is negative (" << noiseVariance << "). Variance must be positive by definition. Please correct your parameter file." ); + } + + noiseVarianceVector[ i ] = noiseVariance; + } + + if( n == 1 && noiseVarianceVector.size() > 1 ) + { + // Fill the rest of the elements + noiseVarianceVector.fill( noiseVarianceVector[ 0 ] ); + } + + return noiseVarianceVector; +} + + + +/** + * ***************** ReadTotalVariance *********************** + */ + +template< class TElastix > +typename ActiveRegistrationModelShapeMetric< TElastix >::StatisticalModelVectorType +ActiveRegistrationModelShapeMetric< TElastix > +::ReadTotalVariance() +{ + std::ostringstream key( "TotalVariance", std::ios_base::ate ); + key << this->GetMetricNumber(); + + StatisticalModelVectorType totalVarianceVector = StatisticalModelVectorType( this->m_ShapeDirectories.size(), 1.0 ); + unsigned int n = this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); + + if( n == 0 ) + { + elxout << "WARNING: TotalVariance not specified for " << this->GetComponentLabel() << ":" << this->elxGetClassName() << "." << std::endl + << " A default value of 1.0 will be used (all principal componontents) for metric " << this->GetMetricNumber() << "." << std::endl; + + return totalVarianceVector; + } + + for(unsigned int i = 0; i < this->GetConfiguration()->CountNumberOfParameterEntries( key.str() ); ++i) + { + std::string value = ""; + this->m_Configuration->ReadParameter( value, key.str(), i ); + + char *e; + errno = 0; + double totalVariance = std::strtod( value.c_str(), &e ); + + if ( *e != '\0' || // error, we didn't consume the entire string + errno != 0 ) // error, overflow or underflow + { + itkExceptionMacro( << "Invalid number format for NoiseVariance entry " << i << "." ); + } + + if( totalVariance < 0.0 || totalVariance > 1.0 ) + { + itkExceptionMacro( << "TotalVariance entries must lie in [0.0; 1.0] but entry number " << i << " is " << totalVariance << ". Please correct your parameter file." ); + } + + totalVarianceVector[ i ] = totalVariance; + } + + if( n == 1 && totalVarianceVector.size() > 1 ) + { + // Need to fill the rest of the elements + totalVarianceVector.fill( totalVarianceVector[ 0 ] ); + } + + return totalVarianceVector; +} + + + + +/** + * ***************** AfterEachIteration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::AfterEachIteration( void ) +{ + const unsigned int iter = this->m_Elastix->GetIterationCounter(); + + /** Decide whether or not to write final model image */ + bool writeShapeModelReconstructionAfterEachIteration = false; + this->m_Configuration->ReadParameter( writeShapeModelReconstructionAfterEachIteration, + "WriteShapeModelReconstructionAfterEachIteration", 0, false ); + + if( writeShapeModelReconstructionAfterEachIteration ) { + + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->ComputeCoefficients( + this->TransformMesh( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId)->DrawMean() ) ); + std::string shapeFormat = "vtk"; + this->m_Configuration->ReadParameter( shapeFormat, "ResultShapeFormat", 0, false ); + + std::ostringstream makeFileName(""); + makeFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << statisticalModelId + << "Iteration" << iter + << "Shape." << shapeFormat; + + elxout << " Writing shape model " << statisticalModelId << " shape for " + << this->GetComponentLabel() << " after iteration " << iter << " to " << makeFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + meshWriter->SetInput( this->GetStatisticalModelContainer()->ElementAt(statisticalModelId)->DrawSample( coeffs ) ); + meshWriter->SetFileName(makeFileName.str()); + meshWriter->Update(); + } + } +} + + + +/** + * ***************** AfterEachResolution *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::AfterEachResolution( void ) +{ + const unsigned int level = this->m_Registration->GetAsITKBaseType()->GetCurrentLevel(); + + /** Decide whether or not to write model image after each resolution */ + bool writeShapeModelReconstructionAfterEachResolution = false; + this->m_Configuration->ReadParameter( writeShapeModelReconstructionAfterEachResolution, + "WriteShapeModelReconstructionAfterEachResolution", 0, false ); + + if( writeShapeModelReconstructionAfterEachResolution ) { + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->ComputeCoefficients( + this->TransformMesh( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId)->DrawMean() ) ); + + + std::string shapeFormat = "vtk"; + this->m_Configuration->ReadParameter( shapeFormat, "ResultShapeFormat", 0, false ); + + std::ostringstream makeFileName(""); + makeFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "IntensityModel" << statisticalModelId + << "Resolution" << level + << "Image." << shapeFormat; + + elxout << " Writing intensity model " << statisticalModelId << " image " << " for " + << this->GetComponentLabel() << " after resolution " << level << " to " << makeFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + meshWriter->SetInput( this->GetStatisticalModelContainer()->ElementAt(statisticalModelId)->DrawSample( coeffs ) ); + meshWriter->SetFileName(makeFileName.str()); + meshWriter->Update(); + } + } +} // end AfterEachResolution() + + +/** + * ***************** AfterRegistration *********************** + */ + +template< class TElastix > +void +ActiveRegistrationModelShapeMetric< TElastix > +::AfterRegistration( void ) +{ + /** Decide whether or not to write the mean images */ + bool writeShapeModelMeanShape = false; + this->m_Configuration->ReadParameter( writeShapeModelMeanShape, + "WriteShapeModelMeanShapeAfterRegistration", 0, false ); + + std::string shapeFormat = "vtk"; + this->m_Configuration->ReadParameter( shapeFormat, "ResultShapeFormat", 0, false ); + + if( writeShapeModelMeanShape ) + { + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) + { + std::ostringstream makeFileName( "" ); + makeFileName + << this->m_Configuration->GetCommandLineArgument( "-out" ) + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << statisticalModelId + << "MeanShape." << shapeFormat; + + elxout << " Writing statistical model " << statisticalModelId << " mean shape for " << this->GetComponentLabel() << " to " << makeFileName.str() << std::endl; + + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + meshWriter->SetInput( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->DrawMean() ); + meshWriter->SetFileName( makeFileName.str() ); + meshWriter->Update(); + } + } + + /** Decide whether or not to write final model image */ + bool writeShapeModelFinalReconstruction = false; + this->m_Configuration->ReadParameter( writeShapeModelFinalReconstruction, + "WriteShapeModelFinalReconstructionAfterRegistration", 0, false ); + + /** Decide whether or not to write sample probability */ + bool writeShapeModelFinalReconstructionProbability = false; + this->m_Configuration->ReadParameter( writeShapeModelFinalReconstructionProbability, + "WriteShapeModelFinalShapeProbabilityAfterRegistration", 0, false ); + + if( writeShapeModelFinalReconstruction || writeShapeModelFinalReconstructionProbability ) + { + for( unsigned int statisticalModelId = 0; statisticalModelId < this->GetStatisticalModelContainer()->Size(); statisticalModelId++ ) + { + StatisticalModelVectorType coeffs = this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->ComputeCoefficients( + this->TransformMesh( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId)->DrawMean() ) ); + + if( writeShapeModelFinalReconstruction ) + { + std::ostringstream makeFileName( "" ); + makeFileName + << this->m_Configuration->GetCommandLineArgument( "-out" ) + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << statisticalModelId + << "FinalShape." << shapeFormat; + + elxout << " Writing statistical model final image " << statisticalModelId << " for " << this->GetComponentLabel() << " to " << makeFileName.str() << std::endl; + + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + meshWriter->SetInput( this->GetStatisticalModelContainer()->ElementAt( statisticalModelId )->DrawSample( coeffs ) ); + meshWriter->SetFileName( makeFileName.str() ); + meshWriter->Update(); + } + + if( writeShapeModelFinalReconstructionProbability ) { + std::ostringstream makeProbFileName; + makeProbFileName + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << statisticalModelId + << "Probability.txt"; + + elxout << " Writing shape model " << statisticalModelId << " final shape probablity for " << this->GetComponentLabel() + << " to " << makeProbFileName.str() << ". "; + elxout << " Coefficents are [" << coeffs << "]." << std::endl; + std::ofstream probabilityFile; + probabilityFile.open(makeProbFileName.str()); + probabilityFile << this->GetStatisticalModelContainer()->GetElement( statisticalModelId )->ComputeLogProbabilityOfCoefficients( coeffs ); + probabilityFile.close(); + } + } + } + + bool writeShapeModelPrincipalComponents = false; + this->m_Configuration->ReadParameter( writeShapeModelPrincipalComponents, + "WriteShapeModelPrincipalComponentsAfterRegistration", 0, false ); + + if( writeShapeModelPrincipalComponents ) + { + for( unsigned int i = 0; i < this->GetStatisticalModelContainer()->Size(); i++ ) + { + std::string shapeFormat = "vtk"; + this->m_Configuration->ReadParameter( shapeFormat, "ResultShapeFormat", 0, false ); + + MeshFileWriterPointer meshWriter = MeshFileWriterType::New(); + + for( unsigned int j = 0; j < this->GetStatisticalModelContainer()->GetElement( i )->GetNumberOfPrincipalComponents(); j++ ) { + StatisticalModelVectorType plus3std = StatisticalModelVectorType( + this->GetStatisticalModelContainer()->GetElement( i )->GetNumberOfPrincipalComponents(), 0.0 ); + plus3std[ j ] = 3.0; + + std::ostringstream makeFileNamePC(""); + makeFileNamePC + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << i + << "PC" << j << "." << shapeFormat; + + elxout << " Writing shape model " << i << " principal component " << j + << " for " << this->GetComponentLabel() << " to " << makeFileNamePC.str() << std::endl; + meshWriter->SetInput(this->GetStatisticalModelContainer()->GetElement( i )->DrawPCABasisSample( j )); + meshWriter->SetFileName( makeFileNamePC.str() ); + meshWriter->Update(); + + std::ostringstream makeFileNameP3STD( "" ); + makeFileNameP3STD + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << i + << "PC" << j << "plus3std." << shapeFormat; + + elxout << " Writing shape model " << i << " principal component " << j << " plus 3 standard deviations" + << " for " << this->GetComponentLabel() << " to " << makeFileNameP3STD.str() << std::endl; + meshWriter->SetInput(this->GetStatisticalModelContainer()->GetElement( i )->DrawSample( plus3std )) ; + meshWriter->SetFileName( makeFileNameP3STD.str() ); + meshWriter->Update(); + + StatisticalModelVectorType minus3std = StatisticalModelVectorType( + this->GetStatisticalModelContainer()->GetElement( i )->GetNumberOfPrincipalComponents(), 0.0 ); + minus3std[ j ] = -3.0; + + std::ostringstream makeFileNamePCM3STD(""); + makeFileNamePCM3STD + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << i + << "PC" << j << "minus3std." << shapeFormat; + + elxout << " Writing shape model " << i << " principal component " << j << " minus 3 standard deviations" + << " for " << this->GetComponentLabel() << " to " << makeFileNamePCM3STD.str() << std::endl; + meshWriter->SetInput(this->GetStatisticalModelContainer()->GetElement(i)->DrawSample( minus3std )); + meshWriter->SetFileName( makeFileNamePCM3STD.str() ); + meshWriter->Update(); + } + } + } + + bool writeShapeModelEigenValues = false; + this->m_Configuration->ReadParameter( writeShapeModelEigenValues, + "WriteShapeModelEigenValuesAfterRegistration", 0, false ); + if( writeShapeModelEigenValues ) { + for( unsigned int i = 0; i < this->GetStatisticalModelContainer()->Size(); i++ ) { + std::ostringstream makeFileNameEigVal( "" ); + makeFileNameEigVal + << this->m_Configuration->GetCommandLineArgument("-out") + << "Metric" << this->GetMetricNumber() + << "ShapeModel" << i + << "EigenValues.txt"; + + std::ofstream f; + f.open(makeFileNameEigVal.str()); + f << this->GetStatisticalModelContainer()->GetElement(i)->GetPCAVarianceVector().apply(std::sqrt); + f.close(); + } + } + +} // end AfterRegistration() + +} // end namespace elastix + +#endif // end #ifndef __elxActiveRegistrationModelShapeMetric_hxx__ + diff --git a/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.h b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.h new file mode 100644 index 000000000..492d37552 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.h @@ -0,0 +1,235 @@ +/*========================================================================= + * + * Copyright UMC Utrecht and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef __itkActiveRegistrationModelImageIntensityMetric_h +#define __itkActiveRegistrationModelImageIntensityMetric_h + +#include "itkAdvancedImageToImageMetric.h" + +// Statismo includes +#include "itkDataManager.h" +#include "itkStatisticalModel.h" +#include "itkPCAModelBuilder.h" +#include "itkReducedVarianceModelBuilder.h" +#include "itkStandardImageRepresenter.h" + +namespace itk +{ + +/** \class AdvancedMeanSquaresImageToImageMetric + * \brief Compute Mean square difference between two images, based on AdvancedImageToImageMetric... + * + * This Class is templated over the type of the fixed and moving + * images to be compared. + * + * This metric computes the sum of squared differenced between pixels in + * the moving image and pixels in the fixed image. The spatial correspondance + * between both images is established through a Transform. Pixel values are + * taken from the Moving image. Their positions are mapped to the Fixed image + * and result in general in non-grid position on it. Values at these non-grid + * position of the Fixed image are interpolated using a user-selected Interpolator. + * + * This implementation of the MeanSquareDifference is based on the + * AdvancedImageToImageMetric, which means that: + * \li It uses the ImageSampler-framework + * \li It makes use of the compact support of B-splines, in case of B-spline transforms. + * \li Image derivatives are computed using either the B-spline interpolator's implementation + * or by nearest neighbor interpolation of a precomputed central difference image. + * \li A minimum number of samples that should map within the moving image (mask) can be specified. + * + * \ingroup RegistrationMetrics + * \ingroup Metrics + */ + +template< class TFixedImage, class TMovingImage > +class ActiveRegistrationModelIntensityMetric : + public AdvancedImageToImageMetric< TFixedImage, TMovingImage > +{ +public: + + /** Standard class typedefs. */ + typedef ActiveRegistrationModelIntensityMetric Self; + typedef AdvancedImageToImageMetric< + TFixedImage, TMovingImage > Superclass; + typedef SmartPointer< Self > Pointer; + typedef SmartPointer< const Self > ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro( ImageIntensityMetric, AdvancedImageToImageMetric ); + + /** Typedefs from the superclass. */ + typedef typename + Superclass::CoordinateRepresentationType CoordinateRepresentationType; + typedef typename Superclass::MovingImageType MovingImageType; + typedef typename Superclass::MovingImagePixelType MovingImagePixelType; + typedef typename Superclass::MovingImageConstPointer MovingImageConstPointer; + typedef typename Superclass::FixedImageType FixedImageType; + typedef typename Superclass::FixedImageConstPointer FixedImageConstPointer; + typedef typename Superclass::FixedImageRegionType FixedImageRegionType; + typedef typename Superclass::TransformType TransformType; + typedef typename Superclass::TransformPointer TransformPointer; + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename Superclass::TransformParametersType TransformParametersType; + typedef typename Superclass::TransformJacobianType TransformJacobianType; + typedef typename Superclass::NumberOfParametersType NumberOfParametersType; + typedef typename Superclass::InterpolatorType InterpolatorType; + typedef typename Superclass::InterpolatorPointer InterpolatorPointer; + typedef typename Superclass::RealType RealType; + typedef typename Superclass::GradientPixelType GradientPixelType; + typedef typename Superclass::GradientImageType GradientImageType; + typedef typename Superclass::GradientImagePointer GradientImagePointer; + typedef typename Superclass::GradientImageFilterType GradientImageFilterType; + typedef typename Superclass::GradientImageFilterPointer GradientImageFilterPointer; + typedef typename Superclass::FixedImageMaskType FixedImageMaskType; + typedef typename Superclass::FixedImageMaskPointer FixedImageMaskPointer; + typedef typename Superclass::MovingImageMaskType MovingImageMaskType; + typedef typename Superclass::MovingImageMaskPointer MovingImageMaskPointer; + typedef typename Superclass::MeasureType MeasureType; + typedef typename Superclass::DerivativeType DerivativeType; + typedef typename Superclass::DerivativeValueType DerivativeValueType; + typedef typename Superclass::ParametersType ParametersType; + typedef typename Superclass::FixedImagePixelType FixedImagePixelType; + typedef typename Superclass::MovingImageRegionType MovingImageRegionType; + typedef typename Superclass::ImageSamplerType ImageSamplerType; + typedef typename Superclass::ImageSamplerPointer ImageSamplerPointer; + typedef typename Superclass::ImageSampleContainerType ImageSampleContainerType; + typedef typename + Superclass::ImageSampleContainerPointer ImageSampleContainerPointer; + typedef typename Superclass::FixedImageLimiterType FixedImageLimiterType; + typedef typename Superclass::MovingImageLimiterType MovingImageLimiterType; + typedef typename + Superclass::FixedImageLimiterOutputType FixedImageLimiterOutputType; + typedef typename + Superclass::MovingImageLimiterOutputType MovingImageLimiterOutputType; + typedef typename + Superclass::MovingImageDerivativeScalesType MovingImageDerivativeScalesType; + typedef typename Superclass::HessianValueType HessianValueType; + typedef typename Superclass::HessianType HessianType; + typedef typename Superclass::ThreaderType ThreaderType; + typedef typename Superclass::ThreadInfoType ThreadInfoType; + + /** The fixed image dimension. */ + itkStaticConstMacro( FixedImageDimension, unsigned int, + FixedImageType::ImageDimension ); + + /** The moving image dimension. */ + itkStaticConstMacro( MovingImageDimension, unsigned int, + MovingImageType::ImageDimension ); + + /** Get the value for single valued optimizers. */ + virtual MeasureType GetValue( const TransformParametersType & parameters ) const override; + + // ActiveRegistrationModel typedefs + typedef double StatisticalModelScalarType; + typedef vnl_vector< double > StatisticalModelVectorType; + typedef vnl_matrix< double > StatisticalModelMatrixType; + + typedef FixedImageType StatisticalModelImageType; + typedef typename StatisticalModelImageType::Pointer StatisticalModelImagePointer; + + typedef StatisticalModel< StatisticalModelImageType > StatisticalModelType; + typedef typename StatisticalModelType::Pointer StatisticalModelPointer; + + typedef itk::StandardImageRepresenter< + typename StatisticalModelImageType::PixelType, + StatisticalModelImageType::ImageDimension > StatisticalModelRepresenterType; + typedef typename StatisticalModelRepresenterType::Pointer StatisticalModelRepresenterPointer; + + typedef DataManager< StatisticalModelImageType > StatisticalModelDataManagerType; + typedef typename StatisticalModelDataManagerType::Pointer StatisticalModelDataManagerPointer; + + typedef PCAModelBuilder< StatisticalModelImageType > StatisticalModelModelBuilderType; + typedef typename StatisticalModelModelBuilderType::Pointer StatisticalModelBuilderPointer; + + typedef ReducedVarianceModelBuilder< StatisticalModelImageType > StatisticalModelReducedVarianceBuilderType; + typedef typename StatisticalModelReducedVarianceBuilderType::Pointer StatisticalModelReducedVarianceBuilderPointer; + + typedef unsigned int StatisticalModelIdType; + + typedef VectorContainer< StatisticalModelIdType, StatisticalModelPointer > StatisticalModelContainerType; + typedef typename StatisticalModelContainerType::Pointer StatisticalModelContainerPointer; + typedef typename StatisticalModelContainerType::ConstPointer StatisticalModelContainerConstPointer; + + itkSetConstObjectMacro( StatisticalModelContainer, StatisticalModelContainerType ); + itkGetConstObjectMacro( StatisticalModelContainer, StatisticalModelContainerType ); + + /** Initialize the Metric by making sure that all the components are + * present and plugged together correctly. + */ + virtual void Initialize( void ) override; + + /** Get the derivatives of the match measure. */ + void GetDerivative( const TransformParametersType & parameters, + DerivativeType & Derivative ) const override; + + /** Get value and derivatives for multiple valued optimizers. */ + void GetValueAndDerivative( const TransformParametersType & parameters, + MeasureType & Value, DerivativeType & Derivative ) const override; + + void GetValueAndFiniteDifferenceDerivative( const TransformParametersType & parameters, + MeasureType& value, + DerivativeType& derivative ) const; + + void GetModelValue( const TransformParametersType& parameters, + const StatisticalModelPointer statisticalModel, + MeasureType& modelValue ) const; + + + void GetModelFiniteDifferenceDerivative( const TransformParametersType & parameters, + const StatisticalModelPointer statisticalModel, + DerivativeType& modelDerivative ) const; + +protected: + + ActiveRegistrationModelIntensityMetric(); + virtual ~ActiveRegistrationModelIntensityMetric(){} + + void PrintSelf( std::ostream & os, Indent indent ) const override; + + /** Protected Typedefs ******************/ + + /** Typedefs inherited from superclass */ + typedef typename Superclass::FixedImageIndexType FixedImageIndexType; + typedef typename Superclass::FixedImageIndexValueType FixedImageIndexValueType; + typedef typename Superclass::MovingImageIndexType MovingImageIndexType; + typedef typename Superclass::FixedImagePointType FixedImagePointType; + typedef typename Superclass::MovingImagePointType MovingImagePointType; + typedef typename Superclass::MovingImageContinuousIndexType MovingImageContinuousIndexType; + typedef typename Superclass::BSplineInterpolatorType BSplineInterpolatorType; + typedef typename Superclass::CentralDifferenceGradientFilterType CentralDifferenceGradientFilterType; + typedef typename Superclass::MovingImageDerivativeType MovingImageDerivativeType; + typedef typename Superclass::NonZeroJacobianIndicesType NonZeroJacobianIndicesType; + +private: + + ActiveRegistrationModelIntensityMetric( const Self & ); // purposely not implemented + void operator=( const Self & ); // purposely not implemented + + StatisticalModelContainerConstPointer m_StatisticalModelContainer; +}; + +} // end namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +#include "itkActiveRegistrationModelIntensityMetric.hxx" +#endif + +#endif // end #ifndef __itkActiveRegistrationModelImageIntensityMetric_h diff --git a/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.hxx b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.hxx new file mode 100644 index 000000000..0f7a7756b --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelIntensityMetric.hxx @@ -0,0 +1,382 @@ +/*========================================================================= + * + * Copyright UMC Utrecht and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + *=========================================================================*/ +#ifndef _itkActiveRegistrationModelImageIntensityMetric_hxx +#define _itkActiveRegistrationModelImageIntensityMetric_hxx + +#include "itkActiveRegistrationModelIntensityMetric.h" + +namespace itk { + +/** + * ******************* Constructor ******************* + */ + +template +ActiveRegistrationModelIntensityMetric +::ActiveRegistrationModelIntensityMetric() { + this->SetUseImageSampler(true); + this->SetUseFixedImageLimiter(false); + this->SetUseMovingImageLimiter(false); + +} // end Constructor + +/** + * ********************* Initialize **************************** + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::Initialize() { + /** Initialize transform, interpolator, etc. */ + Superclass::Initialize(); +} // end Initialize() + + +/** + * ******************* PrintSelf ******************* + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::PrintSelf( std::ostream &os, Indent indent ) const { + Superclass::PrintSelf( os, indent ); +} // end PrintSelf() + + +/** + * ******************* GetValueAndFiniteDifferenceDerivative ******************* + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::GetValueAndFiniteDifferenceDerivative( const TransformParametersType & parameters, + MeasureType& value, + DerivativeType& derivative ) const +{ + value = NumericTraits< MeasureType >::ZeroValue(); + derivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + // Loop over models + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + + // Initialize value container + MeasureType modelValue = NumericTraits< MeasureType >::ZeroValue(); + DerivativeType modelDerivative = DerivativeType( this->GetNumberOfParameters() ); + modelDerivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + this->GetModelValue( parameters, statisticalModel, modelValue ); + this->GetModelFiniteDifferenceDerivative( parameters, statisticalModel, modelDerivative ); + + value += modelValue; + derivative += modelDerivative; + } + + value /= this->GetStatisticalModelContainer()->Size(); + derivative /= this->GetStatisticalModelContainer()->Size(); + + elxout << "FiniteDiff: " << value << ", " << derivative << std::endl; +} + + + + + + +/** + * ******************* GetValue ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +typename ActiveRegistrationModelIntensityMetric< TFixedPointSet, TMovingPointSet >::MeasureType +ActiveRegistrationModelIntensityMetric< TFixedPointSet, TMovingPointSet > +::GetValue( const TransformParametersType& parameters ) const +{ + MeasureType value = NumericTraits< MeasureType >::ZeroValue(); + + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + this->GetModelValue( parameters, statisticalModel, value ); + } + + value /= this->GetStatisticalModelContainer()->Size(); + + return value; +} // end GetValue() + + + +/** + * ******************* GetModelValue ******************* + */ + +template +void +ActiveRegistrationModelIntensityMetric +::GetModelValue( const TransformParametersType& parameters, + const StatisticalModelPointer statisticalModel, + MeasureType& modelValue ) const +{ + + // Make sure transform parameters are up-to-date + this->SetTransformParameters( parameters ); + + ImageSampleContainerPointer sampleContainer = this->GetImageSampler()->GetOutput(); + typename StatisticalModelType::PointValueListType fixedPointMovingImageValues; + + FixedImagePointType fixedPoint; + MovingImagePointType movingPoint; + RealType movingImageValue; + + for( const auto& sample : sampleContainer->CastToSTLConstContainer() ) + { + // Transform point + fixedPoint = sample.m_ImageCoordinates; + bool sampleOk = this->TransformPoint( fixedPoint, movingPoint ); + + // Check if movingPoint is inside moving image + if( sampleOk ) { + sampleOk = this->m_Interpolator->IsInsideBuffer( movingPoint ); + } else { + continue; + } + + // Check if movingPoint is inside moving mask if moving mask is used + if( sampleOk ) { + sampleOk = this->IsInsideMovingMask(movingPoint); + } else { + continue; + } + + // Sample moving image + if( sampleOk ) { + sampleOk = this->EvaluateMovingImageValueAndDerivative( movingPoint, movingImageValue, nullptr ); + } else { + continue; + } + + if( sampleOk ) + { + fixedPointMovingImageValues.emplace_back( fixedPoint, movingImageValue ); + } + } + + this->CheckNumberOfSamples( sampleContainer->Size(), fixedPointMovingImageValues.size() ); + + const auto coeffs = statisticalModel->ComputeCoefficientsForPointValues( fixedPointMovingImageValues, statisticalModel->GetNoiseVariance() ); + + // tmp = sum_J (M_j - mu_j) * (I - V_j V_j^T) * (M_j - mu_j) + RealType tmp = 0; + for( const auto& fixedPointMovingImageValue : fixedPointMovingImageValues ) { + const auto& fixedPoint = fixedPointMovingImageValue.first; + const auto& movingImageValue = fixedPointMovingImageValue.second; + + tmp += ( movingImageValue - statisticalModel->DrawMeanAtPoint( fixedPoint ) ) * + ( movingImageValue - statisticalModel->DrawSampleAtPoint( coeffs, fixedPoint, true ) ); + } + + if( fixedPointMovingImageValues.size() > 0 ) + { + modelValue += tmp / fixedPointMovingImageValues.size(); + } + +} // end GetModelValue() + + +/** + * ******************* GetModelFiniteDifferenceDerivative ******************* + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::GetModelFiniteDifferenceDerivative( const TransformParametersType & parameters, + const StatisticalModelPointer statisticalModel, + DerivativeType& modelDerivative ) const +{ + const double h = 0.01; + + // Get derivative (J(X)-W*(inv(C)*(W^T*J(X))))^T*f(X) + unsigned int siz = parameters.size(); + for( unsigned int i = 0; i < parameters.size(); ++i ) + { + MeasureType plusModelValue = NumericTraits< MeasureType >::ZeroValue(); + MeasureType minusModelValue = NumericTraits< MeasureType >::ZeroValue(); + + TransformParametersType plusParameters = parameters; + TransformParametersType minusParameters = parameters; + + plusParameters[ i ] += h; + minusParameters[ i ] -= h; + + this->GetModelValue( plusParameters, statisticalModel, plusModelValue ); + this->GetModelValue( minusParameters, statisticalModel, minusModelValue ); + + modelDerivative[ i ] += ( plusModelValue - minusModelValue ) / ( 2 * h ); + } + + this->SetTransformParameters( parameters ); +} + + + + + + + +/** + * ******************* GetDerivative ******************* + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::GetDerivative( + const TransformParametersType & parameters, + DerivativeType & derivative ) const +{ + /** When the derivative is calculated, all information for calculating + * the metric value is available. It does not cost anything to calculate + * the metric value now. Therefore, we have chosen to only implement the + * GetValueAndDerivative(), supplying it with a dummy value variable. + */ + MeasureType dummyvalue = NumericTraits< MeasureType >::Zero; + this->GetValueAndDerivative( parameters, dummyvalue, derivative ); + +} // end GetDerivative() + + +/** + * ******************* GetValueAndDerivative ******************* + */ + +template< class TFixedImage, class TMovingImage > +void +ActiveRegistrationModelIntensityMetric< TFixedImage, TMovingImage > +::GetValueAndDerivative( + const TransformParametersType & parameters, + MeasureType & value, DerivativeType & derivative ) const { + + this->SetTransformParameters( parameters ); + + value = NumericTraits< MeasureType >::ZeroValue(); + derivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + DerivativeType Jacobian( this->GetTransform()->GetNumberOfNonZeroJacobianIndices() ); + NonZeroJacobianIndicesType nzji( this->GetTransform()->GetNumberOfNonZeroJacobianIndices() ); + + ImageSampleContainerPointer sampleContainer = this->GetImageSampler()->GetOutput(); + + // Loop over models + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + MeasureType modelValue = NumericTraits< MeasureType >::ZeroValue(); + DerivativeType modelDerivative = DerivativeType( this->GetNumberOfParameters() ); + modelDerivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + typename StatisticalModelType::PointValueListType fixedPointMovingImageValues; + typename std::vector< MovingImageDerivativeType > movingImageDerivatives; + + for( const auto& sample : sampleContainer->CastToSTLConstContainer() ) + { + MovingImagePointType movingPoint; + RealType movingImageValue; + MovingImageDerivativeType movingImageDerivative; + + // Transform point + const FixedImagePointType& fixedPoint = sample.m_ImageCoordinates; + bool sampleOk = this->TransformPoint( fixedPoint, movingPoint ); + + // Check if movingPoint is inside moving image + if( sampleOk ) { + sampleOk = this->m_Interpolator->IsInsideBuffer( movingPoint ); + } else { + continue; + } + + // Check if movingPoint is inside moving mask if moving mask is used + if( sampleOk ) { + sampleOk = this->IsInsideMovingMask( movingPoint ); + } else { + continue; + } + + // Sample moving image + if( sampleOk ) { + sampleOk = this->EvaluateMovingImageValueAndDerivative( movingPoint, movingImageValue, &movingImageDerivative ); + } else { + continue; + } + + if( sampleOk ) + { + fixedPointMovingImageValues.emplace_back( fixedPoint, movingImageValue ); + movingImageDerivatives.emplace_back( movingImageDerivative ); + } + } + + this->CheckNumberOfSamples( sampleContainer->Size(), fixedPointMovingImageValues.size() ); + + const auto coeffs = statisticalModel->ComputeCoefficientsForPointValues( fixedPointMovingImageValues, statisticalModel->GetNoiseVariance() ); + + for( auto it = std::make_pair( fixedPointMovingImageValues.begin(), movingImageDerivatives.begin() ); + it.first != fixedPointMovingImageValues.end(); + it.first++, it.second++) { + + const FixedImagePointType& fixedPoint = it.first->first; + const RealType& movingImageValue = it.first->second; + const MovingImageDerivativeType& movingImageDerivative = *it.second; + + // tmp = (M_j - mu_j) * (I - V_j V_j^T) + RealType tmp = movingImageValue - statisticalModel->DrawSampleAtPoint( coeffs, fixedPoint, true ); + modelValue += ( movingImageValue - statisticalModel->DrawMeanAtPoint( fixedPoint ) ) * tmp; + + // (dM/d{x,y,z})(dT/du) + this->m_AdvancedTransform->EvaluateJacobianWithImageGradientProduct( fixedPoint, movingImageDerivative, Jacobian, nzji ); + + // Loop over Jacobian + for( unsigned int i = 0; i < nzji.size(); ++i ) + { + const unsigned int& mu = nzji[ i ]; + modelDerivative[ mu ] += tmp * Jacobian[ i ]; + } + } + + value += modelValue / fixedPointMovingImageValues.size(); + derivative += 2.0 * modelDerivative / fixedPointMovingImageValues.size(); + } + + value /= this->GetStatisticalModelContainer()->Size(); + derivative /= this->GetStatisticalModelContainer()->Size(); + + const bool useFiniteDifferenceDerivative = false; + if (useFiniteDifferenceDerivative) + { + elxout << "Analytical: " << value << ", " << derivative << std::endl; + this->GetValueAndFiniteDifferenceDerivative( parameters, value, derivative ); + } + + return; +} // end GetValueAndDerivative() + +} // end namespace itk + +#endif // end #ifndef _itkAdvancedMeanSquaresImageToImageMetric_hxx diff --git a/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.h b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.h new file mode 100644 index 000000000..ab86392b4 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.h @@ -0,0 +1,207 @@ +/*====================================================================== + +This file is part of the elastix software. + +Copyright (c) University Medical Center Utrecht. All rights reserved. +See src/CopyrightElastix.txt or http://elastix.isi.uu.nl/legal.php for +details. + +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the above copyright notices for more information. + +======================================================================*/ +#ifndef __itkActiveRegistrationModelShapeMetric_h__ +#define __itkActiveRegistrationModelShapeMetric_h__ + +#include "itkSingleValuedPointSetToPointSetMetric.h" +#include "itkPoint.h" +#include "itkPointSet.h" +#include "itkImage.h" +#include "itkMesh.h" +#include +#include + +#include "itkDataManager.h" +#include "itkStatisticalModel.h" +#include "itkPCAModelBuilder.h" +#include "itkReducedVarianceModelBuilder.h" +#include "itkStandardMeshRepresenter.h" + +namespace itk +{ + +/** \class PointSetPenalty + * \brief A dummy metric to generate transformed meshes each iteration. + * + * + * + * \ingroup RegistrationMetrics + */ + +template< class TFixedPointSet, class TMovingPointSet > +class ITK_EXPORT ActiveRegistrationModelShapeMetric : + public SingleValuedPointSetToPointSetMetric< TFixedPointSet, TMovingPointSet > +{ +public: + + /** Standard class typedefs. */ + typedef ActiveRegistrationModelShapeMetric Self; + typedef SingleValuedPointSetToPointSetMetric< + TFixedPointSet, TMovingPointSet > Superclass; + typedef SmartPointer< Self > Pointer; + typedef SmartPointer< const Self > ConstPointer; + + /** Type used for representing point components */ + + /** Method for creation through the object factory. */ + itkNewMacro( Self ); + + /** Run-time type information (and related methods). */ + itkTypeMacro( PointDistributionShapeMetric, SingleValuedPointSetToPointSetMetric ); + + /** Types transferred from the base class */ + typedef typename Superclass::TransformType TransformType; + typedef typename Superclass::TransformPointer TransformPointer; + typedef typename Superclass::TransformParametersType TransformParametersType; + typedef typename Superclass::TransformJacobianType TransformJacobianType; + + typedef typename Superclass::MeasureType MeasureType; + typedef typename Superclass::DerivativeType DerivativeType; + typedef typename Superclass::DerivativeValueType DerivativeValueType; + + /** Typedefs. */ + typedef typename Superclass::InputPointType InputPointType; + typedef typename Superclass::OutputPointType OutputPointType; + typedef typename InputPointType::CoordRepType CoordRepType; + typedef vnl_vector VnlVectorType; + typedef typename TransformType::InputPointType FixedImagePointType; + typedef typename TransformType::OutputPointType MovingImagePointType; + typedef typename TransformType::SpatialJacobianType SpatialJacobianType; + + typedef typename Superclass::NonZeroJacobianIndicesType NonZeroJacobianIndicesType; + + /** Constants for the pointset dimensions. */ + itkStaticConstMacro( FixedPointSetDimension, unsigned int, + Superclass::FixedPointSetDimension ); + + typedef Vector< typename TransformType::ScalarType, + FixedPointSetDimension > PointNormalType; + typedef unsigned char DummyMeshPixelType; + typedef DefaultStaticMeshTraits< PointNormalType, + FixedPointSetDimension, FixedPointSetDimension, CoordRepType > MeshTraitsType; + typedef Mesh< PointNormalType, FixedPointSetDimension, + MeshTraitsType > FixedMeshType; + + typedef typename FixedMeshType::ConstPointer FixedMeshConstPointer; + typedef typename FixedMeshType::Pointer FixedMeshPointer; + typedef typename MeshTraitsType::CellType::CellInterface CellInterfaceType; + + // ActiveRegistrationModel typedefs + typedef double StatisticalModelScalarType; + typedef vnl_vector< double > StatisticalModelVectorType; + typedef vnl_matrix< double > StatisticalModelMatrixType; + + itkStaticConstMacro( StatisticalModelMeshDimension, unsigned int, Superclass::FixedPointSetDimension ); + + typedef DefaultStaticMeshTraits< + StatisticalModelScalarType, + FixedPointSetDimension, + FixedPointSetDimension, + StatisticalModelScalarType, + StatisticalModelScalarType > StatisticalModelMeshTraitsType; + + typedef Mesh< + StatisticalModelScalarType, + StatisticalModelMeshDimension, + StatisticalModelMeshTraitsType > StatisticalModelMeshType; + typedef typename StatisticalModelMeshType::Pointer StatisticalModelMeshPointer; + typedef typename StatisticalModelMeshType::ConstPointer StatisticalModelMeshConstPointer; + typedef typename StatisticalModelMeshType::PointsContainerIterator StatisticalModelMeshIteratorType; + typedef typename StatisticalModelMeshType::PointsContainerConstIterator StatisticalModelMeshConstIteratorType; + + typedef MeshFileReader< StatisticalModelMeshType > MeshReaderType; + typedef typename MeshReaderType::Pointer MeshReaderPointer; + + typedef DataManager< StatisticalModelMeshType > StatisticalModelDataManagerType; + typedef typename StatisticalModelDataManagerType::Pointer StatisticalModelDataManagerPointer; + + typedef StatisticalModel< StatisticalModelMeshType > StatisticalModelType; + typedef typename StatisticalModelType::Pointer StatisticalModelPointer; + + typedef StandardMeshRepresenter< + StatisticalModelScalarType, + StatisticalModelMeshDimension > StatisticalModelRepresenterType; + typedef typename StatisticalModelRepresenterType::Pointer StatisticalModelRepresenterPointer; + + typedef PCAModelBuilder< StatisticalModelMeshType > ModelBuilderType; + typedef typename ModelBuilderType::Pointer ModelBuilderPointer; + + typedef ReducedVarianceModelBuilder< StatisticalModelMeshType > StatisticalModelReducedVarianceBuilderType; + typedef typename StatisticalModelReducedVarianceBuilderType::Pointer StatisticalModelReducedVarianceBuilderPointer; + + typedef unsigned int StatisticalModelIdType; + + typedef VectorContainer< StatisticalModelIdType, StatisticalModelPointer > StatisticalModelContainerType; + typedef typename StatisticalModelContainerType::Pointer StatisticalModelContainerPointer; + typedef typename StatisticalModelContainerType::ConstPointer StatisticalModelContainerConstPointer; + + itkSetConstObjectMacro( StatisticalModelContainer, StatisticalModelContainerType ); + itkGetConstObjectMacro( StatisticalModelContainer, StatisticalModelContainerType ); + + /** Initialize the Metric by making sure that all the components are + * present and plugged together correctly. + */ + virtual void Initialize( void ) override; + + /** Get the value for single valued optimizers. */ + MeasureType GetValue( const TransformParametersType & parameters ) const override; + + /** Get the derivatives of the match measure. */ + void GetDerivative( const TransformParametersType & parameters, + DerivativeType& Derivative ) const override; + + /** Get value and derivatives for multiple valued optimizers. */ + void GetValueAndDerivative( const TransformParametersType& parameters, + MeasureType& Value, DerivativeType& Derivative ) const override; + + void GetValueAndFiniteDifferenceDerivative( const TransformParametersType& parameters, + MeasureType& value, + DerivativeType& derivative ) const; + + void GetModelValue( const TransformParametersType& parameters, + const StatisticalModelPointer statisticalModel, + MeasureType& modelValue ) const; + + + void GetModelFiniteDifferenceDerivative( const TransformParametersType & parameters, + const StatisticalModelPointer statisticalModel, + DerivativeType& modelDerivative ) const; + + StatisticalModelMeshPointer TransformMesh( StatisticalModelMeshPointer fixedMesh ) const; + +protected: + + ActiveRegistrationModelShapeMetric(); + virtual ~ActiveRegistrationModelShapeMetric(); + + /** PrintSelf. */ + void PrintSelf( std::ostream & os, Indent indent ) const override; + +private: + + ActiveRegistrationModelShapeMetric( const Self & ); // purposely not implemented + void operator=( const Self & ); // purposely not implemented + + StatisticalModelContainerConstPointer m_StatisticalModelContainer; + +}; // end class PointSetPenalty + +} // end namespace itk + +#ifndef ITK_MANUAL_INSTANTIATION +#include "itkActiveRegistrationModelShapeMetric.hxx" +#endif + +#endif // end #ifndef __itkActiveRegistrationModelPointDistributionShapeMetric_h__ + diff --git a/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.hxx b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.hxx new file mode 100644 index 000000000..ff0491125 --- /dev/null +++ b/Components/Metrics/ActiveRegistrationModel/itkActiveRegistrationModelShapeMetric.hxx @@ -0,0 +1,329 @@ +/*====================================================================== + +This file is part of the elastix software. + +Copyright (c) University Medical Center Utrecht. All rights reserved. +See src/CopyrightElastix.txt or http://elastix.isi.uu.nl/legal.php for +details. + +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the above copyright notices for more information. + +======================================================================*/ +#ifndef __itkActiveRegistrationModelShapeMetric_hxx__ +#define __itkActiveRegistrationModelShapeMetric_hxx__ + +#include "itkActiveRegistrationModelShapeMetric.h" + +namespace itk +{ + +/** + * ******************* Constructor ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::ActiveRegistrationModelShapeMetric() +{ +} // end Constructor + + +/** + * ******************* Destructor ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::~ActiveRegistrationModelShapeMetric() +{} // end Destructor + +/** + * *********************** Initialize ***************************** + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::Initialize( void ) +{ + if( !this->GetTransform() ) + { + itkExceptionMacro( << "Transform is not present" ); + } + + if( this->GetStatisticalModelContainer()->Size() == 0 ) + { + itkExceptionMacro( << "StatisticalModelContainer is empty." ); + } +} // end Initialize() + + + +/** + * ******************* GetValue ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +typename ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet >::MeasureType +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetValue( const TransformParametersType& parameters ) const +{ + MeasureType value = NumericTraits< MeasureType >::Zero; + + // Loop over models + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + this->GetModelValue( parameters, statisticalModel, value ); + } + + value /= this->GetStatisticalModelContainer()->Size(); + + return value; +} // end GetValue() + + + +/** + * ******************* GetModelValue ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetModelValue( const TransformParametersType& parameters, + const StatisticalModelPointer statisticalModel, + MeasureType& modelValue ) const +{ + // Make sure transform parameters are up-to-date + this->SetTransformParameters( parameters ); + + const auto& fixedMesh = statisticalModel->DrawMean(); + const auto& fixedVector = statisticalModel->GetRepresenter()->SampleToSampleVector( fixedMesh ); + + const auto& movingMesh = this->TransformMesh( fixedMesh ); + const auto& movingVector = statisticalModel->GetRepresenter()->SampleToSampleVector( movingMesh ); + + const auto& coeffs = statisticalModel->ComputeCoefficients( movingMesh ); + const auto& reconstructedMovingVector = statisticalModel->GetRepresenter()->SampleToSampleVector( statisticalModel->DrawSample( coeffs, false ) ); + + modelValue += ( movingVector - fixedVector ).dot( movingVector - reconstructedMovingVector ) / fixedMesh->GetNumberOfPoints(); +} + + +/** + * ******************* GetValueAndFiniteDifferenceDerivative ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetValueAndFiniteDifferenceDerivative( const TransformParametersType& parameters, + MeasureType& value, + DerivativeType& derivative ) const +{ + value = NumericTraits< MeasureType >::ZeroValue(); + derivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + MeasureType modelValue = NumericTraits< MeasureType >::ZeroValue(); + DerivativeType modelDerivative = DerivativeType( this->GetNumberOfParameters() ); + modelDerivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + this->GetModelValue( parameters, statisticalModel, value ); + this->GetModelFiniteDifferenceDerivative( parameters, statisticalModel, modelDerivative ); + + value += modelValue; + derivative += modelDerivative; + } + + value /= this->GetStatisticalModelContainer()->Size(); + derivative /= this->GetStatisticalModelContainer()->Size(); + + elxout << "FiniteDiff: " << value << ", " << derivative << std::endl; +} + + + +/** + * ******************* GetModelFiniteDifferenceDerivative ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetModelFiniteDifferenceDerivative( const TransformParametersType & parameters, + const StatisticalModelPointer statisticalModel, + DerivativeType& modelDerivative ) const +{ + const DerivativeValueType h = 0.01; + + for( unsigned int i = 0; i < parameters.size(); ++i )\ + { + MeasureType plusModelValue = NumericTraits< DerivativeValueType >::ZeroValue(); + MeasureType minusModelValue = NumericTraits< DerivativeValueType >::ZeroValue(); + + TransformParametersType plusParameters = parameters; + TransformParametersType minusParameters = parameters; + + plusParameters[ i ] += h; + minusParameters[ i ] -= h; + + this->GetModelValue( plusParameters, statisticalModel, plusModelValue ); + this->GetModelValue( minusParameters, statisticalModel, minusModelValue ); + + modelDerivative[ i ] += ( plusModelValue - minusModelValue ) / ( 2 * h ); + } + + // Reset transform parameters + this->SetTransformParameters( parameters ); +} + + + +/** + * ******************* GetDerivative ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetDerivative( const TransformParametersType & parameters, + DerivativeType & derivative ) const +{ + /** When the derivative is calculated, all information for calculating + * the metric value is available. It does not cost anything to calculate + * the metric value now. Therefore, we have chosen to only implement the + * GetValueAndDerivative(), supplying it with a dummy value variable. + */ + MeasureType dummyvalue = NumericTraits< MeasureType >::Zero; + this->GetValueAndDerivative( parameters, dummyvalue, derivative ); + +} // end GetDerivative() + + + +/** + * ******************* GetValueAndDerivative ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::GetValueAndDerivative( const TransformParametersType& parameters, + MeasureType& value, + DerivativeType& derivative ) const +{ + this->SetTransformParameters( parameters ); + + value = NumericTraits< MeasureType >::ZeroValue(); + derivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + TransformJacobianType Jacobian; + NonZeroJacobianIndicesType nzji( this->GetTransform()->GetNumberOfNonZeroJacobianIndices() ); + + // Loop over models + for( const auto& statisticalModel : this->GetStatisticalModelContainer()->CastToSTLConstContainer() ) + { + DerivativeType modelDerivative = DerivativeType( this->GetNumberOfParameters() ); + modelDerivative.Fill( NumericTraits< DerivativeValueType >::ZeroValue() ); + + const auto& fixedMesh = statisticalModel->DrawMean(); + const auto& fixedVector = statisticalModel->GetRepresenter()->SampleToSampleVector( fixedMesh ); + + const auto& movingMesh = this->TransformMesh( fixedMesh ); + const auto& movingVector = statisticalModel->GetRepresenter()->SampleToSampleVector( movingMesh ); + + const auto& coeffs = statisticalModel->ComputeCoefficients( movingMesh ); + const auto& reconstructedMovingVector = statisticalModel->GetRepresenter()->SampleToSampleVector( statisticalModel->DrawSample( coeffs, false ) ); + + const statismo::VectorType tmp = movingVector - reconstructedMovingVector; + MeasureType modelValue = ( movingVector - fixedVector ).dot( tmp ); + + for( unsigned int i = 0; i < fixedMesh->GetNumberOfPoints(); i++ ) + { + const auto& tmp_i = StatisticalModelVectorType( tmp.data() + i * StatisticalModelMeshDimension, + StatisticalModelMeshDimension ); + + this->GetTransform()->GetJacobian( fixedMesh->GetPoints()->ElementAt( i ), Jacobian, nzji ); + + for( unsigned int j = 0; j < nzji.size(); j++ ) { + const auto& mu = nzji[ j ]; + modelDerivative[ mu ] += dot_product( tmp_i, Jacobian.get_column( j ) ); + } + } + + if( std::isnan( modelValue ) ) + { + itkExceptionMacro( "Model value is NaN.") + } + + if( fixedMesh->GetNumberOfPoints() > 0 ) + { + value += modelValue / fixedMesh->GetNumberOfPoints(); + derivative += 2.0 * modelDerivative / fixedMesh->GetNumberOfPoints(); + } + } + + value /= this->GetStatisticalModelContainer()->Size(); + derivative /= this->GetStatisticalModelContainer()->Size(); + + const bool useFiniteDifferenceDerivative = false; + if( useFiniteDifferenceDerivative ) + { + elxout << "Analytical: " << value << ", " << derivative << std::endl; + this->GetValueAndFiniteDifferenceDerivative( parameters, value, derivative ); + } +} + + + +/** + * ******************* TransformMesh ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +typename ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet >::StatisticalModelMeshPointer +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::TransformMesh( StatisticalModelMeshPointer fixedMesh ) const +{ + StatisticalModelMeshPointer movingMesh = StatisticalModelMeshType::New(); + movingMesh->GetPoints()->Reserve( fixedMesh->GetNumberOfPoints() ); + + // Transform mesh + StatisticalModelMeshConstIteratorType fixedMeshIterator = fixedMesh->GetPoints()->Begin(); + StatisticalModelMeshConstIteratorType fixedMeshIteratorEnd = fixedMesh->GetPoints()->End(); + StatisticalModelMeshIteratorType movingMeshIterator = movingMesh->GetPoints()->Begin(); + while( fixedMeshIterator != fixedMeshIteratorEnd ) + { + movingMeshIterator->Value() = this->GetTransform()->TransformPoint( fixedMeshIterator->Value() ); + ++fixedMeshIterator; + ++movingMeshIterator; + } + + return movingMesh; +} + + + +/** + * ******************* PrintSelf ******************* + */ + +template< class TFixedPointSet, class TMovingPointSet > +void +ActiveRegistrationModelShapeMetric< TFixedPointSet, TMovingPointSet > +::PrintSelf( std::ostream & os, Indent indent ) const +{ + Superclass::PrintSelf(os, indent); + // TODO +} + + + +} // end namespace itk + +#endif // end #ifndef __itkActiveRegistrationModelShapeMetric_hxx__ + From 85d2f40d33d9b6bc383804fb15ccfb963e781337 Mon Sep 17 00:00:00 2001 From: Kasper Marstal Date: Mon, 29 Mar 2021 13:02:36 +0200 Subject: [PATCH 2/2] BUG: Fix CMake variable expansion --- Components/Metrics/ActiveRegistrationModel/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt b/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt index 3933593fd..83ce24fb3 100644 --- a/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt +++ b/Components/Metrics/ActiveRegistrationModel/CMakeLists.txt @@ -20,10 +20,10 @@ if(${USE_ActiveRegistrationModelShapeMetric} OR ${USE_ActiveRegistrationModelInt add_subdirectory(Statismo) endif() -if(USE_ActiveRegistrationModelShapeMetric) +if(${USE_ActiveRegistrationModelShapeMetric}) target_link_libraries(ActiveRegistrationModelShapeMetric statismo_core ${Boost_LIBRARIES} ITKInternalEigen3::Eigen) endif() -if(USE_ActiveRegistrationModelIntensityMetric) +if(${USE_ActiveRegistrationModelIntensityMetric}) target_link_libraries(ActiveRegistrationModelIntensityMetric statismo_core ${Boost_LIBRARIES} ITKInternalEigen3::Eigen) endif()