Skip to content

Commit

Permalink
Visual-Text Co-Embedding image retrieval (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spiess authored Jul 26, 2022
1 parent 644ea14 commit 1e7aea5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allprojects {
group = 'org.vitrivr'

/* Our current version, on dev branch this should always be release+1-SNAPSHOT */
version = '3.12.1'
version = '3.12.2'

apply plugin: 'java-library'
apply plugin: 'maven-publish'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package org.vitrivr.cineast.core.features;

import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
Expand Down Expand Up @@ -32,6 +35,8 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule {
private static final String TABLE_NAME = "features_visualtextcoembedding";
private static final Distance DISTANCE = ReadableQueryConfig.Distance.euclidean;

private static final Logger LOGGER = LogManager.getLogger();

/**
* Resource paths.
*/
Expand All @@ -53,7 +58,7 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule {
/**
* Embedding network from text to intermediary embedding.
* <p>
* Currently using UniversalSentenceEncoderV4: https://tfhub.dev/google/universal-sentence-encoder/4
* Currently using <a href="https://tfhub.dev/google/universal-sentence-encoder/4">UniversalSentenceEncoderV4</a>.
*/
private static SavedModelBundle textEmbedding;
/**
Expand All @@ -64,7 +69,7 @@ public class VisualTextCoEmbedding extends AbstractFeatureModule {
/**
* Embedding network from image to intermediary embedding.
* <p>
* Currently using InceptionResNetV2 pretrained on ImageNet: https://storage.googleapis.com/tensorflow/keras-applications/inception_resnet_v2/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5
* Currently using InceptionResNetV2 pretrained on <a href="https://storage.googleapis.com/tensorflow/keras-applications/inception_resnet_v2/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5">ImageNet</a>.
*/
private static SavedModelBundle visualEmbedding;
/**
Expand Down Expand Up @@ -110,15 +115,34 @@ public void processSegment(SegmentContainer sc) {

@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {
String text = sc.getText();

// Ensure the correct distance function is used
QueryConfig queryConfig = QueryConfig.clone(qc);
queryConfig.setDistance(DISTANCE);

float[] embeddingArray = embedText(text);
// Case: segment contains text
if (!sc.getText().isEmpty()) {
String text = sc.getText();
LOGGER.debug("Retrieving with TEXT: " + text);
float[] embeddingArray = embedText(text);

return getSimilar(embeddingArray, queryConfig);
}

// Case: segment contains image
if (sc.getMostRepresentativeFrame() != VideoFrame.EMPTY_VIDEO_FRAME) {
LOGGER.debug("Retrieving with IMAGE.");
BufferedImage image = sc.getMostRepresentativeFrame().getImage().getBufferedImage();

if (image != null) {
float[] embeddingArray = embedImage(image);
return getSimilar(embeddingArray, queryConfig);
}

LOGGER.error("Image was provided, but could not be decoded!");
}

return getSimilar(embeddingArray, queryConfig);
LOGGER.error("Could not get similar because no acceptable modality was provided.");
return new ArrayList<>();
}

@Override
Expand Down

0 comments on commit 1e7aea5

Please sign in to comment.