diff --git a/README.md b/README.md index 3d49b3c..16fd725 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,19 @@ # tflite-react-native -A React Native library for accessing TensorFlow Lite API. Supports Classification and Object Detection on both iOS and Android. +A React Native library for accessing TensorFlow Lite API. Supports Classification, Object Detection, Deeplab and PoseNet on both iOS and Android. + +### Table of Contents + +- [Installation](#Installation) +- [Usage](#Usage) + - [Image Classification](#Image-Classification) + - [Object Detection](#Object-Detection) + - [SSD MobileNet](#SSD-MobileNet) + - [YOLO](#Tiny-YOLOv2) + - [Deeplab](#Deeplab) + - [PoseNet](#PoseNet) +- [Example](#Example) ## Installation @@ -119,9 +131,19 @@ tflite.runModelOnImage({ }); ``` +- Output fomart: +``` +{ + index: 0, + label: "person", + confidence: 0.629 +} +``` + ### Object detection: -- SSD MobileNet +#### SSD MobileNet + ```javascript tflite.detectObjectOnImage({ path: imagePath, @@ -139,7 +161,8 @@ tflite.detectObjectOnImage({ }); ``` -- Tiny YOLOv2 +#### Tiny YOLOv2 + ```javascript tflite.detectObjectOnImage({ path: imagePath, @@ -159,12 +182,123 @@ tflite.detectObjectOnImage({ }); ``` +- Output fomart: + +`x, y, w, h` are between [0, 1]. You can scale `x, w` by the width and `y, h` by the height of the image. + +``` +{ + detectedClass: "hot dog", + confidenceInClass: 0.123, + rect: { + x: 0.15, + y: 0.33, + w: 0.80, + h: 0.27 + } +} +``` + +### Deeplab + +```javascript +tflite.runSegmentationOnImage({ + path: imagePath, + imageMean: 127.5, // defaults to 127.5 + imageStd: 127.5, // defaults to 127.5 + labelColors: [...], // defaults to https://github.com/shaqian/tflite-react-native/blob/master/index.js#L59 + outputType: "png", // defaults to "png" +}, +(err, res) => { + if(err) + console.log(err); + else + console.log(res); +}); +``` + +- Output format: + + The output of Deeplab inference is Uint8List type. Depending on the `outputType` used, the output is: + + - (if outputType is png) byte array of a png image + + - (otherwise) byte array of r, g, b, a values of the pixels + + +### PoseNet + +> Model is from [StackOverflow thread](https://stackoverflow.com/a/55288616). + +```javascript +tflite.runPoseNetOnImage({ + path: imagePath, + imageMean: 127.5, // defaults to 127.5 + imageStd: 127.5, // defaults to 127.5 + numResults: 3, // defaults to 5 + threshold: 0.8, // defaults to 0.5 + nmsRadius: 20, // defaults to 20 +}, +(err, res) => { + if(err) + console.log(err); + else + console.log(res); +}); +``` + +- Output format: + +`x, y` are between [0, 1]. You can scale `x` by the width and `y` by the height of the image. + +``` +[ // array of poses/persons + { // pose #1 + score: 0.6324902, + keypoints: { + 0: { + x: 0.250, + y: 0.125, + part: nose, + score: 0.9971070 + }, + 1: { + x: 0.230, + y: 0.105, + part: leftEye, + score: 0.9978438 + } + ...... + } + }, + { // pose #2 + score: 0.32534285, + keypoints: { + 0: { + x: 0.402, + y: 0.538, + part: nose, + score: 0.8798978 + }, + 1: { + x: 0.380, + y: 0.513, + part: leftEye, + score: 0.7090239 + } + ...... + } + }, + ...... +] +``` + ### Release resources: ``` tflite.close(); ``` -# Demo +## Example Refer to the [example](https://github.com/shaqian/tflite-react-native/tree/master/example). diff --git a/android/src/main/java/com/reactlibrary/TfliteReactNativeModule.java b/android/src/main/java/com/reactlibrary/TfliteReactNativeModule.java index f8d2122..db4006f 100644 --- a/android/src/main/java/com/reactlibrary/TfliteReactNativeModule.java +++ b/android/src/main/java/com/reactlibrary/TfliteReactNativeModule.java @@ -5,8 +5,10 @@ import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.BitmapFactory; +import android.graphics.Color; import android.graphics.Matrix; import android.graphics.Canvas; +import android.util.Base64; import com.facebook.react.bridge.Arguments; import com.facebook.react.bridge.ReactApplicationContext; @@ -22,6 +24,7 @@ import org.tensorflow.lite.Tensor; import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; @@ -73,7 +76,9 @@ private void loadModel(final String modelPath, final String labelsPath, final in tfliteOptions.setNumThreads(numThreads); tfLite = new Interpreter(buffer, tfliteOptions); - loadLabels(assetManager, labelsPath); + if (labelsPath.length() > 0) { + loadLabels(assetManager, labelsPath); + } callback.invoke(null, "success"); } @@ -90,7 +95,7 @@ private void loadLabels(AssetManager assetManager, String path) { labelProb = new float[1][labels.size()]; br.close(); } catch (IOException e) { - throw new RuntimeException("Failed to read label file" , e); + throw new RuntimeException("Failed to read label file", e); } } @@ -130,7 +135,7 @@ ByteBuffer feedInputTensorImage(String path, float mean, float std) throws IOExc inputSize = tensor.shape()[1]; int inputChannels = tensor.shape()[3]; - InputStream inputStream = new FileInputStream(path.replace("file://","")); + InputStream inputStream = new FileInputStream(path.replace("file://", "")); Bitmap bitmapRaw = BitmapFactory.decodeStream(inputStream); Matrix matrix = getTransformationMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(), @@ -155,9 +160,9 @@ ByteBuffer feedInputTensorImage(String path, float mean, float std) throws IOExc imgData.putFloat((((pixelValue >> 8) & 0xFF) - mean) / std); imgData.putFloat(((pixelValue & 0xFF) - mean) / std); } else { - imgData.put((byte)((pixelValue >> 16) & 0xFF)); - imgData.put((byte)((pixelValue >> 8) & 0xFF)); - imgData.put((byte)(pixelValue & 0xFF)); + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); } } } @@ -269,7 +274,7 @@ public int compare(WritableMap lhs, WritableMap rhs) { for (int b = 0; b < numBoxesPerBlock; ++b) { final int offset = (numClasses + 5) * b; - final float confidence = expit(output[0][y][x][offset + 4]); + final float confidence = sigmoid(output[0][y][x][offset + 4]); final float[] classes = new float[numClasses]; for (int c = 0; c < numClasses; ++c) { @@ -288,8 +293,8 @@ public int compare(WritableMap lhs, WritableMap rhs) { final float confidenceInClass = maxClass * confidence; if (confidenceInClass > threshold) { - final float xPos = (x + expit(output[0][y][x][offset + 0])) * blockSize; - final float yPos = (y + expit(output[0][y][x][offset + 1])) * blockSize; + final float xPos = (x + sigmoid(output[0][y][x][offset + 0])) * blockSize; + final float yPos = (y + sigmoid(output[0][y][x][offset + 1])) * blockSize; final float w = (float) (Math.exp(output[0][y][x][offset + 2]) * anchors.getDouble(2 * b + 0)) * blockSize; final float h = (float) (Math.exp(output[0][y][x][offset + 3]) * anchors.getDouble(2 * b + 1)) * blockSize; @@ -337,7 +342,425 @@ public int compare(WritableMap lhs, WritableMap rhs) { return results; } - private float expit(final float x) { + byte[] fetchArgmax(ByteBuffer output, ReadableArray labelColors, String outputType) { + Tensor outputTensor = tfLite.getOutputTensor(0); + int outputBatchSize = outputTensor.shape()[0]; + assert outputBatchSize == 1; + int outputHeight = outputTensor.shape()[1]; + int outputWidth = outputTensor.shape()[2]; + int outputChannels = outputTensor.shape()[3]; + + Bitmap outputArgmax = null; + byte[] outputBytes = new byte[outputWidth * outputHeight * 4]; + if (outputType.equals("png")) { + outputArgmax = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888); + } + + if (outputTensor.dataType() == DataType.FLOAT32) { + for (int i = 0; i < outputHeight; ++i) { + for (int j = 0; j < outputWidth; ++j) { + int maxIndex = 0; + float maxValue = 0.0f; + for (int c = 0; c < outputChannels; ++c) { + float outputValue = output.getFloat(); + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + int labelColor = labelColors.getInt(maxIndex); + if (outputType.equals("png")) { + outputArgmax.setPixel(j, i, Color.rgb((labelColor >> 16) & 0xFF, (labelColor >> 8) & 0xFF, labelColor & 0xFF)); + } else { + setPixel(outputBytes, i * outputWidth + j, labelColor); + } + } + } + } else { + for (int i = 0; i < outputHeight; ++i) { + for (int j = 0; j < outputWidth; ++j) { + int maxIndex = 0; + int maxValue = 0; + for (int c = 0; c < outputChannels; ++c) { + int outputValue = output.get(); + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + int labelColor = labelColors.getInt(maxIndex); + if (outputType.equals("png")) { + outputArgmax.setPixel(j, i, Color.rgb((labelColor >> 16) & 0xFF, (labelColor >> 8) & 0xFF, labelColor & 0xFF)); + } else { + setPixel(outputBytes, i * outputWidth + j, labelColor); + } + } + } + } + if (outputType.equals("png")) { + return compressPNG(outputArgmax); + } else { + return outputBytes; + } + } + + void setPixel(byte[] rgba, int index, long color) { + rgba[index * 4] = (byte) ((color >> 16) & 0xFF); + rgba[index * 4 + 1] = (byte) ((color >> 8) & 0xFF); + rgba[index * 4 + 2] = (byte) (color & 0xFF); + rgba[index * 4 + 3] = (byte) ((color >> 24) & 0xFF); + } + + byte[] compressPNG(Bitmap bitmap) { + // https://stackoverflow.com/questions/4989182/converting-java-bitmap-to-byte-array#4989543 + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + bitmap.compress(Bitmap.CompressFormat.PNG, 100, stream); + byte[] byteArray = stream.toByteArray(); + // bitmap.recycle(); + return byteArray; + } + + @ReactMethod + private void runSegmentationOnImage(final String path, final float mean, final float std, final ReadableArray labelColors, + final String outputType, final Callback callback) throws IOException { + int i = tfLite.getOutputTensor(0).numBytes(); + ByteBuffer output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes()); + output.order(ByteOrder.nativeOrder()); + tfLite.run(feedInputTensorImage(path, mean, std), output); + + if (output.position() != output.limit()) { + callback.invoke("Unexpected output position", null); + return; + } + output.flip(); + + byte[] res = fetchArgmax(output, labelColors, outputType); + String base64String = Base64.encodeToString(res, Base64.NO_WRAP); + + callback.invoke(null, base64String); + } + + String[] partNames = { + "nose", "leftEye", "rightEye", "leftEar", "rightEar", "leftShoulder", + "rightShoulder", "leftElbow", "rightElbow", "leftWrist", "rightWrist", + "leftHip", "rightHip", "leftKnee", "rightKnee", "leftAnkle", "rightAnkle" + }; + + String[][] poseChain = { + {"nose", "leftEye"}, {"leftEye", "leftEar"}, {"nose", "rightEye"}, + {"rightEye", "rightEar"}, {"nose", "leftShoulder"}, + {"leftShoulder", "leftElbow"}, {"leftElbow", "leftWrist"}, + {"leftShoulder", "leftHip"}, {"leftHip", "leftKnee"}, + {"leftKnee", "leftAnkle"}, {"nose", "rightShoulder"}, + {"rightShoulder", "rightElbow"}, {"rightElbow", "rightWrist"}, + {"rightShoulder", "rightHip"}, {"rightHip", "rightKnee"}, + {"rightKnee", "rightAnkle"} + }; + + Map partsIds = new HashMap<>(); + List parentToChildEdges = new ArrayList<>(); + List childToParentEdges = new ArrayList<>(); + + + void initPoseNet(Map outputMap) { + if (partsIds.size() == 0) { + for (int i = 0; i < partNames.length; ++i) + partsIds.put(partNames[i], i); + + for (int i = 0; i < poseChain.length; ++i) { + parentToChildEdges.add(partsIds.get(poseChain[i][1])); + childToParentEdges.add(partsIds.get(poseChain[i][0])); + } + } + + for (int i = 0; i < tfLite.getOutputTensorCount(); i++) { + int[] shape = tfLite.getOutputTensor(i).shape(); + float[][][][] output = new float[shape[0]][shape[1]][shape[2]][shape[3]]; + outputMap.put(i, output); + } + } + + @ReactMethod + private void runPoseNetOnImage(final String path, final float mean, final float std, final int numResults, + final float threshold, final int nmsRadius, final Callback callback) throws IOException { + int localMaximumRadius = 1; + int outputStride = 16; + + ByteBuffer imgData = feedInputTensorImage(path, mean, std); + Object[] input = new Object[]{imgData}; + + Map outputMap = new HashMap<>(); + initPoseNet(outputMap); + + tfLite.runForMultipleInputsOutputs(input, outputMap); + + float[][][] scores = ((float[][][][]) outputMap.get(0))[0]; + float[][][] offsets = ((float[][][][]) outputMap.get(1))[0]; + float[][][] displacementsFwd = ((float[][][][]) outputMap.get(2))[0]; + float[][][] displacementsBwd = ((float[][][][]) outputMap.get(3))[0]; + + PriorityQueue> pq = buildPartWithScoreQueue(scores, threshold, localMaximumRadius); + + int numParts = scores[0][0].length; + int numEdges = parentToChildEdges.size(); + int sqaredNmsRadius = nmsRadius * nmsRadius; + + List> results = new ArrayList<>(); + + while (results.size() < numResults && pq.size() > 0) { + Map root = pq.poll(); + float[] rootPoint = getImageCoords(root, outputStride, numParts, offsets); + + if (withinNmsRadiusOfCorrespondingPoint( + results, sqaredNmsRadius, rootPoint[0], rootPoint[1], (int) root.get("partId"))) + continue; + + Map keypoint = new HashMap<>(); + keypoint.put("score", root.get("score")); + keypoint.put("part", partNames[(int) root.get("partId")]); + keypoint.put("y", rootPoint[0] / inputSize); + keypoint.put("x", rootPoint[1] / inputSize); + + Map> keypoints = new HashMap<>(); + keypoints.put((int) root.get("partId"), keypoint); + + for (int edge = numEdges - 1; edge >= 0; --edge) { + int sourceKeypointId = parentToChildEdges.get(edge); + int targetKeypointId = childToParentEdges.get(edge); + if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { + keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), + targetKeypointId, scores, offsets, outputStride, displacementsBwd); + keypoints.put(targetKeypointId, keypoint); + } + } + + for (int edge = 0; edge < numEdges; ++edge) { + int sourceKeypointId = childToParentEdges.get(edge); + int targetKeypointId = parentToChildEdges.get(edge); + if (keypoints.containsKey(sourceKeypointId) && !keypoints.containsKey(targetKeypointId)) { + keypoint = traverseToTargetKeypoint(edge, keypoints.get(sourceKeypointId), + targetKeypointId, scores, offsets, outputStride, displacementsFwd); + keypoints.put(targetKeypointId, keypoint); + } + } + + Map result = new HashMap<>(); + result.put("keypoints", keypoints); + result.put("score", getInstanceScore(keypoints, numParts)); + results.add(result); + } + + WritableArray outputs = Arguments.createArray(); + for (Map result : results) { + Map> keypoints = (Map>) result.get("keypoints"); + + WritableMap _keypoints = Arguments.createMap(); + for (Map.Entry> keypoint : keypoints.entrySet()) { + Map keypoint_ = keypoint.getValue(); + WritableMap _keypoint = Arguments.createMap(); + _keypoint.putDouble("score", Double.valueOf(keypoint_.get("score").toString())); + _keypoint.putString("part", keypoint_.get("part").toString()); + _keypoint.putDouble("y", Double.valueOf(keypoint_.get("y").toString())); + _keypoint.putDouble("x", Double.valueOf(keypoint_.get("x").toString())); + _keypoints.putMap(keypoint.getKey().toString(), _keypoint); + } + + WritableMap output = Arguments.createMap(); + output.putMap("keypoints", _keypoints); + output.putDouble("score", Double.valueOf(result.get("score").toString())); + + outputs.pushMap(output); + } + + callback.invoke(null, outputs); + } + + + PriorityQueue> buildPartWithScoreQueue(float[][][] scores, + double threshold, + int localMaximumRadius) { + PriorityQueue> pq = + new PriorityQueue<>( + 1, + new Comparator>() { + @Override + public int compare(Map lhs, Map rhs) { + return Float.compare((float) rhs.get("score"), (float) lhs.get("score")); + } + }); + + for (int heatmapY = 0; heatmapY < scores.length; ++heatmapY) { + for (int heatmapX = 0; heatmapX < scores[0].length; ++heatmapX) { + for (int keypointId = 0; keypointId < scores[0][0].length; ++keypointId) { + float score = sigmoid(scores[heatmapY][heatmapX][keypointId]); + if (score < threshold) continue; + + if (scoreIsMaximumInLocalWindow( + keypointId, score, heatmapY, heatmapX, localMaximumRadius, scores)) { + Map res = new HashMap<>(); + res.put("score", score); + res.put("y", heatmapY); + res.put("x", heatmapX); + res.put("partId", keypointId); + pq.add(res); + } + } + } + } + + return pq; + } + + boolean scoreIsMaximumInLocalWindow(int keypointId, + float score, + int heatmapY, + int heatmapX, + int localMaximumRadius, + float[][][] scores) { + boolean localMaximum = true; + int height = scores.length; + int width = scores[0].length; + + int yStart = Math.max(heatmapY - localMaximumRadius, 0); + int yEnd = Math.min(heatmapY + localMaximumRadius + 1, height); + for (int yCurrent = yStart; yCurrent < yEnd; ++yCurrent) { + int xStart = Math.max(heatmapX - localMaximumRadius, 0); + int xEnd = Math.min(heatmapX + localMaximumRadius + 1, width); + for (int xCurrent = xStart; xCurrent < xEnd; ++xCurrent) { + if (sigmoid(scores[yCurrent][xCurrent][keypointId]) > score) { + localMaximum = false; + break; + } + } + if (!localMaximum) { + break; + } + } + + return localMaximum; + } + + float[] getImageCoords(Map keypoint, + int outputStride, + int numParts, + float[][][] offsets) { + int heatmapY = (int) keypoint.get("y"); + int heatmapX = (int) keypoint.get("x"); + int keypointId = (int) keypoint.get("partId"); + float offsetY = offsets[heatmapY][heatmapX][keypointId]; + float offsetX = offsets[heatmapY][heatmapX][keypointId + numParts]; + + float y = heatmapY * outputStride + offsetY; + float x = heatmapX * outputStride + offsetX; + + return new float[]{y, x}; + } + + boolean withinNmsRadiusOfCorrespondingPoint(List> poses, + float squaredNmsRadius, + float y, + float x, + int keypointId) { + for (Map pose : poses) { + Map keypoints = (Map) pose.get("keypoints"); + Map correspondingKeypoint = (Map) keypoints.get(keypointId); + float _x = (float) correspondingKeypoint.get("x") * inputSize - x; + float _y = (float) correspondingKeypoint.get("y") * inputSize - y; + float squaredDistance = _x * _x + _y * _y; + if (squaredDistance <= squaredNmsRadius) + return true; + } + + return false; + } + + Map traverseToTargetKeypoint(int edgeId, + Map sourceKeypoint, + int targetKeypointId, + float[][][] scores, + float[][][] offsets, + int outputStride, + float[][][] displacements) { + int height = scores.length; + int width = scores[0].length; + int numKeypoints = scores[0][0].length; + float sourceKeypointY = (float) sourceKeypoint.get("y") * inputSize; + float sourceKeypointX = (float) sourceKeypoint.get("x") * inputSize; + + int[] sourceKeypointIndices = getStridedIndexNearPoint(sourceKeypointY, sourceKeypointX, + outputStride, height, width); + + float[] displacement = getDisplacement(edgeId, sourceKeypointIndices, displacements); + + float[] displacedPoint = new float[]{ + sourceKeypointY + displacement[0], + sourceKeypointX + displacement[1] + }; + + float[] targetKeypoint = displacedPoint; + + final int offsetRefineStep = 2; + for (int i = 0; i < offsetRefineStep; i++) { + int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], + outputStride, height, width); + + int targetKeypointY = targetKeypointIndices[0]; + int targetKeypointX = targetKeypointIndices[1]; + + float offsetY = offsets[targetKeypointY][targetKeypointX][targetKeypointId]; + float offsetX = offsets[targetKeypointY][targetKeypointX][targetKeypointId + numKeypoints]; + + targetKeypoint = new float[]{ + targetKeypointY * outputStride + offsetY, + targetKeypointX * outputStride + offsetX + }; + } + + int[] targetKeypointIndices = getStridedIndexNearPoint(targetKeypoint[0], targetKeypoint[1], + outputStride, height, width); + + float score = sigmoid(scores[targetKeypointIndices[0]][targetKeypointIndices[1]][targetKeypointId]); + + Map keypoint = new HashMap<>(); + keypoint.put("score", score); + keypoint.put("part", partNames[targetKeypointId]); + keypoint.put("y", targetKeypoint[0] / inputSize); + keypoint.put("x", targetKeypoint[1] / inputSize); + + return keypoint; + } + + int[] getStridedIndexNearPoint(float _y, float _x, int outputStride, int height, int width) { + int y_ = Math.round(_y / outputStride); + int x_ = Math.round(_x / outputStride); + int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_; + int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_; + return new int[]{y, x}; + } + + float[] getDisplacement(int edgeId, int[] keypoint, float[][][] displacements) { + int numEdges = displacements[0][0].length / 2; + int y = keypoint[0]; + int x = keypoint[1]; + return new float[]{displacements[y][x][edgeId], displacements[y][x][edgeId + numEdges]}; + } + + float getInstanceScore(Map> keypoints, int numKeypoints) { + float scores = 0; + for (Map.Entry> keypoint : keypoints.entrySet()) + scores += (float) keypoint.getValue().get("score"); + return scores / numKeypoints; + } + + @ReactMethod + private void close() { + tfLite.close(); + labels = null; + labelProb = null; + } + + + private float sigmoid(final float x) { return (float) (1. / (1. + Math.exp(-x))); } @@ -360,8 +783,7 @@ private static Matrix getTransformationMatrix(final int srcWidth, final int srcHeight, final int dstWidth, final int dstHeight, - final boolean maintainAspectRatio) - { + final boolean maintainAspectRatio) { final Matrix matrix = new Matrix(); if (srcWidth != dstWidth || srcHeight != dstHeight) { @@ -380,10 +802,4 @@ private static Matrix getTransformationMatrix(final int srcWidth, return matrix; } - @ReactMethod - private void close() { - tfLite.close(); - labels = null; - labelProb = null; - } } \ No newline at end of file diff --git a/example/App.js b/example/App.js index 0741d6f..674cecb 100644 --- a/example/App.js +++ b/example/App.js @@ -11,6 +11,8 @@ const blue = "#25d5fd"; const mobile = "MobileNet"; const ssd = "SSD MobileNet"; const yolo = "Tiny YOLOv2"; +const deeplab = "Deeplab"; +const posenet = "PoseNet"; type Props = {}; export default class App extends Component { @@ -36,6 +38,14 @@ export default class App extends Component { var modelFile = 'models/yolov2_tiny.tflite'; var labelsFile = 'models/yolov2_tiny.txt'; break; + case deeplab: + var modelFile = 'models/deeplabv3_257_mv_gpu.tflite'; + var labelsFile = 'models/deeplabv3_257_mv_gpu.txt'; + break; + case posenet: + var modelFile = 'models/posenet_mv1_075_float_from_checkpoints.tflite'; + var labelsFile = ''; + break; default: var modelFile = 'models/mobilenet_v1_1.0_224.tflite'; var labelsFile = 'models/mobilenet_v1_1.0_224.txt'; @@ -92,6 +102,7 @@ export default class App extends Component { this.setState({ recognitions: res }); }); break; + case yolo: tflite.detectObjectOnImage({ path, @@ -108,6 +119,32 @@ export default class App extends Component { this.setState({ recognitions: res }); }); break; + + case deeplab: + tflite.runSegmentationOnImage({ + path + }, + (err, res) => { + if (err) + console.log(err); + else + this.setState({ recognitions: res }); + }); + break; + + case posenet: + tflite.runPoseNetOnImage({ + path, + threshold: 0.8 + }, + (err, res) => { + if (err) + console.log(err); + else + this.setState({ recognitions: res }); + }); + break; + default: tflite.runModelOnImage({ path, @@ -127,30 +164,64 @@ export default class App extends Component { }); } - renderBoxes() { + renderResults() { const { model, recognitions, imageHeight, imageWidth } = this.state; - if (model == mobile) - return recognitions.map((res, id) => { - return ( - - {res["label"] + "-" + (res["confidence"] * 100).toFixed(0) + "%"} - - ) - }); - else - return recognitions.map((res, id) => { - var left = res["rect"]["x"] * imageWidth; - var top = res["rect"]["y"] * imageHeight; - var width = res["rect"]["w"] * imageWidth; - var height = res["rect"]["h"] * imageHeight; + switch (model) { + case ssd: + case yolo: + return recognitions.map((res, id) => { + var left = res["rect"]["x"] * imageWidth; + var top = res["rect"]["y"] * imageHeight; + var width = res["rect"]["w"] * imageWidth; + var height = res["rect"]["h"] * imageHeight; + return ( + + + {res["detectedClass"] + " " + (res["confidenceInClass"] * 100).toFixed(0) + "%"} + + + ) + }); + break; + + case deeplab: return ( - - - {res["detectedClass"] + " " + (res["confidenceInClass"] * 100).toFixed(0) + "%"} + recognitions.length > 0 ? + : undefined + ); + break; + + case posenet: + return recognitions.map((res, id) => { + return Object.values(res["keypoints"]).map((k, id) => { + var left = k["x"] * imageWidth - 6; + var top = k["y"] * imageHeight - 6; + var width = imageWidth; + var height = imageHeight; + return ( + + + {"● " + k["part"]} + + + ) + }); + }); + break; + + default: + return recognitions.map((res, id) => { + return ( + + {res["label"] + "-" + (res["confidence"] * 100).toFixed(0) + "%"} - - ) - }); + ) + }); + } } render() { @@ -179,7 +250,7 @@ export default class App extends Component { Select Picture } - {this.renderBoxes()} + {this.renderResults()} : @@ -187,6 +258,8 @@ export default class App extends Component { {renderButton(mobile)} {renderButton(ssd)} {renderButton(yolo)} + {renderButton(deeplab)} + {renderButton(posenet)} } diff --git a/example/models/deeplabv3_257_mv_gpu.tflite b/example/models/deeplabv3_257_mv_gpu.tflite new file mode 100644 index 0000000..d2d9a9b Binary files /dev/null and b/example/models/deeplabv3_257_mv_gpu.tflite differ diff --git a/example/models/deeplabv3_257_mv_gpu.txt b/example/models/deeplabv3_257_mv_gpu.txt new file mode 100644 index 0000000..ecfffa3 --- /dev/null +++ b/example/models/deeplabv3_257_mv_gpu.txt @@ -0,0 +1,21 @@ +background +aeroplane +biyclce +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +potted plant +sheep +sofa +train +tv-monitor diff --git a/example/models/posenet_mv1_075_float_from_checkpoints.tflite b/example/models/posenet_mv1_075_float_from_checkpoints.tflite new file mode 100644 index 0000000..4ccbcc5 Binary files /dev/null and b/example/models/posenet_mv1_075_float_from_checkpoints.tflite differ diff --git a/example/package.json b/example/package.json index cfc7b1f..ae6b681 100644 --- a/example/package.json +++ b/example/package.json @@ -10,7 +10,7 @@ "react": "16.6.3", "react-native": "0.58.4", "react-native-image-picker": "0.28.0", - "tflite-react-native": "0.0.4" + "tflite-react-native": "0.0.5" }, "devDependencies": { "babel-core": "7.0.0-bridge.0", diff --git a/index.js b/index.js index 92a69a7..8995f79 100644 --- a/index.js +++ b/index.js @@ -5,8 +5,8 @@ const { TfliteReactNative } = NativeModules; class Tflite { loadModel(args, callback) { TfliteReactNative.loadModel( - args['model'], - args['labels'], + args['model'], + args['labels'] || '', args['numThreads'] || 1, (error, response) => { callback && callback(error, response); @@ -15,9 +15,9 @@ class Tflite { runModelOnImage(args, callback) { TfliteReactNative.runModelOnImage( - args['path'], - args['imageMean'] != null ? args['imageMean'] : 127.5, - args['imageStd'] != null ? args['imageStd'] : 127.5, + args['path'], + args['imageMean'] != null ? args['imageMean'] : 127.5, + args['imageStd'] != null ? args['imageStd'] : 127.5, args['numResults'] || 5, args['threshold'] != null ? args['threshold'] : 0.1, (error, response) => { @@ -27,9 +27,9 @@ class Tflite { detectObjectOnImage(args, callback) { TfliteReactNative.detectObjectOnImage( - args['path'], - args['model'] || "SSDMobileNet", - args['imageMean'] != null ? args['imageMean'] : 127.5, + args['path'], + args['model'] || "SSDMobileNet", + args['imageMean'] != null ? args['imageMean'] : 127.5, args['imageStd'] != null ? args['imageStd'] : 127.5, args['threshold'] != null ? args['threshold'] : 0.1, args['numResultsPerClass'] || 5, @@ -51,6 +51,53 @@ class Tflite { }); } + runSegmentationOnImage(args, callback) { + TfliteReactNative.runSegmentationOnImage( + args['path'], + args['imageMean'] != null ? args['imageMean'] : 127.5, + args['imageStd'] != null ? args['imageStd'] : 127.5, + args['labelColors'] || [ + 0x000000, // background + 0x800000, // aeroplane + 0x008000, // biyclce + 0x808000, // bird + 0x000080, // boat + 0x800080, // bottle + 0x008080, // bus + 0x808080, // car + 0x400000, // cat + 0xc00000, // chair + 0x408000, // cow + 0xc08000, // diningtable + 0x400080, // dog + 0xc00080, // horse + 0x408080, // motorbike + 0xc08080, // person + 0x004000, // potted plant + 0x804000, // sheep + 0x00c000, // sofa + 0x80c000, // train + 0x004080, // tv-monitor + ], + args['outputType'] || "png", + (error, response) => { + callback && callback(error, response); + }); + } + + runPoseNetOnImage(args, callback) { + TfliteReactNative.runPoseNetOnImage( + args['path'], + args['imageMean'] != null ? args['imageMean'] : 127.5, + args['imageStd'] != null ? args['imageStd'] : 127.5, + args['numResults'] || 5, + args['threshold'] != null ? args['threshold'] : 0.5, + args['nmsRadius'] || 20, + (error, response) => { + callback && callback(error, response); + }); + } + close() { TfliteReactNative.close(); } diff --git a/ios/TfliteReactNative.mm b/ios/TfliteReactNative.mm index bd22bc9..b53d771 100644 --- a/ios/TfliteReactNative.mm +++ b/ios/TfliteReactNative.mm @@ -1,3 +1,4 @@ +// #define CONTRIB_PATH #import "TfliteReactNative.h" @@ -8,11 +9,19 @@ #include #include #include +#import +#ifdef CONTRIB_PATH #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/op_resolver.h" +#else +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/op_resolver.h" +#endif #include "ios_image_load.h" @@ -38,9 +47,8 @@ static void LoadLabels(NSString* labels_path, } std::ifstream t; t.open([labels_path UTF8String]); - std::string line; - while (t) { - std::getline(t, line); + label_strings->clear(); + for (std::string line; std::getline(t, line); ) { label_strings->push_back(line); } t.close(); @@ -62,7 +70,9 @@ static void LoadLabels(NSString* labels_path, } NSString* labels_path = [[NSBundle mainBundle] pathForResource:labels_file ofType:nil]; - LoadLabels(labels_path, &labels); + if ([labels_path length] > 0) { + LoadLabels(labels_path, &labels); + } tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*model, resolver)(&interpreter); @@ -81,13 +91,7 @@ static void LoadLabels(NSString* labels_path, callback(@[[NSNull null], @"sucess"]); } -void feedInputTensorImage(const NSString* image_path, float input_mean, float input_std, int* input_size) { - int image_channels; - int image_height; - int image_width; - std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); - uint8_t* in = image_data.data(); - +void feedInputTensor(uint8_t* in, int* input_size, int image_height, int image_width, int image_channels, float input_mean, float input_std) { assert(interpreter->inputs().size() == 1); int input = interpreter->inputs()[0]; TfLiteTensor* input_tensor = interpreter->tensor(input); @@ -129,6 +133,15 @@ void feedInputTensorImage(const NSString* image_path, float input_mean, float in } } +void feedInputTensorImage(const NSString* image_path, float input_mean, float input_std, int* input_size) { + int image_channels; + int image_height; + int image_width; + std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + uint8_t* in = image_data.data(); + feedInputTensor(in, input_size, image_height, image_width, image_channels, input_mean, input_std); +} + NSMutableArray* GetTopN(const float* prediction, const unsigned long prediction_size, const int num_results, const float threshold) { std::priority_queue, std::vector>, @@ -223,6 +236,7 @@ void feedInputTensorImage(const NSString* image_path, float input_mean, float in NSMutableDictionary* res = [NSMutableDictionary dictionary]; NSString* class_name = [NSString stringWithUTF8String:labels[detected_class + 1].c_str()]; NSObject* counter = [counters objectForKey:class_name]; + if (counter) { int countValue = [(NSNumber*)counter intValue] + 1; if (countValue > num_results_per_class) { @@ -250,7 +264,6 @@ void feedInputTensorImage(const NSString* image_path, float input_mean, float in [res setObject:rect forKey:@"rect"]; [results addObject:res]; } - return results; } @@ -387,12 +400,417 @@ void softmax(float vals[], int count) { if ([model isEqual: @"SSDMobileNet"]) results = parseSSDMobileNet(threshold, num_results_per_class); else - results = parseYOLO((int)(labels.size() - 1), anchors, block_size, 5, num_results_per_class, - threshold, input_size); + results = parseYOLO((int)labels.size(), anchors, block_size, 5, num_results_per_class, + threshold, input_size); callback(@[[NSNull null], results]); } +void setPixel(char* rgba, int index, long color) { + rgba[index * 4] = (color >> 16) & 0xFF; + rgba[index * 4 + 1] = (color >> 8) & 0xFF; + rgba[index * 4 + 2] = color & 0xFF; + rgba[index * 4 + 3] = (color >> 24) & 0xFF; +} + +NSData* fetchArgmax(const NSArray* labelColors, const NSString* outputType) { + int output = interpreter->outputs()[0]; + TfLiteTensor* output_tensor = interpreter->tensor(output); + const int height = output_tensor->dims->data[1]; + const int width = output_tensor->dims->data[2]; + const int channels = output_tensor->dims->data[3]; + + NSMutableData *data = nil; + int size = height * width * 4; + data = [[NSMutableData dataWithCapacity: size] initWithLength: size]; + char* out = (char*)[data bytes]; + if (output_tensor->type == kTfLiteUInt8) { + const uint8_t* bytes = interpreter->typed_tensor(output); + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + int index = i * width + j; + int maxIndex = 0; + int maxValue = 0; + for (int c = 0; c < channels; ++c) { + int outputValue = bytes[index* channels + c]; + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + long labelColor = [[labelColors objectAtIndex:maxIndex] longValue]; + setPixel(out, index, labelColor); + } + } + } else { // kTfLiteFloat32 + const float* bytes = interpreter->typed_tensor(output); + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + int index = i * width + j; + int maxIndex = 0; + float maxValue = .0f; + for (int c = 0; c < channels; ++c) { + float outputValue = bytes[index * channels + c]; + if (outputValue > maxValue) { + maxIndex = c; + maxValue = outputValue; + } + } + long labelColor = [[labelColors objectAtIndex:maxIndex] longValue]; + setPixel(out, index, labelColor); + } + } + } + + if ([outputType isEqual: @"png"]) { + CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); + CGContextRef bitmapContext = CGBitmapContextCreate(out, + width, + height, + 8, // bitsPerComponent + 4 * width, // bytesPerRow + colorSpace, + kCGImageAlphaNoneSkipLast); + + CFRelease(colorSpace); + CGImageRef cgImage = CGBitmapContextCreateImage(bitmapContext); + NSData* image = UIImagePNGRepresentation([[UIImage alloc] initWithCGImage:cgImage]); + CFRelease(cgImage); + CFRelease(bitmapContext); + return image; + } else { + return data; + } +} + + +RCT_EXPORT_METHOD(runSegmentationOnImage:(NSString*)image_path + mean:(float)input_mean + std:(float)input_std + labelColors:(NSArray*)label_colors + outputType:(NSString*)output_type + callback:(RCTResponseSenderBlock)callback) { + + if (!interpreter) { + NSLog(@"Failed to construct interpreter."); + callback(@[@"Failed to construct interpreter."]); + } + + image_path = [image_path stringByReplacingOccurrencesOfString:@"file://" withString:@""]; + int input_size; + feedInputTensorImage(image_path, input_mean, input_std, &input_size); + + if (interpreter->Invoke() != kTfLiteOk) { + NSLog(@"Failed to invoke!"); + callback(@[@"Failed to invoke!"]); + } + + NSData* output = fetchArgmax(label_colors, output_type); + NSString* base64String = [output base64EncodedStringWithOptions:0]; + callback(@[[NSNull null], base64String]); +} + +NSArray* part_names = @[ + @"nose", @"leftEye", @"rightEye", @"leftEar", @"rightEar", @"leftShoulder", + @"rightShoulder", @"leftElbow", @"rightElbow", @"leftWrist", @"rightWrist", + @"leftHip", @"rightHip", @"leftKnee", @"rightKnee", @"leftAnkle", @"rightAnkle" + ]; + +NSArray* pose_chain = @[ + @[@"nose", @"leftEye"], @[@"leftEye", @"leftEar"], @[@"nose", @"rightEye"], + @[@"rightEye", @"rightEar"], @[@"nose", @"leftShoulder"], + @[@"leftShoulder", @"leftElbow"], @[@"leftElbow", @"leftWrist"], + @[@"leftShoulder", @"leftHip"], @[@"leftHip", @"leftKnee"], + @[@"leftKnee", @"leftAnkle"], @[@"nose", @"rightShoulder"], + @[@"rightShoulder", @"rightElbow"], @[@"rightElbow", @"rightWrist"], + @[@"rightShoulder", @"rightHip"], @[@"rightHip", @"rightKnee"], + @[@"rightKnee", @"rightAnkle"] + ]; + +NSMutableDictionary* parts_ids = [NSMutableDictionary dictionary]; +NSMutableArray* parent_to_child_edges = [NSMutableArray array]; +NSMutableArray* child_to_parent_edges = [NSMutableArray array]; +int local_maximum_radius = 1; +int output_stride = 16; +int height; +int width; +int num_keypoints; + +void initPoseNet() { + if ([parts_ids count] == 0) { + for (int i = 0; i < [part_names count]; ++i) + [parts_ids setValue:[NSNumber numberWithInt:i] forKey:part_names[i]]; + + for (int i = 0; i < [pose_chain count]; ++i) { + [parent_to_child_edges addObject:parts_ids[pose_chain[i][1]]]; + [child_to_parent_edges addObject:parts_ids[pose_chain[i][0]]]; + } + } +} + +bool scoreIsMaximumInLocalWindow(int keypoint_id, + float score, + int heatmap_y, + int heatmap_x, + int local_maximum_radius, + float* scores) { + bool local_maxium = true; + + int y_start = MAX(heatmap_y - local_maximum_radius, 0); + int y_end = MIN(heatmap_y + local_maximum_radius + 1, height); + for (int y_current = y_start; y_current < y_end; ++y_current) { + int x_start = MAX(heatmap_x - local_maximum_radius, 0); + int x_end = MIN(heatmap_x + local_maximum_radius + 1, width); + for (int x_current = x_start; x_current < x_end; ++x_current) { + if (sigmoid(scores[(y_current * width + x_current) * num_keypoints + keypoint_id]) > score) { + local_maxium = false; + break; + } + } + if (!local_maxium) { + break; + } + } + return local_maxium; +} + +typedef std::priority_queue, +std::vector>, +std::less>> PriorityQueue; + +PriorityQueue buildPartWithScoreQueue(float* scores, + float threshold, + int local_maximum_radius) { + PriorityQueue pq; + for (int heatmap_y = 0; heatmap_y < height; ++heatmap_y) { + for (int heatmap_x = 0; heatmap_x < width; ++heatmap_x) { + for (int keypoint_id = 0; keypoint_id < num_keypoints; ++keypoint_id) { + float score = sigmoid(scores[(heatmap_y * width + heatmap_x) * + num_keypoints + keypoint_id]); + if (score < threshold) continue; + + if (scoreIsMaximumInLocalWindow(keypoint_id, score, heatmap_y, heatmap_x, + local_maximum_radius, scores)) { + NSMutableDictionary* res = [NSMutableDictionary dictionary]; + [res setValue:[NSNumber numberWithFloat:score] forKey:@"score"]; + [res setValue:[NSNumber numberWithInt:heatmap_y] forKey:@"y"]; + [res setValue:[NSNumber numberWithInt:heatmap_x] forKey:@"x"]; + [res setValue:[NSNumber numberWithInt:keypoint_id] forKey:@"partId"]; + pq.push(std::pair(score, res)); + } + } + } + } + return pq; +} + +void getImageCoords(float* res, + NSMutableDictionary* keypoint, + float* offsets) { + int heatmap_y = [keypoint[@"y"] intValue]; + int heatmap_x = [keypoint[@"x"] intValue]; + int keypoint_id = [keypoint[@"partId"] intValue]; + + int offset = (heatmap_y * width + heatmap_x) * num_keypoints * 2 + keypoint_id; + float offset_y = offsets[offset]; + float offset_x = offsets[offset + num_keypoints]; + res[0] = heatmap_y * output_stride + offset_y; + res[1] = heatmap_x * output_stride + offset_x; +} + + +bool withinNmsRadiusOfCorrespondingPoint(NSMutableArray* poses, + float squared_nms_radius, + float y, + float x, + int keypoint_id, + int input_size) { + for (NSMutableDictionary* pose in poses) { + NSMutableDictionary* keypoints = pose[@"keypoints"]; + NSMutableDictionary* correspondingKeypoint = keypoints[[NSNumber numberWithInt:keypoint_id]]; + float _x = [correspondingKeypoint[@"x"] floatValue] * input_size - x; + float _y = [correspondingKeypoint[@"y"] floatValue] * input_size - y; + float squaredDistance = _x * _x + _y * _y; + if (squaredDistance <= squared_nms_radius) + return true; + } + return false; +} + +void getStridedIndexNearPoint(int* res, float _y, float _x) { + int y_ = round(_y / output_stride); + int x_ = round(_x / output_stride); + int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_; + int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_; + res[0] = y; + res[1] = x; +} + +void getDisplacement(float* res, int edgeId, int* keypoint, float* displacements) { + int num_edges = (int)[parent_to_child_edges count]; + int y = keypoint[0]; + int x = keypoint[1]; + int offset = (y * width + x) * num_edges * 2 + edgeId; + res[0] = displacements[offset]; + res[1] = displacements[offset + num_edges]; +} + +float getInstanceScore(NSMutableDictionary* keypoints) { + float scores = 0; + for (NSMutableDictionary* keypoint in keypoints.allValues) + scores += [keypoint[@"score"] floatValue]; + return scores / num_keypoints; +} + +NSMutableDictionary* traverseToTargetKeypoint(int edge_id, + NSMutableDictionary* source_keypoint, + int target_keypoint_id, + float* scores, + float* offsets, + float* displacements, + int input_size) { + float source_keypoint_y = [source_keypoint[@"y"] floatValue] * input_size; + float source_keypoint_x = [source_keypoint[@"x"] floatValue] * input_size; + + int source_keypoint_indices[2]; + getStridedIndexNearPoint(source_keypoint_indices, source_keypoint_y, source_keypoint_x); + + float displacement[2]; + getDisplacement(displacement, edge_id, source_keypoint_indices, displacements); + + float displaced_point[2]; + displaced_point[0] = source_keypoint_y + displacement[0]; + displaced_point[1] = source_keypoint_x + displacement[1]; + + float* target_keypoint = displaced_point; + + int offset_refine_step = 2; + for (int i = 0; i < offset_refine_step; i++) { + int target_keypoint_indices[2]; + getStridedIndexNearPoint(target_keypoint_indices, target_keypoint[0], target_keypoint[1]); + + int target_keypoint_y = target_keypoint_indices[0]; + int target_keypoint_x = target_keypoint_indices[1]; + + int offset = (target_keypoint_y * width + target_keypoint_x) * num_keypoints * 2 + target_keypoint_id; + float offset_y = offsets[offset]; + float offset_x = offsets[offset + num_keypoints]; + + target_keypoint[0] = target_keypoint_y * output_stride + offset_y; + target_keypoint[1] = target_keypoint_x * output_stride + offset_x; + } + + int target_keypoint_indices[2]; + getStridedIndexNearPoint(target_keypoint_indices, target_keypoint[0], target_keypoint[1]); + + float score = sigmoid(scores[(target_keypoint_indices[0] * width + + target_keypoint_indices[1]) * num_keypoints + target_keypoint_id]); + + NSMutableDictionary* keypoint = [NSMutableDictionary dictionary]; + [keypoint setValue:[NSNumber numberWithFloat:score] forKey:@"score"]; + [keypoint setValue:[NSNumber numberWithFloat:target_keypoint[0] / input_size] forKey:@"y"]; + [keypoint setValue:[NSNumber numberWithFloat:target_keypoint[1] / input_size] forKey:@"x"]; + [keypoint setValue:part_names[target_keypoint_id] forKey:@"part"]; + return keypoint; +} + +NSMutableArray* parsePoseNet(int num_results, float threshold, int nms_radius, int input_size) { + initPoseNet(); + + assert(interpreter->outputs().size() == 4); + TfLiteTensor* scores_tensor = interpreter->tensor(interpreter->outputs()[0]); + height = scores_tensor->dims->data[1]; + width = scores_tensor->dims->data[2]; + num_keypoints = scores_tensor->dims->data[3]; + + float* scores = interpreter->typed_output_tensor(0); + float* offsets = interpreter->typed_output_tensor(1); + float* displacements_fwd = interpreter->typed_output_tensor(2); + float* displacements_bwd = interpreter->typed_output_tensor(3); + + PriorityQueue pq = buildPartWithScoreQueue(scores, threshold, local_maximum_radius); + + int num_edges = (int)[parent_to_child_edges count]; + int sqared_nms_radius = nms_radius * nms_radius; + + NSMutableArray* results = [NSMutableArray array]; + + while([results count] < num_results && !pq.empty()) { + NSMutableDictionary* root = pq.top().second; + pq.pop(); + + float root_point[2]; + getImageCoords(root_point, root, offsets); + + if (withinNmsRadiusOfCorrespondingPoint(results, sqared_nms_radius, root_point[0], root_point[1], + [root[@"partId"] intValue], input_size)) + continue; + + NSMutableDictionary* keypoint = [NSMutableDictionary dictionary]; + [keypoint setValue:[NSNumber numberWithFloat:[root[@"score"] floatValue]] forKey:@"score"]; + [keypoint setValue:[NSNumber numberWithFloat:root_point[0] / input_size] forKey:@"y"]; + [keypoint setValue:[NSNumber numberWithFloat:root_point[1] / input_size] forKey:@"x"]; + [keypoint setValue:part_names[[root[@"partId"] intValue]] forKey:@"part"]; + + NSMutableDictionary* keypoints = [NSMutableDictionary dictionary]; + [keypoints setObject:keypoint forKey:root[@"partId"]]; + + for (int edge = num_edges - 1; edge >= 0; --edge) { + int source_keypoint_id = [parent_to_child_edges[edge] intValue]; + int target_keypoint_id = [child_to_parent_edges[edge] intValue]; + if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] && + !(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) { + keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]], + target_keypoint_id, scores, offsets, displacements_bwd, input_size); + [keypoints setObject:keypoint forKey:[NSNumber numberWithInt:target_keypoint_id]]; + } + } + + for (int edge = 0; edge < num_edges; ++edge) { + int source_keypoint_id = [child_to_parent_edges[edge] intValue]; + int target_keypoint_id = [parent_to_child_edges[edge] intValue]; + if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] && + !(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) { + keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]], + target_keypoint_id, scores, offsets, displacements_fwd, input_size); + [keypoints setObject:keypoint forKey:[NSNumber numberWithInt:target_keypoint_id]]; + } + } + + NSMutableDictionary* result = [NSMutableDictionary dictionary]; + [result setObject:keypoints forKey:@"keypoints"]; + [result setValue:[NSNumber numberWithFloat:getInstanceScore(keypoints)] forKey:@"score"]; + [results addObject:result]; + } + + return results; +} + +RCT_EXPORT_METHOD(runPoseNetOnImage:(NSString*)image_path + mean:(float)input_mean + std:(float)input_std + numResults:(int)num_results + threshold:(float)threshold + nmsRadius:(int)nms_radius + callback:(RCTResponseSenderBlock)callback) { + if (!interpreter) { + NSLog(@"Failed to construct interpreter."); + callback(@[@"Failed to construct interpreter."]); + } + + image_path = [image_path stringByReplacingOccurrencesOfString:@"file://" withString:@""]; + int input_size; + feedInputTensorImage(image_path, input_mean, input_std, &input_size); + + if (interpreter->Invoke() != kTfLiteOk) { + NSLog(@"Failed to invoke!"); + callback(@[@"Failed to invoke!"]); + } + + NSMutableArray* output = parsePoseNet(num_results, threshold, nms_radius, input_size); + callback(@[[NSNull null], output]); +} + RCT_EXPORT_METHOD(close) { interpreter = NULL; diff --git a/package.json b/package.json index e0e186c..abed7d8 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "tflite-react-native", - "version": "0.0.4", + "version": "0.0.5", "description": "A react native library for accessing TensorFlow Lite API. Supports Classification and Object Detection on both iOS and Android.", "main": "index.js", "scripts": {