diff --git a/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCATest.java b/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCATest.java new file mode 100644 index 000000000..3c8d15b5a --- /dev/null +++ b/src/test/java/org/mastodon/mamut/feature/dimensionalityreduction/pca/PCATest.java @@ -0,0 +1,43 @@ +package org.mastodon.mamut.feature.dimensionalityreduction.pca; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.invoke.MethodHandles; + +import org.junit.jupiter.api.Test; +import org.mastodon.mamut.feature.dimensionalityreduction.RandomDataTools; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import smile.data.DataFrame; +import smile.feature.extraction.PCA; + +class PCATest +{ + private static final Logger logger = LoggerFactory.getLogger( MethodHandles.lookup().lookupClass() ); + + @Test + void test() + { + int numCluster1 = 50; + int numCluster2 = 100; + double[][] inputData = RandomDataTools.generateSampleData( numCluster1, numCluster2 ); + logger.debug( "dimensions rows: {}, columns:{}", inputData.length, inputData[ 0 ].length ); + + int targetDimensions = 2; + + DataFrame dataFrame = DataFrame.of( inputData ); + PCA pca = PCA.fit( dataFrame ).getProjection( targetDimensions ); + double[][] pcaResult = pca.apply( inputData ); + + assertEquals( pcaResult.length, inputData.length ); + assertEquals( targetDimensions, pcaResult[ 0 ].length ); + + for ( int i = 0; i < numCluster1; i++ ) + assertTrue( pcaResult[ i ][ 0 ] < 0 ); + for ( int i = numCluster1; i < numCluster1 + numCluster2; i++ ) + assertTrue( pcaResult[ i ][ 0 ] > 0 ); + + } +}