From 4eeb62a25cc2d4d6bd6625584fa6f5d2190866eb Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 20 May 2023 13:31:30 +0800 Subject: [PATCH 1/3] add react-native-blob-jsi-helper --- package.json | 1 + yarn.lock | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/package.json b/package.json index f865b0d..c1074a5 100644 --- a/package.json +++ b/package.json @@ -25,6 +25,7 @@ "path-browserify": "^1.0.1", "react": "18.2.0", "react-native": "0.71.8", + "react-native-blob-jsi-helper": "^0.3.0", "react-native-fs": "^2.20.0", "react-native-image-picker": "^5.3.1", "react-native-quick-base64": "^2.0.6", diff --git a/yarn.lock b/yarn.lock index d614973..39a5f86 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6610,6 +6610,11 @@ react-is@^17.0.1: resolved "https://registry.yarnpkg.com/react-is/-/react-is-17.0.2.tgz#e691d4a8e9c789365655539ab372762b0efb54f0" integrity sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w== +react-native-blob-jsi-helper@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/react-native-blob-jsi-helper/-/react-native-blob-jsi-helper-0.3.0.tgz#a57a8467d9b08d620db1d9e546dbbef45e2996d2" + integrity sha512-9ez/zdiHEcuI86ufxSAWqiPEMjhtCW89DHlG3nVPhQ1vBi7cb7/jsrMYILVaNzGsxsW7vPPcMAs9Cd8hxo7M0w== + react-native-codegen@^0.71.5: version "0.71.5" resolved "https://registry.yarnpkg.com/react-native-codegen/-/react-native-codegen-0.71.5.tgz#454a42a891cd4ca5fc436440d301044dc1349c14" From ee1e715c78c0384b55b73efb3f24f820c1e1608d Mon Sep 17 00:00:00 2001 From: Jhen Date: Sat, 20 May 2023 13:32:04 +0800 Subject: [PATCH 2/3] ONNX RN: use blobId / size as output --- patches/onnxruntime-react-native+1.14.0.patch | 121 +++++++++++++++++- 1 file changed, 119 insertions(+), 2 deletions(-) diff --git a/patches/onnxruntime-react-native+1.14.0.patch b/patches/onnxruntime-react-native+1.14.0.patch index 71cc993..117f3c2 100644 --- a/patches/onnxruntime-react-native+1.14.0.patch +++ b/patches/onnxruntime-react-native+1.14.0.patch @@ -12,11 +12,75 @@ index 4c8a318..65b58c1 100644 + implementation project(":onnxruntime-patched") + } +diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +index fe59cef..41c1dd2 100644 +--- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java ++++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +@@ -39,6 +39,8 @@ import java.util.Set; + import java.util.stream.Collectors; + import java.util.stream.Stream; + ++import com.facebook.react.modules.blob.BlobModule; ++ + @RequiresApi(api = Build.VERSION_CODES.N) + public class OnnxruntimeModule extends ReactContextBaseJavaModule { + private static ReactApplicationContext reactContext; +@@ -165,6 +167,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { + throw new Exception("Model is not loaded: " + key); + } + ++ BlobModule blobModule = reactContext.getNativeModule(BlobModule.class); ++ + RunOptions runOptions = parseRunOptions(options); + + long startTime = System.currentTimeMillis(); +@@ -217,7 +221,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { + Log.d("Duration", "inference: " + duration); + + startTime = System.currentTimeMillis(); +- WritableMap resultMap = TensorHelper.createOutputTensor(result); ++ WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); + duration = System.currentTimeMillis() - startTime; + Log.d("Duration", "createOutputTensor: " + duration); + diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java -index 500141a..49b3abd 100644 +index 500141a..20c680f 100644 --- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java -@@ -164,7 +164,11 @@ public class TensorHelper { +@@ -29,6 +29,8 @@ import java.util.Objects; + import java.util.stream.Collectors; + import java.util.stream.Stream; + ++import com.facebook.react.modules.blob.BlobModule; ++ + public class TensorHelper { + /** + * Supported tensor data type +@@ -80,7 +82,7 @@ public class TensorHelper { + * It creates an output map from an output tensor. + * a data array is encoded as base64 string. + */ +- public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception { ++ public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception { + WritableMap outputTensorMap = Arguments.createMap(); + + Iterator> iterator = result.iterator(); +@@ -115,8 +117,12 @@ public class TensorHelper { + } + outputTensor.putArray("data", dataArray); + } else { +- String data = createOutputTensor(onnxTensor); +- outputTensor.putString("data", data); ++ // Blob ++ byte[] bufferArray = createOutputTensor(onnxTensor); ++ String blobId = blobModule.store(bufferArray); ++ int size = bufferArray.length; ++ outputTensor.putString("data", blobId); ++ outputTensor.putInt("size", size); + } + + outputTensorMap.putMap(outputName, outputTensor); +@@ -164,7 +170,11 @@ public class TensorHelper { tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); break; } @@ -29,3 +93,56 @@ index 500141a..49b3abd 100644 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: +@@ -177,7 +187,7 @@ public class TensorHelper { + return tensor; + } + +- private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception { ++ private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception { + TensorInfo tensorInfo = onnxTensor.getInfo(); + ByteBuffer buffer = null; + +@@ -224,8 +234,7 @@ public class TensorHelper { + throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString()); + } + +- String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT); +- return data; ++ return buffer.array(); + } + + private static final Map JsTensorTypeToOnnxTensorTypeMap = +diff --git a/node_modules/onnxruntime-react-native/lib/backend.ts b/node_modules/onnxruntime-react-native/lib/backend.ts +index 4ebc364..7aee5a0 100644 +--- a/node_modules/onnxruntime-react-native/lib/backend.ts ++++ b/node_modules/onnxruntime-react-native/lib/backend.ts +@@ -4,6 +4,7 @@ + import {Buffer} from 'buffer'; + import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; + import {Platform} from 'react-native'; ++import {getArrayBufferForBlob} from 'react-native-blob-jsi-helper'; + + import {binding, Binding} from './binding'; + +@@ -98,7 +99,20 @@ class OnnxruntimeSessionHandler implements SessionHandler { + } + } + const input = this.encodeFeedsType(feeds); +- const results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options); ++ let results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options); ++ results = Object.entries(results).reduce((acc, [name, result]) => { ++ acc[name] = { ++ ...result, ++ data: getArrayBufferForBlob({ ++ _data: { ++ blobId: result.data, ++ offset: 0, ++ size: result.size, ++ } ++ }), ++ }; ++ return acc; ++ }, {}) + const output = this.decodeReturnType(results); + return output; + } From a9c24cdd23f25ae001dff610e9bd241401a41623 Mon Sep 17 00:00:00 2001 From: Jhen Date: Sun, 21 May 2023 10:53:33 +0800 Subject: [PATCH 3/3] add ios patch --- patches/onnxruntime-react-native+1.14.0.patch | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/patches/onnxruntime-react-native+1.14.0.patch b/patches/onnxruntime-react-native+1.14.0.patch index 117f3c2..4a35452 100644 --- a/patches/onnxruntime-react-native+1.14.0.patch +++ b/patches/onnxruntime-react-native+1.14.0.patch @@ -112,6 +112,52 @@ index 500141a..20c680f 100644 } private static final Map JsTensorTypeToOnnxTensorTypeMap = +diff --git a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm +index 00c1c79..ed6c81c 100644 +--- a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm ++++ b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm +@@ -2,6 +2,8 @@ + // Licensed under the MIT License. + + #import "TensorHelper.h" ++#import ++#import + #import + + @implementation TensorHelper +@@ -109,8 +111,11 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa + } + outputTensor[@"data"] = buffer; + } else { +- NSString *data = [self createOutputTensor:value]; +- outputTensor[@"data"] = data; ++ NSData *buffer = [self createOutputTensor:value]; ++ RCTBlobManager* blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; ++ NSString* blobId = [blobManager store:buffer]; ++ outputTensor[@"data"] = blobId; ++ outputTensor[@"size"] = [NSNumber numberWithUnsignedInteger:buffer.length]; + } + + outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor; +@@ -170,15 +175,15 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa + } + } + +-template static NSString *createOutputTensorT(const Ort::Value &tensor) { ++template static NSData *createOutputTensorT(const Ort::Value &tensor) { + const auto data = tensor.GetTensorData(); + NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data + length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T) + freeWhenDone:false]; +- return [buffer base64EncodedStringWithOptions:0]; ++ return buffer; + } + +-+ (NSString *)createOutputTensor:(const Ort::Value &)tensor { +++ (NSData *)createOutputTensor:(const Ort::Value &)tensor { + ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType(); + + switch (tensorType) { diff --git a/node_modules/onnxruntime-react-native/lib/backend.ts b/node_modules/onnxruntime-react-native/lib/backend.ts index 4ebc364..7aee5a0 100644 --- a/node_modules/onnxruntime-react-native/lib/backend.ts