diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e38acb..ea2c4e3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,9 +58,6 @@ jobs: - name: Setup uses: ./.github/actions/setup - - name: Get submodule - run: git submodule update --init --recursive - - name: Cache turborepo for Android uses: actions/cache@v3 with: diff --git a/.gitmodules b/.gitmodules index 5e9e16c..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "libyuv"] - path = libyuv - url = https://android.googlesource.com/platform/external/libyuv/ diff --git a/android/build.gradle b/android/build.gradle index 18bd321..c1a3d05 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -60,7 +60,7 @@ android { buildFeatures { prefab true - } + } ndkVersion getExtOrDefault("ndkVersion") compileSdkVersion getExtOrIntegerDefault("compileSdkVersion") @@ -69,19 +69,6 @@ android { minSdkVersion getExtOrIntegerDefault("minSdkVersion") targetSdkVersion getExtOrIntegerDefault("targetSdkVersion") - externalNativeBuild { - cmake { - cppFlags "-O2 -frtti -fexceptions -Wall -fstack-protector-all" - abiFilters (*reactNativeArchitectures()) - arguments "-DANDROID_STL=c++_shared" - } - } - } - - externalNativeBuild { - cmake { - path "CMakeLists.txt" - } } buildTypes { @@ -99,24 +86,6 @@ android { targetCompatibility JavaVersion.VERSION_1_8 } - // packagingOptions { - // excludes = [ - // "**/libc++_shared.so", - // "**/libfbjni.so", - // "**/libjsi.so", - // "**/libfolly_json.so", - // "**/libfolly_runtime.so", - // "**/libglog.so", - // "**/libhermes.so", - // "**/libhermes-executor-debug.so", - // "**/libhermes_executor.so", - // "**/libreactnativejni.so", - // "**/libturbomodulejsijni.so", - // "**/libreact_nativemodule_core.so", - // "**/libjscexecutor.so" - // ] - // } - } repositories { @@ -134,5 +103,6 @@ dependencies { implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" implementation project(':react-native-vision-camera') implementation 'com.google.mediapipe:tasks-vision:0.10.2' + implementation 'androidx.camera:camera-core:1.3.3' } diff --git a/android/src/main/cpp/JImage.cpp b/android/src/main/cpp/JImage.cpp deleted file mode 100644 index 37532e6..0000000 --- a/android/src/main/cpp/JImage.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24. -// - -#include "JImage.h" - -#include -#include - -namespace resizeconvert { - -using namespace facebook; -using namespace jni; - -int JImage::getWidth() const { - auto method = getClass()->getMethod("getWidth"); - auto result = method(self()); - return result; -} - -int JImage::getHeight() const { - auto method = getClass()->getMethod("getHeight"); - auto result = method(self()); - return result; -} - -jni::local_ref> JImage::getPlanes() const { - auto method = getClass()->getMethod()>("getPlanes"); - auto result = method(self()); - return result; -} - -} // namespace resizeconvert \ No newline at end of file diff --git a/android/src/main/cpp/JImage.h b/android/src/main/cpp/JImage.h deleted file mode 100644 index b9acab0..0000000 --- a/android/src/main/cpp/JImage.h +++ /dev/null @@ -1,27 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24. -// - -#pragma once - -#include "JImagePlane.h" -#include -#include - -namespace resizeconvert -{ - - using namespace facebook; - using namespace jni; - - struct JImage : public JavaClass - { - static constexpr auto kJavaDescriptor = "Landroid/media/Image;"; - - public: - int getWidth() const; - int getHeight() const; - jni::local_ref> getPlanes() const; - }; - -} // namespace resizeconvert diff --git a/android/src/main/cpp/JImagePlane.cpp b/android/src/main/cpp/JImagePlane.cpp deleted file mode 100644 index b260dd6..0000000 --- a/android/src/main/cpp/JImagePlane.cpp +++ /dev/null @@ -1,34 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24. -// - -#include "JImagePlane.h" - -namespace resizeconvert -{ - - using namespace facebook; - using namespace jni; - - int JImagePlane::getPixelStride() const - { - auto method = getClass()->getMethod("getPixelStride"); - auto result = method(self()); - return result; - } - - int JImagePlane::getRowStride() const - { - auto method = getClass()->getMethod("getRowStride"); - auto result = method(self()); - return result; - } - - jni::local_ref JImagePlane::getBuffer() const - { - auto method = getClass()->getMethod("getBuffer"); - auto result = method(self()); - return result; - } - -} // namespace resizeconvert \ No newline at end of file diff --git a/android/src/main/cpp/JImagePlane.h b/android/src/main/cpp/JImagePlane.h deleted file mode 100644 index cc7b70e..0000000 --- a/android/src/main/cpp/JImagePlane.h +++ /dev/null @@ -1,27 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24. -// - -#pragma once - -#include -#include -#include - -namespace resizeconvert -{ - - using namespace facebook; - using namespace jni; - - struct JImagePlane : public JavaClass - { - static constexpr auto kJavaDescriptor = "Landroid/media/Image$Plane;"; - - public: - jni::local_ref getBuffer() const; - int getPixelStride() const; - int getRowStride() const; - }; - -} // namespace resizeconvert diff --git a/android/src/main/cpp/ResizeConvert.cpp b/android/src/main/cpp/ResizeConvert.cpp deleted file mode 100644 index 02c802e..0000000 --- a/android/src/main/cpp/ResizeConvert.cpp +++ /dev/null @@ -1,427 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24 -// - -#include "ResizeConvert.h" -#include "libyuv.h" -#include -#include -#include -#include - -namespace resizeconvert { - - using namespace facebook; - using namespace facebook::jni; - - void ResizeConvert::registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ResizeConvert::initHybrid), - makeNativeMethod("resize", ResizeConvert::resize), - }); - } - - ResizeConvert::ResizeConvert(const jni::alias_ref &javaThis) { - _javaThis = jni::make_global(javaThis); - } - - int getChannelCount(PixelFormat pixelFormat) { - switch (pixelFormat) { - case RGB: - case BGR: - return 3; - case ARGB: - case RGBA: - case BGRA: - case ABGR: - return 4; - } - } - - int getBytesPerChannel(DataType type) { - switch (type) { - case UINT8: - return sizeof(uint8_t); - case FLOAT32: - return sizeof(float_t); - } - } - - int getBytesPerPixel(PixelFormat pixelFormat, DataType type) { - return getChannelCount(pixelFormat) * getBytesPerChannel(type); - } - - int FrameBuffer::bytesPerRow() { - size_t bytesPerPixel = getBytesPerPixel(pixelFormat, dataType); - return width * bytesPerPixel; - } - - uint8_t *FrameBuffer::data() { - return buffer->getDirectBytes(); - } - - global_ref ResizeConvert::allocateBuffer(size_t size, std::string debugName) { - __android_log_print(ANDROID_LOG_INFO, TAG, "Allocating %s Buffer with size %zu...", - debugName.c_str(), size); - local_ref buffer = JByteBuffer::allocateDirect(size); - buffer->order(JByteOrder::nativeOrder()); - return make_global(buffer); - } - - FrameBuffer ResizeConvert::imageToFrameBuffer(alias_ref image) { - __android_log_write(ANDROID_LOG_INFO, TAG, "Converting YUV 4:2:0 -> ARGB 8888..."); - - jni::local_ref> planes = image->getPlanes(); - - jni::local_ref yPlane = planes->getElement(0); - jni::local_ref yBuffer = yPlane->getBuffer(); - jni::local_ref uPlane = planes->getElement(1); - jni::local_ref uBuffer = uPlane->getBuffer(); - jni::local_ref vPlane = planes->getElement(2); - jni::local_ref vBuffer = vPlane->getBuffer(); - - size_t uvPixelStride = uPlane->getPixelStride(); - if (uPlane->getPixelStride() != vPlane->getPixelStride()) { - throw std::runtime_error( - "U and V planes do not have the same pixel stride! Are you sure this is a 4:2:0 YUV format?"); - } - - int width = image->getWidth(); - int height = image->getHeight(); - - size_t channels = getChannelCount(PixelFormat::ARGB); - size_t channelSize = getBytesPerChannel(DataType::UINT8); - size_t argbSize = width * height * channels * channelSize; - if (_argbBuffer == nullptr || _argbBuffer->getDirectSize() != argbSize) { - _argbBuffer = allocateBuffer(argbSize, "_argbBuffer"); - } - FrameBuffer destination = { - .width = width, - .height = height, - .pixelFormat = PixelFormat::ARGB, - .dataType = DataType::UINT8, - .buffer = _argbBuffer, - }; - - // 1. Convert from YUV -> ARGB - int status = libyuv::Android420ToARGB(yBuffer->getDirectBytes(), yPlane->getRowStride(), - uBuffer->getDirectBytes(), - uPlane->getRowStride(), vBuffer->getDirectBytes(), - vPlane->getRowStride(), uvPixelStride, - destination.data(), width * channels * channelSize, - width, height); - - if (status != 0) { - throw std::runtime_error( - "Failed to convert YUV 4:2:0 to ARGB! Error: " + std::to_string(status)); - } - - return destination; - } - - std::string rectToString(int x, int y, int width, int height) { - return std::to_string(x) + ", " + std::to_string(y) + " @ " + std::to_string(width) + "x" + - std::to_string(height); - } - - FrameBuffer - ResizeConvert::cropARGBBuffer(resizeconvert::FrameBuffer frameBuffer, int x, int y, int width, - int height) { - if (width == frameBuffer.width && height == frameBuffer.height && x == 0 && y == 0) { - // already in correct size. - return frameBuffer; - } - - auto rectString = rectToString(0, 0, frameBuffer.width, frameBuffer.height); - auto targetString = rectToString(x, y, width, height); - __android_log_print(ANDROID_LOG_INFO, TAG, "Cropping [%s] ARGB buffer to [%s]...", - rectString.c_str(), targetString.c_str()); - - size_t channels = getChannelCount(PixelFormat::ARGB); - size_t channelSize = getBytesPerChannel(DataType::UINT8); - size_t argbSize = width * height * channels * channelSize; - if (_cropBuffer == nullptr || _cropBuffer->getDirectSize() != argbSize) { - _cropBuffer = allocateBuffer(argbSize, "_cropBuffer"); - } - FrameBuffer destination = { - .width = width, - .height = height, - .pixelFormat = PixelFormat::ARGB, - .dataType = DataType::UINT8, - .buffer = _cropBuffer, - }; - - int status = libyuv::ConvertToARGB(frameBuffer.data(), - frameBuffer.height * frameBuffer.bytesPerRow(), - destination.data(), - destination.bytesPerRow(), x, y, frameBuffer.width, - frameBuffer.height, width, height, - libyuv::kRotate0, libyuv::FOURCC_ARGB); - if (status != 0) { - throw std::runtime_error( - "Failed to crop ARGB Buffer! Status: " + std::to_string(status)); - } - - return destination; - } - - FrameBuffer ResizeConvert::mirrorARGBBuffer(FrameBuffer frameBuffer, bool mirror) { - if (!mirror) { - return frameBuffer; - } - - __android_log_print(ANDROID_LOG_INFO, TAG, "Mirroring ARGB buffer..."); - - size_t channels = getChannelCount(PixelFormat::ARGB); - size_t channelSize = getBytesPerChannel(DataType::UINT8); - size_t argbSize = frameBuffer.width * frameBuffer.height * channels * channelSize; - if (_mirrorBuffer == nullptr || _mirrorBuffer->getDirectSize() != argbSize) { - _mirrorBuffer = allocateBuffer(argbSize, "_mirrorBuffer"); - } - FrameBuffer destination = { - .width = frameBuffer.width, - .height = frameBuffer.height, - .pixelFormat = PixelFormat::ARGB, - .dataType = DataType::UINT8, - .buffer = _mirrorBuffer, - }; - - int status = libyuv::ARGBMirror(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - frameBuffer.width, frameBuffer.height); - if (status != 0) { - throw std::runtime_error( - "Failed to mirror ARGB Buffer! Status: " + std::to_string(status)); - } - - return destination; - } - - FrameBuffer ResizeConvert::rotateARGBBuffer(FrameBuffer frameBuffer, int rotation) { - if (rotation == 0) { - return frameBuffer; - } - - int rotatedWidth = frameBuffer.width; - int rotatedHeight = frameBuffer.height; - if (rotation == 90 || rotation == 270) { - std::swap(rotatedWidth, rotatedHeight); - } - - size_t channels = getChannelCount(PixelFormat::ARGB); - size_t channelSize = getBytesPerChannel(DataType::UINT8); - size_t destinationStride = - rotation == 90 || rotation == 270 ? rotatedWidth * channels * channelSize - : frameBuffer.bytesPerRow(); - size_t argbSize = rotatedWidth * rotatedHeight * channels * channelSize; - - if (_rotatedBuffer == nullptr || _rotatedBuffer->getDirectSize() != argbSize) { - _rotatedBuffer = allocateBuffer(argbSize, "_rotatedBuffer"); - } - - FrameBuffer destination = { - .width = rotatedWidth, - .height = rotatedHeight, - .pixelFormat = PixelFormat::ARGB, - .dataType = DataType::UINT8, - .buffer = _rotatedBuffer, - }; - - int status = libyuv::ARGBRotate(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destinationStride, frameBuffer.width, - frameBuffer.height, - static_cast(rotation)); - if (status != 0) { - throw std::runtime_error( - "Failed to rotate ARGB Buffer! Status: " + std::to_string(status)); - } - - return destination; - } - - FrameBuffer - ResizeConvert::scaleARGBBuffer(resizeconvert::FrameBuffer frameBuffer, int width, int height) { - if (width == frameBuffer.width && height == frameBuffer.height) { - // already in correct size. - return frameBuffer; - } - auto rectString = rectToString(0, 0, frameBuffer.width, frameBuffer.height); - auto targetString = rectToString(0, 0, width, height); - __android_log_print(ANDROID_LOG_INFO, TAG, "Scaling [%s] ARGB buffer to [%s]...", - rectString.c_str(), targetString.c_str()); - - size_t channels = getChannelCount(PixelFormat::ARGB); - size_t channelSize = getBytesPerChannel(DataType::UINT8); - size_t argbSize = width * height * channels * channelSize; - if (_scaleBuffer == nullptr || _scaleBuffer->getDirectSize() != argbSize) { - _scaleBuffer = allocateBuffer(argbSize, "_scaleBuffer"); - } - FrameBuffer destination = { - .width = width, - .height = height, - .pixelFormat = PixelFormat::ARGB, - .dataType = DataType::UINT8, - .buffer = _scaleBuffer, - }; - - int status = libyuv::ARGBScale(frameBuffer.data(), frameBuffer.bytesPerRow(), - frameBuffer.width, frameBuffer.height, destination.data(), - destination.bytesPerRow(), width, height, - libyuv::FilterMode::kFilterBilinear); - if (status != 0) { - throw std::runtime_error( - "Failed to scale ARGB Buffer! Status: " + std::to_string(status)); - } - - return destination; - } - - FrameBuffer - ResizeConvert::convertARGBBufferTo(FrameBuffer frameBuffer, PixelFormat pixelFormat) { - if (frameBuffer.pixelFormat == pixelFormat) { - // Already in the correct format. - return frameBuffer; - } - - __android_log_print(ANDROID_LOG_INFO, TAG, "Converting ARGB Buffer to Pixel Format %zu...", - pixelFormat); - - size_t bytesPerPixel = getBytesPerPixel(pixelFormat, frameBuffer.dataType); - size_t targetBufferSize = frameBuffer.width * frameBuffer.height * bytesPerPixel; - if (_customFormatBuffer == nullptr || - _customFormatBuffer->getDirectSize() != targetBufferSize) { - _customFormatBuffer = allocateBuffer(targetBufferSize, "_customFormatBuffer"); - } - FrameBuffer destination = { - .width = frameBuffer.width, - .height = frameBuffer.height, - .pixelFormat = pixelFormat, - .dataType = frameBuffer.dataType, - .buffer = _customFormatBuffer, - }; - - int error = 0; - switch (pixelFormat) { - case PixelFormat::ARGB: - // do nothing, we're already in ARGB - return frameBuffer; - case RGB: - // RAW is [R, G, B] in libyuv memory layout - error = libyuv::ARGBToRAW(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - destination.width, destination.height); - break; - case BGR: - // RGB24 is [B, G, R] in libyuv memory layout - error = libyuv::ARGBToRGB24(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - destination.width, destination.height); - break; - case RGBA: - error = libyuv::ARGBToRGBA(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - destination.width, destination.height); - break; - case BGRA: - error = libyuv::ARGBToBGRA(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - destination.width, destination.height); - break; - case ABGR: - error = libyuv::ARGBToABGR(frameBuffer.data(), frameBuffer.bytesPerRow(), - destination.data(), destination.bytesPerRow(), - destination.width, destination.height); - break; - } - - if (error != 0) { - throw std::runtime_error( - "Failed to convert ARGB Buffer to target Pixel Format! Error: " + - std::to_string(error)); - } - - return destination; - } - - FrameBuffer ResizeConvert::convertBufferToDataType(FrameBuffer frameBuffer, DataType dataType) { - if (frameBuffer.dataType == dataType) { - // Already in correct data-type - return frameBuffer; - } - - __android_log_print(ANDROID_LOG_INFO, TAG, "Converting ARGB Buffer to Data Type %zu...", - dataType); - - size_t targetSize = frameBuffer.width * frameBuffer.height * - getBytesPerPixel(frameBuffer.pixelFormat, dataType); - if (_customTypeBuffer == nullptr || _customTypeBuffer->getDirectSize() != targetSize) { - _customTypeBuffer = allocateBuffer(targetSize, "_customTypeBuffer"); - } - size_t size = frameBuffer.buffer->getDirectSize(); - FrameBuffer destination = { - .width = frameBuffer.width, - .height = frameBuffer.height, - .pixelFormat = frameBuffer.pixelFormat, - .dataType = dataType, - .buffer = _customTypeBuffer, - }; - - int status = 0; - switch (dataType) { - case UINT8: - // it's already uint8 - return frameBuffer; - case FLOAT32: { - float *floatData = reinterpret_cast(destination.data()); - status = libyuv::ByteToFloat(frameBuffer.data(), floatData, 1.0f / 255.0f, size); - break; - } - } - - if (status != 0) { - throw std::runtime_error("Failed to convert Buffer to target Data Type! Error: " + - std::to_string(status)); - } - - return destination; - } - - jni::global_ref - ResizeConvert::resize(jni::alias_ref image, int cropX, int cropY, int cropWidth, - int cropHeight, - int scaleWidth, int scaleHeight, int rotationOrdinal, bool mirror, - int /* PixelFormat */ pixelFormatOrdinal, - int /* DataType */ dataTypeOrdinal) { - PixelFormat pixelFormat = static_cast(pixelFormatOrdinal); - DataType dataType = static_cast(dataTypeOrdinal); - - // 1. Convert from YUV -> ARGB - FrameBuffer result = imageToFrameBuffer(image); - - // 2. Crop ARGB - result = cropARGBBuffer(result, cropX, cropY, cropWidth, cropHeight); - - // 3. Scale ARGB - result = scaleARGBBuffer(result, scaleWidth, scaleHeight); - - // 4. Rotate ARGB - result = rotateARGBBuffer(result, rotationOrdinal); - - // 5 Mirror ARGB if needed - result = mirrorARGBBuffer(result, mirror); - - // 6. Convert from ARGB -> ???? - result = convertARGBBufferTo(result, pixelFormat); - - // 7. Convert from data type to other data type - result = convertBufferToDataType(result, dataType); - - return result.buffer; - } - - jni::local_ref - ResizeConvert::initHybrid(jni::alias_ref javaThis) { - return makeCxxInstance(javaThis); - } - -} // namespace resizeconvert diff --git a/android/src/main/cpp/ResizeConvert.h b/android/src/main/cpp/ResizeConvert.h deleted file mode 100644 index dc658a1..0000000 --- a/android/src/main/cpp/ResizeConvert.h +++ /dev/null @@ -1,89 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24 -// - -#pragma once - -#include -#include -#include -#include -#include - -#include "JImage.h" - -namespace resizeconvert { - - using namespace facebook; - using namespace jni; - - enum PixelFormat { - RGB, BGR, ARGB, RGBA, BGRA, ABGR - }; - - enum DataType { - UINT8, FLOAT32 - }; - - struct FrameBuffer { - int width; - int height; - PixelFormat pixelFormat; - DataType dataType; - global_ref buffer; - - uint8_t *data(); - - int bytesPerRow(); - }; - - struct ResizeConvert : public HybridClass { - public: - static auto constexpr kJavaDescriptor = "Lcom/reactnativemediapipe/shared/ResizeConvert;"; - static void registerNatives(); - - private: - explicit ResizeConvert(const alias_ref &javaThis); - - global_ref - resize(alias_ref image, int cropX, int cropY, int cropWidth, int cropHeight, - int scaleWidth, - int scaleHeight, int rotation, bool mirror, int /* PixelFormat */ pixelFormat, - int /* DataType */ dataType); - - FrameBuffer imageToFrameBuffer(alias_ref image); - - FrameBuffer cropARGBBuffer(FrameBuffer frameBuffer, int x, int y, int width, int height); - - FrameBuffer scaleARGBBuffer(FrameBuffer frameBuffer, int width, int height); - - FrameBuffer convertARGBBufferTo(FrameBuffer frameBuffer, PixelFormat toFormat); - - FrameBuffer convertBufferToDataType(FrameBuffer frameBuffer, DataType dataType); - - FrameBuffer rotateARGBBuffer(FrameBuffer frameBuffer, int rotation); - - FrameBuffer mirrorARGBBuffer(FrameBuffer frameBuffer, bool mirror); - - global_ref allocateBuffer(size_t size, std::string debugName); - - private: - static auto constexpr TAG = "ResizeConvert"; - friend HybridBase; - global_ref _javaThis; - // YUV (?x?) -> ARGB (?x?) - global_ref _argbBuffer; - // ARGB (?x?) -> ARGB (!x!) - global_ref _cropBuffer; - global_ref _scaleBuffer; - global_ref _rotatedBuffer; - global_ref _mirrorBuffer; - // ARGB (?x?) -> !!!! (?x?) - global_ref _customFormatBuffer; - // Custom Data Type (e.g. float32) - global_ref _customTypeBuffer; - - static local_ref initHybrid(alias_ref javaThis); - }; - -} // namespace resizeconvert diff --git a/android/src/main/cpp/ResizeConvertLib.cpp b/android/src/main/cpp/ResizeConvertLib.cpp deleted file mode 100644 index 816d0f0..0000000 --- a/android/src/main/cpp/ResizeConvertLib.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// -// Created by Marc Rousavy on 25.01.24 -// - -#include "ResizeConvert.h" -#include -#include - -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize(vm, [] { resizeconvert::ResizeConvert::registerNatives(); }); -} diff --git a/android/src/main/java/com/reactnativemediapipe/objectdetection/ConvertHelpers.kt b/android/src/main/java/com/reactnativemediapipe/objectdetection/ConvertHelpers.kt index 2f2b50b..bfa3eb5 100644 --- a/android/src/main/java/com/reactnativemediapipe/objectdetection/ConvertHelpers.kt +++ b/android/src/main/java/com/reactnativemediapipe/objectdetection/ConvertHelpers.kt @@ -7,6 +7,7 @@ import com.facebook.react.bridge.WritableMap import com.facebook.react.bridge.WritableNativeMap import com.google.mediapipe.tasks.components.containers.Category import com.google.mediapipe.tasks.components.containers.Detection +import com.mrousavy.camera.core.types.Orientation import java.util.Optional // Assuming simplified representations based on your descriptions @@ -75,3 +76,13 @@ fun convertResultBundleToWritableMap(resultBundle: ObjectDetectorHelper.ResultBu map.putDouble("inferenceTime", resultBundle.inferenceTime.toDouble()) return map } + + +fun orientationToDegrees(orientation: Orientation): Int = + when (orientation) { + Orientation.PORTRAIT -> 0 + Orientation.LANDSCAPE_LEFT -> 90 + Orientation.PORTRAIT_UPSIDE_DOWN -> 180 + Orientation.LANDSCAPE_RIGHT -> -90 + } + diff --git a/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectionFrameProcessorPlugin.kt b/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectionFrameProcessorPlugin.kt index 74c4c1e..96eb976 100644 --- a/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectionFrameProcessorPlugin.kt +++ b/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectionFrameProcessorPlugin.kt @@ -1,130 +1,50 @@ package com.reactnativemediapipe.objectdetection -import android.graphics.ImageFormat -import android.util.Log -import com.google.mediapipe.framework.image.ByteBufferImageBuilder -import com.google.mediapipe.framework.image.MPImage +import android.graphics.Bitmap +import android.graphics.Matrix +import androidx.camera.core.ImageProxy +import com.google.mediapipe.framework.image.BitmapImageBuilder +import com.mrousavy.camera.core.types.PixelFormat import com.mrousavy.camera.frameprocessors.Frame import com.mrousavy.camera.frameprocessors.FrameProcessorPlugin -import com.reactnativemediapipe.shared.ResizeConvert -class ObjectDetectionFrameProcessorPlugin() : - FrameProcessorPlugin() { +class ObjectDetectionFrameProcessorPlugin() : FrameProcessorPlugin() { companion object { private const val TAG = "ObjectDetectionFrameProcessorPlugin" } - private val resizeConvert: ResizeConvert = ResizeConvert() override fun callback(frame: Frame, params: MutableMap?): Any? { val detectorHandle: Double = params!!["detectorHandle"] as Double val detector = ObjectDetectorMap.detectorMap[detectorHandle.toInt()] ?: return false - var cropWidth = frame.width - var cropHeight = frame.height - var cropX = 0 - var cropY = 0 - var scaleWidth = frame.width - var scaleHeight = frame.height - - val rotationParam = params["rotation"] - val rotation: Rotation - if (rotationParam is String) { - rotation = Rotation.fromString(rotationParam) - Log.i(TAG, "Rotation: ${rotation.degrees}") - } else { - rotation = Rotation.Rotation0 - Log.i(TAG, "Rotation not specified, defaulting to: ${rotation.degrees}") - } - - val mirrorParam = params["mirror"] - val mirror: Boolean - if (mirrorParam is Boolean) { - mirror = mirrorParam - Log.i(TAG, "Mirror: $mirror") - } else { - mirror = false - Log.i(TAG, "Mirror not specified, defaulting to: $mirror") + // val mpImage = MediaImageBuilder(frame.image).build() + // detector.detectLivestreamFrame(mpImage,frame.orientation) + val bitmap = imageToBitmap(frame.imageProxy) + if (bitmap != null) { + val rotated = rotateBitmap(bitmap, orientationToDegrees(frame.orientation).toFloat()) + val mpImage = BitmapImageBuilder(rotated).build() + detector.detectLivestreamFrame(mpImage, frame.orientation) } + return true + } - val scale = params["scale"] - if (scale != null) { - if (scale is Map<*, *>) { - val scaleWidthDouble = scale["width"] as? Double - val scaleHeightDouble = scale["height"] as? Double - if (scaleWidthDouble != null && scaleHeightDouble != null) { - scaleWidth = scaleWidthDouble.toInt() - scaleHeight = scaleHeightDouble.toInt() - } else { - throw Error("Failed to parse values in scale dictionary!") - } - Log.i(TAG, "Target scale: $scaleWidth x $scaleHeight") - } else if (scale is Double) { - scaleWidth = (scale * frame.width).toInt() - scaleHeight = (scale * frame.height).toInt() - Log.i(TAG, "Uniform scale factor applied: $scaleWidth x $scaleHeight") - } else { - throw Error("Scale must be either a map with width and height or a double value!") - } + private var bitmapBuffer: Bitmap? = null + private fun imageToBitmap(image: ImageProxy): Bitmap? { + if (bitmapBuffer == null) { + bitmapBuffer = Bitmap.createBitmap(image.width, image.height, Bitmap.Config.ARGB_8888) } - - val crop = params["crop"] as? Map<*, *> - if (crop != null) { - val cropWidthDouble = crop["width"] as? Double - val cropHeightDouble = crop["height"] as? Double - val cropXDouble = crop["x"] as? Double - val cropYDouble = crop["y"] as? Double - if (cropWidthDouble != null && cropHeightDouble != null && cropXDouble != null && cropYDouble != null) { - cropWidth = cropWidthDouble.toInt() - cropHeight = cropHeightDouble.toInt() - cropX = cropXDouble.toInt() - cropY = cropYDouble.toInt() - Log.i(TAG, "Target size: $cropWidth x $cropHeight") - } else { - throw Error("Failed to parse values in crop dictionary!") - } - } else { - if (scale != null) { - val aspectRatio = frame.width.toDouble() / frame.height.toDouble() - val targetAspectRatio = scaleWidth.toDouble() / scaleHeight.toDouble() - - if (aspectRatio > targetAspectRatio) { - cropWidth = (frame.height * targetAspectRatio).toInt() - cropHeight = frame.height - } else { - cropWidth = frame.width - cropHeight = (frame.width / targetAspectRatio).toInt() - } - cropX = (frame.width / 2) - (cropWidth / 2) - cropY = (frame.height / 2) - (cropHeight / 2) - Log.i(TAG, "Cropping to $cropWidth x $cropHeight at ($cropX, $cropY)") - } else { - Log.i(TAG, "Both scale and crop are null, using Frame's original dimensions.") - } + bitmapBuffer?.let { + val buffer = image.planes[0].buffer + it.copyPixelsFromBuffer(buffer) } + return bitmapBuffer + } - val image = frame.image - - if (image.format != ImageFormat.YUV_420_888) { - throw Error("Frame has invalid PixelFormat! Only YUV_420_888 is supported. Did you set pixelFormat=\"yuv\"?") - } - - val resized = resizeConvert.resize( - image, - cropX, cropY, - cropWidth, cropHeight, - scaleWidth, scaleHeight, - rotation.degrees, - mirror, - PixelFormat.RGB.ordinal, - DataType.UINT8.ordinal - ) - - val mpImage = - ByteBufferImageBuilder(resized, scaleWidth, scaleHeight, MPImage.IMAGE_FORMAT_RGB).build() - - detector.detectLivestreamFrame(mpImage) - return true + private fun rotateBitmap(source: Bitmap, angle: Float): Bitmap { + val matrix = Matrix() + matrix.postRotate(angle) + return Bitmap.createBitmap(source, 0, 0, source.width, source.height, matrix, true) } private enum class PixelFormat { @@ -138,15 +58,15 @@ class ObjectDetectionFrameProcessorPlugin() : companion object { fun fromString(string: String): PixelFormat = - when (string) { - "rgb" -> RGB - "rgba" -> RGBA - "argb" -> ARGB - "bgra" -> BGRA - "bgr" -> BGR - "abgr" -> ABGR - else -> throw Error("Invalid PixelFormat! ($string)") - } + when (string) { + "rgb" -> RGB + "rgba" -> RGBA + "argb" -> ARGB + "bgra" -> BGRA + "bgr" -> BGR + "abgr" -> ABGR + else -> throw Error("Invalid PixelFormat! ($string)") + } } } @@ -157,11 +77,11 @@ class ObjectDetectionFrameProcessorPlugin() : companion object { fun fromString(string: String): DataType = - when (string) { - "uint8" -> UINT8 - "float32" -> FLOAT32 - else -> throw Error("Invalid DataType! ($string)") - } + when (string) { + "uint8" -> UINT8 + "float32" -> FLOAT32 + else -> throw Error("Invalid DataType! ($string)") + } } } } @@ -174,12 +94,12 @@ private enum class Rotation(val degrees: Int) { companion object { fun fromString(value: String): Rotation = - when (value) { - "0deg" -> Rotation0 - "90deg" -> Rotation90 - "180deg" -> Rotation180 - "270deg" -> Rotation270 - else -> throw Error("Invalid rotation value! ($value)") - } + when (value) { + "0deg" -> Rotation0 + "90deg" -> Rotation90 + "180deg" -> Rotation180 + "270deg" -> Rotation270 + else -> throw Error("Invalid rotation value! ($value)") + } } } diff --git a/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectorHelper.kt b/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectorHelper.kt index 07a770e..6f1bd22 100644 --- a/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectorHelper.kt +++ b/android/src/main/java/com/reactnativemediapipe/objectdetection/ObjectDetectorHelper.kt @@ -2,23 +2,14 @@ package com.reactnativemediapipe.objectdetection import android.content.Context import android.graphics.Bitmap -import android.graphics.ImageFormat -import android.media.Image import android.media.MediaMetadataRetriever import android.net.Uri -import android.os.SystemClock -import android.renderscript.Allocation -import android.renderscript.Element -import android.renderscript.RenderScript -import android.renderscript.ScriptIntrinsicYuvToRGB -import android.renderscript.Type -import android.util.Log import android.os.Handler import android.os.Looper -import androidx.core.math.MathUtils.clamp +import android.os.SystemClock +import android.util.Log import com.facebook.react.common.annotations.VisibleForTesting import com.google.mediapipe.framework.image.BitmapImageBuilder -import com.google.mediapipe.framework.image.ByteBufferImageBuilder import com.google.mediapipe.framework.image.MPImage import com.google.mediapipe.tasks.core.BaseOptions import com.google.mediapipe.tasks.core.Delegate @@ -26,24 +17,24 @@ import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions import com.google.mediapipe.tasks.vision.core.RunningMode import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectorResult - +import com.mrousavy.camera.core.types.Orientation class ObjectDetectorHelper( - var threshold: Float = THRESHOLD_DEFAULT, - var maxResults: Int = MAX_RESULTS_DEFAULT, - var currentDelegate: Int = DELEGATE_CPU, - var currentModel: String, - var runningMode: RunningMode = RunningMode.IMAGE, - val context: Context, - // The listener is only used when running in RunningMode.LIVE_STREAM - var objectDetectorListener: DetectorListener? = null + var threshold: Float = THRESHOLD_DEFAULT, + var maxResults: Int = MAX_RESULTS_DEFAULT, + var currentDelegate: Int = DELEGATE_CPU, + var currentModel: String, + var runningMode: RunningMode = RunningMode.IMAGE, + val context: Context, + // The listener is only used when running in RunningMode.LIVE_STREAM + var objectDetectorListener: DetectorListener? = null ) { // For this example this needs to be a var so it can be reset on changes. If the ObjectDetector // will not change, a lazy val would be preferable. private var objectDetector: ObjectDetector? = null private var imageRotation = 0 - private lateinit var imageProcessingOptions: ImageProcessingOptions + // private lateinit var imageProcessingOptions: ImageProcessingOptions init { setupObjectDetector() @@ -55,10 +46,14 @@ class ObjectDetectorHelper( // that this is because the object detector is still doing some processing, and that // it is not safe to close it. So a better solution might be to mark it and then when // processing is complete, cause it to be closed. - Handler(Looper.getMainLooper()).postDelayed({ - objectDetector?.close() - objectDetector = null - }, 100) + Handler(Looper.getMainLooper()) + .postDelayed( + { + objectDetector?.close() + objectDetector = null + }, + 100 + ) } // Initialize the object detector using current settings on the @@ -74,7 +69,6 @@ class ObjectDetectorHelper( DELEGATE_CPU -> { baseOptionsBuilder.setDelegate(Delegate.CPU) } - DELEGATE_GPU -> { // Is there a check for GPU being supported? baseOptionsBuilder.setDelegate(Delegate.GPU) @@ -88,55 +82,42 @@ class ObjectDetectorHelper( RunningMode.LIVE_STREAM -> { if (objectDetectorListener == null) { throw IllegalStateException( - "objectDetectorListener must be set when runningMode is LIVE_STREAM." + "objectDetectorListener must be set when runningMode is LIVE_STREAM." ) } } - RunningMode.IMAGE, RunningMode.VIDEO -> { // no-op } } try { - val optionsBuilder = ObjectDetector.ObjectDetectorOptions.builder() - .setBaseOptions(baseOptionsBuilder.build()) - .setScoreThreshold(threshold).setRunningMode(runningMode) - .setMaxResults(maxResults) - - imageProcessingOptions = ImageProcessingOptions.builder() - .setRotationDegrees(imageRotation).build() + val optionsBuilder = + ObjectDetector.ObjectDetectorOptions.builder() + .setBaseOptions(baseOptionsBuilder.build()) + .setScoreThreshold(threshold) + .setRunningMode(runningMode) + .setMaxResults(maxResults) when (runningMode) { - RunningMode.IMAGE, RunningMode.VIDEO -> optionsBuilder.setRunningMode( - runningMode - ) - - RunningMode.LIVE_STREAM -> optionsBuilder.setRunningMode( - runningMode - ).setResultListener(this::returnLivestreamResult) - .setErrorListener(this::returnLivestreamError) + RunningMode.IMAGE, RunningMode.VIDEO -> optionsBuilder.setRunningMode(runningMode) + RunningMode.LIVE_STREAM -> + optionsBuilder + .setRunningMode(runningMode) + .setResultListener(this::returnLivestreamResult) + .setErrorListener(this::returnLivestreamError) } val options = optionsBuilder.build() objectDetector = ObjectDetector.createFromOptions(context, options) } catch (e: IllegalStateException) { - objectDetectorListener?.onError( - "Object detector failed to initialize: " + e.message - ) + objectDetectorListener?.onError("Object detector failed to initialize: " + e.message) Log.e(TAG, "TFLite failed to load model with error: " + e.message) } catch (e: RuntimeException) { - objectDetectorListener?.onError( - "Object detector failed to initialize: " + e.message - ) - Log.e( - TAG, - "Object detector failed to load model with error: " + e.message - ) + objectDetectorListener?.onError("Object detector failed to initialize: " + e.message) + Log.e(TAG, "Object detector failed to load model with error: " + e.message) } catch (e: Exception) { - objectDetectorListener?.onError( - "Object detector failed to initialize: " + e.message - ) + objectDetectorListener?.onError("Object detector failed to initialize: " + e.message) Log.e(TAG, "TFLite failed to load model with error: " + e.message) } } @@ -149,13 +130,11 @@ class ObjectDetectorHelper( // Accepts the URI for a video file loaded from the user's gallery and attempts to run // object detection inference on the video. This process will evaluate every frame in // the video and attach the results to a bundle that will be returned. - fun detectVideoFile( - videoUri: Uri, inferenceIntervalMs: Long - ): ResultBundle? { + fun detectVideoFile(videoUri: Uri, inferenceIntervalMs: Long): ResultBundle? { if (runningMode != RunningMode.VIDEO) { throw IllegalArgumentException( - "Attempting to call detectVideoFile" + " while not using RunningMode.VIDEO" + "Attempting to call detectVideoFile" + " while not using RunningMode.VIDEO" ) } @@ -171,8 +150,7 @@ class ObjectDetectorHelper( val retriever = MediaMetadataRetriever() retriever.setDataSource(context, videoUri) val videoLengthMs = - retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_DURATION) - ?.toLong() + retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_DURATION)?.toLong() // Note: We need to read width/height from frame instead of getting the width/height // of the video directly because MediaRetriever returns frames that are smaller than the @@ -192,39 +170,40 @@ class ObjectDetectorHelper( val timestampMs = i * inferenceIntervalMs // ms retriever.getFrameAtTime( - timestampMs * 1000, // convert from ms to micro-s - MediaMetadataRetriever.OPTION_CLOSEST - )?.let { frame -> - // Convert the video frame to ARGB_8888 which is required by the MediaPipe - val argb8888Frame = - if (frame.config == Bitmap.Config.ARGB_8888) frame - else frame.copy(Bitmap.Config.ARGB_8888, false) - - // Convert the input Bitmap object to an MPImage object to run inference - val mpImage = BitmapImageBuilder(argb8888Frame).build() - - // Run object detection using MediaPipe Object Detector API - objectDetector?.detectForVideo(mpImage, timestampMs) - ?.let { detectionResult -> - resultList.add(detectionResult) - } ?: { - didErrorOccurred = true - objectDetectorListener?.onError( - "ResultBundle could not be returned" + " in detectVideoFile" + timestampMs * 1000, // convert from ms to micro-s + MediaMetadataRetriever.OPTION_CLOSEST ) - } - } ?: run { - didErrorOccurred = true - objectDetectorListener?.onError( - "Frame at specified time could not be" + " retrieved when detecting in video." - ) - } + ?.let { frame -> + // Convert the video frame to ARGB_8888 which is required by the MediaPipe + val argb8888Frame = + if (frame.config == Bitmap.Config.ARGB_8888) frame + else frame.copy(Bitmap.Config.ARGB_8888, false) + + // Convert the input Bitmap object to an MPImage object to run inference + val mpImage = BitmapImageBuilder(argb8888Frame).build() + + // Run object detection using MediaPipe Object Detector API + objectDetector?.detectForVideo(mpImage, timestampMs)?.let { detectionResult -> + resultList.add(detectionResult) + } + ?: { + didErrorOccurred = true + objectDetectorListener?.onError( + "ResultBundle could not be returned" + " in detectVideoFile" + ) + } + } + ?: run { + didErrorOccurred = true + objectDetectorListener?.onError( + "Frame at specified time could not be" + " retrieved when detecting in video." + ) + } } retriever.release() - val inferenceTimePerFrameMs = - (SystemClock.uptimeMillis() - startTime).div(numberOfFrameToRead) + val inferenceTimePerFrameMs = (SystemClock.uptimeMillis() - startTime).div(numberOfFrameToRead) return if (didErrorOccurred) { null @@ -235,50 +214,43 @@ class ObjectDetectorHelper( // Runs object detection on live streaming cameras frame-by-frame and returns the results // asynchronously to the caller. - fun detectLivestreamFrame(mpImage: MPImage) { + fun detectLivestreamFrame(mpImage: MPImage, orientation: Orientation) { if (runningMode != RunningMode.LIVE_STREAM) { throw IllegalArgumentException( - "Attempting to call detectLivestreamFrame" + " while not using RunningMode.LIVE_STREAM" + "Attempting to call detectLivestreamFrame" + " while not using RunningMode.LIVE_STREAM" ) } val frameTime = SystemClock.uptimeMillis() - - detectAsync(mpImage, frameTime) + // this is a hack bc we use this in the callback and it might have changed + this.imageRotation = orientationToDegrees(orientation) + detectAsync(mpImage, frameTime, this.imageRotation) } // Run object detection using MediaPipe Object Detector API @VisibleForTesting - fun detectAsync(mpImage: MPImage, frameTime: Long) { + fun detectAsync(mpImage: MPImage, frameTime: Long, imageRotation: Int) { // As we're using running mode LIVE_STREAM, the detection result will be returned in // returnLivestreamResult function + val imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(imageRotation).build() objectDetector?.detectAsync(mpImage, imageProcessingOptions, frameTime) } // Return the detection result to this ObjectDetectorHelper's caller - private fun returnLivestreamResult( - result: ObjectDetectorResult, input: MPImage - ) { + private fun returnLivestreamResult(result: ObjectDetectorResult, input: MPImage) { val finishTimeMs = SystemClock.uptimeMillis() val inferenceTime = finishTimeMs - result.timestampMs() objectDetectorListener?.onResults( - ResultBundle( - listOf(result), - inferenceTime, - input.height, - input.width, - imageRotation - ) + ResultBundle(listOf(result), inferenceTime, input.height, input.width, imageRotation) ) } // Return errors thrown during detection to this ObjectDetectorHelper's caller private fun returnLivestreamError(error: RuntimeException) { - objectDetectorListener?.onError( - error.message ?: "An unknown error has occurred" - ) + objectDetectorListener?.onError(error.message ?: "An unknown error has occurred") } // Accepted a Bitmap and runs object detection inference on it to return results back @@ -287,7 +259,7 @@ class ObjectDetectorHelper( if (runningMode != RunningMode.IMAGE) { throw IllegalArgumentException( - "Attempting to call detectImage" + " while not using RunningMode.IMAGE" + "Attempting to call detectImage" + " while not using RunningMode.IMAGE" ) } @@ -303,12 +275,7 @@ class ObjectDetectorHelper( // Run object detection using MediaPipe Object Detector API objectDetector?.detect(mpImage)?.also { detectionResult -> val inferenceTimeMs = SystemClock.uptimeMillis() - startTime - return ResultBundle( - listOf(detectionResult), - inferenceTimeMs, - image.height, - image.width - ) + return ResultBundle(listOf(detectionResult), inferenceTimeMs, image.height, image.width) } // If objectDetector?.detect() returns null, this is likely an error. Returning null @@ -319,11 +286,11 @@ class ObjectDetectorHelper( // Wraps results from inference, the time it takes for inference to be performed, and // the input image and height for properly scaling UI to return back to callers data class ResultBundle( - val results: List, - val inferenceTime: Long, - val inputImageHeight: Int, - val inputImageWidth: Int, - val inputImageRotation: Int = 0 + val results: List, + val inferenceTime: Long, + val inputImageHeight: Int, + val inputImageWidth: Int, + val inputImageRotation: Int = 0 ) companion object { diff --git a/android/src/main/java/com/reactnativemediapipe/shared/ResizeConvert.kt b/android/src/main/java/com/reactnativemediapipe/shared/ResizeConvert.kt deleted file mode 100644 index 9ca38d7..0000000 --- a/android/src/main/java/com/reactnativemediapipe/shared/ResizeConvert.kt +++ /dev/null @@ -1,39 +0,0 @@ -package com.reactnativemediapipe.shared - -import android.media.Image -import androidx.annotation.Keep -import com.facebook.jni.HybridData -import com.facebook.jni.annotations.DoNotStrip -import java.nio.ByteBuffer - -class ResizeConvert { - @DoNotStrip - @Keep - private val mHybridData: HybridData - init { - mHybridData = initHybrid() - } - private external fun initHybrid(): HybridData - - companion object { - // Load the native library once, shared by all instances - init { - System.loadLibrary("ResizeConvertLib") - } - } - - // Native methods are instance methods now - external fun resize( - image: Image, - cropX: Int, - cropY: Int, - cropWidth: Int, - cropHeight: Int, - scaleWidth: Int, - scaleHeight: Int, - rotationDegrees: Int, - mirror: Boolean, - pixelFormat: Int, - dataType: Int - ): ByteBuffer -} diff --git a/examples/objectdetection/src/CameraStream.tsx b/examples/objectdetection/src/CameraStream.tsx index f342a7f..2157bd8 100644 --- a/examples/objectdetection/src/CameraStream.tsx +++ b/examples/objectdetection/src/CameraStream.tsx @@ -14,6 +14,9 @@ import { MediapipeCamera, RunningMode, useObjectDetection, + clampToDims, + frameRectToView, + ltrbToXywh, } from "react-native-mediapipe"; import { @@ -74,7 +77,7 @@ export const CameraStream: React.FC = () => { }; const objectDetection = useObjectDetection( - (results, viewSize) => { + (results, viewSize, mirrored) => { const firstResult = results.results[0]; const detections = firstResult?.detections ?? []; const frameSize = { @@ -83,11 +86,15 @@ export const CameraStream: React.FC = () => { }; setObjectFrames( detections.map((detection) => { - const { x, y, width, height } = frameRectToView( - ltrbToXywh(detection.boundingBox), - frameSize, - viewSize, - "cover" + const { x, y, width, height } = clampToDims( + frameRectToView( + ltrbToXywh(detection.boundingBox), + frameSize, + viewSize, + "cover", + mirrored + ), + viewSize ); return { label: detection.categories[0]?.categoryName ?? "unknown", diff --git a/libyuv b/libyuv deleted file mode 160000 index 488a2af..0000000 --- a/libyuv +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 488a2af021e3e7473f083a9435b1472c0d411f3d diff --git a/src/__tests__/convert.test.ts b/src/__tests__/convert.test.ts new file mode 100644 index 0000000..22812e7 --- /dev/null +++ b/src/__tests__/convert.test.ts @@ -0,0 +1,261 @@ +import { + type Dims, + type Point, + type RectLTRB, + type RectXYWH, + framePointToView, + frameRectLTRBToView, + frameRectXYWHToView, + frameRectToView, + ltrbToXywh, + clampToDims, + type ResizeMode, +} from "../shared/convert"; // Adjust the import path to where your module is located + +describe("Image Frame Transformation Utilities", () => { + describe("framePointToView", () => { + it("should correctly transform a point in contain mode", () => { + const pointOrig: Point = { x: 10, y: 20 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 200, height: 400 }; + const mode: ResizeMode = "contain"; + const mirrored = false; + + const transformed = framePointToView( + pointOrig, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ x: 20, y: 40 }); + }); + + it("should correctly transform and mirror a point in cover mode", () => { + const pointOrig: Point = { x: 10, y: 20 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 200, height: 100 }; + const mode: ResizeMode = "cover"; + const mirrored = true; + + const transformed = framePointToView( + pointOrig, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ x: 180, y: -110 }); + }); + + // Additional tests for edge cases, etc. + }); + + describe("frameRectLTRBToView", () => { + it("should transform a LTRB rectangle correctly", () => { + const rect: RectLTRB = { left: 10, top: 10, right: 90, bottom: 190 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 200, height: 400 }; + const mode: ResizeMode = "contain"; + const mirrored = false; + + const transformed = frameRectLTRBToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ + left: 20, + top: 20, + right: 180, + bottom: 380, + }); + }); + + // Additional tests for edge cases, etc. + }); + + describe("frameRectXYWHToView", () => { + it("should transform an XYWH rectangle correctly", () => { + const rect: RectXYWH = { x: 10, y: 10, width: 80, height: 180 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 200, height: 400 }; + const mode: ResizeMode = "contain"; + const mirrored = false; + + const transformed = frameRectXYWHToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ x: 20, y: 20, width: 160, height: 360 }); + }); + + // Additional tests for edge cases, etc. + }); + + describe("frameRectToView", () => { + it("should handle RectLTRB inputs correctly", () => { + const rect: RectLTRB = { left: 10, top: 20, right: 50, bottom: 100 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 300, height: 600 }; + const mode: ResizeMode = "contain"; + const mirrored = false; + + const transformed = frameRectToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ + left: 30, + top: 60, + right: 150, + bottom: 300, + }); + }); + + it("should handle RectXYWH inputs correctly", () => { + const rect: RectXYWH = { x: 10, y: 20, width: 40, height: 80 }; + const frameDims: Dims = { width: 100, height: 200 }; + const viewDims: Dims = { width: 300, height: 600 }; + const mode: ResizeMode = "cover"; + const mirrored = true; + + const transformed = frameRectToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ); + expect(transformed).toEqual({ x: 150, y: 60, width: 120, height: 240 }); + }); + + // Additional tests for edge cases, etc. + }); + + describe("frameRectToView and ltbrToXywh", () => { + it("should transform a bounding box correctly", () => { + const data = { + frameSize: { width: 480, height: 640 }, + viewSize: { height: 600, width: 360 }, + bb0: { left: 0, top: 200, right: 480, bottom: 640 }, + r: { + x: 0, + y: (200 * 600) / 640, + width: 360, + height: 600 - (200 * 600) / 640, + }, + mirrored: true, + }; + const transformed = clampToDims( + frameRectToView( + ltrbToXywh(data.bb0), + data.frameSize, + data.viewSize, + "cover", + data.mirrored + ), + data.viewSize + ); + expect(transformed).toEqual(data.r); + }); + }); + + describe("ltrbToXywh and clampToDims", () => { + it("should convert LTRB to XYWH correctly", () => { + const ltrb: RectLTRB = { left: 10, top: 10, right: 90, bottom: 100 }; + const xywh = ltrbToXywh(ltrb); + expect(xywh).toEqual({ x: 10, y: 10, width: 80, height: 90 }); + }); + + it("should clamp XYWH rectangle within dimensions", () => { + const rect: RectXYWH = { x: -10, y: 300, width: 500, height: 500 }; + const dims: Dims = { width: 400, height: 400 }; + + const clamped = clampToDims(rect, dims); + expect(clamped).toEqual({ x: 0, y: 300, width: 400, height: 100 }); + }); + + // Additional tests for edge cases, etc. + }); + + describe("Image Frame Transformation Tests", () => { + const testcases = [ + { + frame: { width: 400, height: 400 }, + view: { width: 200, height: 300 }, + mode: "cover" as ResizeMode, + rect: { x: 0, y: 0, width: 400, height: 400 }, + target: { x: 100, y: 200 }, + expected: { + lt: { x: -50, y: 0 }, + rb: { x: 250, y: 300 }, + trg: { x: 25, y: 150 }, + }, + }, + { + frame: { width: 200, height: 400 }, + view: { width: 200, height: 200 }, + mode: "cover" as ResizeMode, + rect: { x: 0, y: 0, width: 200, height: 400 }, + target: { x: 100, y: 100 }, + expected: { + lt: { x: 0, y: -100 }, + rb: { x: 200, y: 300 }, + trg: { x: 100, y: 0 }, + }, + }, + { + frame: { width: 400, height: 400 }, + view: { width: 200, height: 300 }, + mode: "contain" as ResizeMode, + rect: { x: 0, y: 0, width: 400, height: 400 }, + target: { x: 100, y: 200 }, + expected: { + lt: { x: 0, y: 50 }, + rb: { x: 200, y: 250 }, + trg: { x: 50, y: 150 }, + }, + }, + { + frame: { width: 200, height: 400 }, + view: { width: 200, height: 200 }, + mode: "contain" as ResizeMode, + rect: { x: 0, y: 0, width: 200, height: 400 }, + target: { x: 100, y: 100 }, + expected: { + lt: { x: 50, y: 0 }, + rb: { x: 150, y: 200 }, + trg: { x: 100, y: 50 }, + }, + }, + ]; + + test.each(testcases)( + "Testing transformations with frame and view dimensions", + ({ frame, view, mode, rect, target, expected }) => { + const lt = framePointToView({ x: 0, y: 0 }, frame, view, mode, false); + const trg = framePointToView(target, frame, view, mode, false); + const rb = framePointToView( + { x: rect.x + rect.width, y: rect.y + rect.height }, + frame, + view, + mode, + false + ); + + expect(lt).toEqual({ x: expected.lt.x, y: expected.lt.y }); + expect(trg).toEqual({ x: expected.trg.x, y: expected.trg.y }); + expect(rb).toEqual({ x: expected.rb.x, y: expected.rb.y }); + } + ); + }); +}); diff --git a/src/__tests__/index.test.tsx b/src/__tests__/index.test.tsx index bf84291..e18553d 100644 --- a/src/__tests__/index.test.tsx +++ b/src/__tests__/index.test.tsx @@ -1 +1,3 @@ -it.todo('write a test'); +it("write a test", () => { + expect(true).toBe(true); +}); diff --git a/src/objectDetection/index.ts b/src/objectDetection/index.ts index 7d12394..5d6dc08 100644 --- a/src/objectDetection/index.ts +++ b/src/objectDetection/index.ts @@ -2,11 +2,13 @@ import React from "react"; import { NativeEventEmitter, NativeModules, + Platform, type LayoutChangeEvent, } from "react-native"; import { VisionCameraProxy, useFrameProcessor, + type CameraDevice, } from "react-native-vision-camera"; import type { MediaPipeSolution } from "../shared/types"; import type { Dims } from "../shared/convert"; @@ -99,12 +101,17 @@ export interface ObjectDetectionOptions { threshold: number; maxResults: number; delegate: Delegate; - resize: { scale: number; aspect: "preserve" | "default" | number }; + mirrorMode: "no-mirror" | "mirror" | "mirror-front-only"; } export interface ObjectDetectionCallbacks { - onResults: (result: ResultBundleMap, viewSize: Dims) => void; + onResults: ( + result: ResultBundleMap, + viewSize: Dims, + mirrored: boolean + ) => void; onError: (error: ObjectDetectionError) => void; viewSize: Dims; + mirrored: boolean; } // TODO setup the general event callbacks @@ -114,7 +121,7 @@ eventEmitter.addListener( (args: { handle: number } & ResultBundleMap) => { const callbacks = detectorMap.get(args.handle); if (callbacks) { - callbacks.onResults(args, callbacks.viewSize); + callbacks.onResults(args, callbacks.viewSize, callbacks.mirrored); } } ); @@ -153,7 +160,23 @@ export function useObjectDetection( }, [] ); - + const mirrorMode = + options?.mirrorMode ?? + Platform.select({ android: "mirror-front-only", default: "no-mirror" }); + const [cameraDevice, setCameraDevice] = React.useState< + CameraDevice | undefined + >(undefined); + const mirrored = React.useMemo((): boolean => { + if ( + (mirrorMode === "mirror-front-only" && + cameraDevice?.position === "front") || + mirrorMode === "mirror" + ) { + return true; + } else { + return false; + } + }, [cameraDevice?.position, mirrorMode]); // Remember the latest callback if it changes. React.useLayoutEffect(() => { if (detectorHandle !== undefined) { @@ -161,9 +184,10 @@ export function useObjectDetection( onResults, onError, viewSize: cameraViewDimensions, + mirrored, }); } - }, [onResults, onError, detectorHandle, cameraViewDimensions]); + }, [onResults, onError, detectorHandle, cameraViewDimensions, mirrored]); React.useEffect(() => { let newHandle: number | undefined; @@ -208,18 +232,15 @@ export function useObjectDetection( const frameProcessor = useFrameProcessor( (frame) => { "worklet"; - - plugin?.call(frame, { - detectorHandle, - pixelFormat: "rgb", - dataType: "uint8", - }); + // console.log(frame.orientation, frame.width, frame.height); + plugin?.call(frame, { detectorHandle }); }, [detectorHandle] ); return React.useMemo( (): MediaPipeSolution => ({ cameraViewLayoutChangeHandler, + cameraDeviceChangeHandler: setCameraDevice, cameraViewDimensions, frameProcessor, }), diff --git a/src/shared/convert.ts b/src/shared/convert.ts index be6719a..602c497 100644 --- a/src/shared/convert.ts +++ b/src/shared/convert.ts @@ -12,13 +12,17 @@ export type ResizeMode = "cover" | "contain"; // both cover and contain preserve aspect ratio. Cover will crop the image to fill the view, contain will show the whole image and add padding. // for cover, if the aspect ratio x/y of the frame is greater than export function framePointToView( - point: Point, + pointOrig: Point, frameDims: Dims, viewDims: Dims, - mode: ResizeMode + mode: ResizeMode, + mirrored: boolean ): Point { const frameRatio = frameDims.width / frameDims.height; const viewRatio = viewDims.width / viewDims.height; + const point = mirrored + ? { x: frameDims.width - pointOrig.x, y: pointOrig.y } + : pointOrig; let scale = 1; let xoffset = 0; let yoffset = 0; @@ -47,52 +51,63 @@ export function framePointToView( yoffset = (viewDims.height - frameDims.height * scale) / 2; } } - return { + const result = { x: point.x * scale + xoffset, y: point.y * scale + yoffset, }; + return result; } -function frameRectLTRBToView( +export function frameRectLTRBToView( rect: RectLTRB, frameDims: Dims, viewDims: Dims, - mode: ResizeMode + mode: ResizeMode, + mirrored: boolean ): RectLTRB { const lt = framePointToView( { x: rect.left, y: rect.top }, frameDims, viewDims, - mode + mode, + mirrored ); const rb = framePointToView( { x: rect.right, y: rect.bottom }, frameDims, viewDims, - mode + mode, + mirrored ); - return { left: lt.x, top: lt.y, right: rb.x, bottom: rb.y }; + const left = mirrored ? Math.min(lt.x, rb.x) : lt.x; + const right = mirrored ? Math.max(lt.x, rb.x) : rb.x; + return { left, top: lt.y, right, bottom: rb.y }; } -function frameRectXYWHToView( +export function frameRectXYWHToView( rect: RectXYWH, frameDims: Dims, viewDims: Dims, - mode: ResizeMode + mode: ResizeMode, + mirrored: boolean ): RectXYWH { const lt = framePointToView( { x: rect.x, y: rect.y }, frameDims, viewDims, - mode + mode, + mirrored ); const rb = framePointToView( { x: rect.x + rect.width, y: rect.y + rect.height }, frameDims, viewDims, - mode + mode, + mirrored ); - return { x: lt.x, y: lt.y, width: rb.x - lt.x, height: rb.y - lt.y }; + const width = mirrored ? Math.abs(rb.x - lt.x) : rb.x - lt.x; + const x = mirrored ? lt.x - width : lt.x; + return { x, y: lt.y, width, height: rb.y - lt.y }; } function isRectLTRB(rect: unknown): rect is RectLTRB { @@ -109,12 +124,25 @@ export function frameRectToView( rect: TRect, frameDims: Dims, viewDims: Dims, - mode: ResizeMode + mode: ResizeMode, + mirrored: boolean ): TRect { if (isRectLTRB(rect)) { - return frameRectLTRBToView(rect, frameDims, viewDims, mode) as TRect; + return frameRectLTRBToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ) as TRect; } else { - return frameRectXYWHToView(rect, frameDims, viewDims, mode) as TRect; + return frameRectXYWHToView( + rect, + frameDims, + viewDims, + mode, + mirrored + ) as TRect; } } @@ -126,3 +154,22 @@ export function ltrbToXywh(rect: RectLTRB): RectXYWH { height: rect.bottom - rect.top, }; } + +export function clampToDims( + rect: TRect, + dims: Dims +): TRect { + if (isRectLTRB(rect)) { + const left = Math.max(0, Math.min(rect.left, dims.width)); + const top = Math.max(0, Math.min(rect.top, dims.height)); + const right = Math.max(0, Math.min(rect.right, dims.width)); + const bottom = Math.max(0, Math.min(rect.bottom, dims.height)); + return { left, top, right, bottom } as TRect; + } else { + const x = Math.max(0, Math.min(rect.x, dims.width)); + const y = Math.max(0, Math.min(rect.y, dims.height)); + const width = Math.max(0, Math.min(rect.width, dims.width - x)); + const height = Math.max(0, Math.min(rect.height, dims.height - y)); + return { x, y, width, height } as TRect; + } +} diff --git a/src/shared/mediapipeCamera.tsx b/src/shared/mediapipeCamera.tsx index de92d6d..9cbf146 100644 --- a/src/shared/mediapipeCamera.tsx +++ b/src/shared/mediapipeCamera.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { type ViewStyle, Text, Platform } from "react-native"; +import { type ViewStyle, Text } from "react-native"; import { Camera, useCameraDevice, @@ -19,22 +19,29 @@ export type MediapipeCameraProps = { export const MediapipeCamera: React.FC = ({ style, - solution, + solution: { + cameraDeviceChangeHandler, + cameraViewLayoutChangeHandler, + frameProcessor, + }, activeCamera = "front", orientation = "portrait", resizeMode = "cover", }) => { const device = useCameraDevice(activeCamera); + React.useEffect(() => { + cameraDeviceChangeHandler(device); + }, [cameraDeviceChangeHandler, device]); return device !== undefined ? ( ) : ( no device diff --git a/src/shared/types.ts b/src/shared/types.ts index d78441c..a4d7555 100644 --- a/src/shared/types.ts +++ b/src/shared/types.ts @@ -1,8 +1,12 @@ import type { LayoutChangeEvent } from "react-native"; -import type { ReadonlyFrameProcessor } from "react-native-vision-camera"; +import type { + CameraDevice, + ReadonlyFrameProcessor, +} from "react-native-vision-camera"; export interface MediaPipeSolution { frameProcessor: ReadonlyFrameProcessor; cameraViewLayoutChangeHandler: (event: LayoutChangeEvent) => void; + cameraDeviceChangeHandler: (device: CameraDevice | undefined) => void; cameraViewDimensions: { width: number; height: number }; }