diff --git a/packages/transformers/jest.config.mjs b/packages/transformers/jest.config.mjs index 7f97a64b0..4127621ba 100644 --- a/packages/transformers/jest.config.mjs +++ b/packages/transformers/jest.config.mjs @@ -84,7 +84,10 @@ export default { // ], // A map from regular expressions to module names or to arrays of module names that allow to stub out resources with a single module - // moduleNameMapper: {}, + moduleNameMapper: { + 'native-universal-fs': 'node:fs/promises', + 'react-native': new URL('./tests/react-native.mock.js', import.meta.url).pathname, + }, // An array of regexp pattern strings, matched against all module paths before considered 'visible' to the module loader // modulePathIgnorePatterns: [], diff --git a/packages/transformers/package.json b/packages/transformers/package.json index ce6cbdf20..2c56ffa3f 100644 --- a/packages/transformers/package.json +++ b/packages/transformers/package.json @@ -1,7 +1,7 @@ { "name": "@huggingface/transformers", "version": "4.0.0-next.8", - "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!", + "description": "State-of-the-art Machine Learning for the web. Run \ud83e\udd17 Transformers directly in your browser, with no need for a server!", "main": "./dist/transformers.node.cjs", "types": "./types/transformers.d.ts", "type": "module", @@ -9,16 +9,19 @@ "node": { "import": { "types": "./types/transformers.d.ts", - "default": "./dist/transformers.node.mjs" + "default": "./dist/transformers.node.mjs", + "react-native": "./dist/transformers.native.mjs" }, "require": { "types": "./types/transformers.d.ts", - "default": "./dist/transformers.node.cjs" + "default": "./dist/transformers.node.cjs", + "react-native": "./dist/transformers.native.cjs" } }, "default": { "types": "./types/transformers.d.ts", - "default": "./dist/transformers.web.js" + "default": "./dist/transformers.web.js", + "react-native": "./dist/transformers.native.mjs" } }, "scripts": { @@ -59,7 +62,9 @@ "@huggingface/tokenizers": "^0.1.3", "onnxruntime-node": "1.24.3", "onnxruntime-web": "1.25.0-dev.20260307-d626b568e0", - "sharp": "^0.34.5" + "sharp": "^0.34.5", + "native-universal-fs": "^0.2.0", + "onnxruntime-react-native": "1.22.0" }, "devDependencies": { "@types/jest": "^30.0.0", @@ -69,7 +74,8 @@ "jest": "^30.2.0", "jest-environment-node": "^30.2.0", "jsdoc-to-markdown": "^9.1.3", - "typescript": "5.9.3" + "typescript": "5.9.3", + "path-browserify": "^1.0.1" }, "files": [ "src", @@ -83,5 +89,6 @@ "access": "public" }, "jsdelivr": "./dist/transformers.min.js", - "unpkg": "./dist/transformers.min.js" + "unpkg": "./dist/transformers.min.js", + "react-native": "./dist/transformers.native.mjs" } diff --git a/packages/transformers/scripts/build/constants.mjs b/packages/transformers/scripts/build/constants.mjs index f8b2a3373..23430e625 100644 --- a/packages/transformers/scripts/build/constants.mjs +++ b/packages/transformers/scripts/build/constants.mjs @@ -12,6 +12,8 @@ export const NODE_EXTERNAL_MODULES = [ export const WEB_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "path", "url", "stream", "stream/promises"]; export const WEB_EXTERNAL_MODULES = ["onnxruntime-common", "onnxruntime-web"]; +export const REACT_NATIVE_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "stream", "stream/promises", "url"]; +export const REACT_NATIVE_EXTERNAL_MODULES = ["onnxruntime-common", "onnxruntime-react-native", "react-native"]; const __dirname = path.dirname(fileURLToPath(import.meta.url)); export const ROOT_DIR = path.join(__dirname, "../.."); diff --git a/packages/transformers/scripts/build/targets.mjs b/packages/transformers/scripts/build/targets.mjs index fcbf4383d..1f1b22ffa 100644 --- a/packages/transformers/scripts/build/targets.mjs +++ b/packages/transformers/scripts/build/targets.mjs @@ -1,4 +1,4 @@ -import { NODE_IGNORE_MODULES, NODE_EXTERNAL_MODULES, WEB_IGNORE_MODULES, WEB_EXTERNAL_MODULES } from "./constants.mjs"; +import { NODE_IGNORE_MODULES, NODE_EXTERNAL_MODULES, WEB_IGNORE_MODULES, WEB_EXTERNAL_MODULES, REACT_NATIVE_IGNORE_MODULES, REACT_NATIVE_EXTERNAL_MODULES } from "./constants.mjs"; /** * Build target configuration @@ -26,6 +26,28 @@ export const BUILD_TARGETS = [ usePostBuild: false, }, }, + { + name: "React Native Build (ESM)", + config: { + name: ".native", + suffix: ".mjs", + format: "esm", + ignoreModules: REACT_NATIVE_IGNORE_MODULES, + externalModules: REACT_NATIVE_EXTERNAL_MODULES, + usePostBuild: false, + }, + }, + { + name: "React Native Build (CJS)", + config: { + name: ".native", + suffix: ".cjs", + format: "cjs", + ignoreModules: REACT_NATIVE_IGNORE_MODULES, + externalModules: REACT_NATIVE_EXTERNAL_MODULES, + usePostBuild: false, + }, + }, { name: "Node Build (ESM)", config: { diff --git a/packages/transformers/src/backends/onnx.js b/packages/transformers/src/backends/onnx.js index 13b1a7482..c715118dc 100644 --- a/packages/transformers/src/backends/onnx.js +++ b/packages/transformers/src/backends/onnx.js @@ -22,6 +22,7 @@ import { env, apis, LogLevel } from '../env.js'; // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web/webgpu'; +import * as ONNX_REACT_NATIVE from 'onnxruntime-react-native'; import { loadWasmBinary, loadWasmFactory } from './utils/cacheWasm.js'; import { isBlobURL, toAbsoluteURL } from '../utils/hub/utils.js'; import { logger } from '../utils/logger.js'; @@ -41,6 +42,8 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({ cuda: 'cuda', // CUDA dml: 'dml', // DirectML coreml: 'coreml', // CoreML + xnnpack: 'xnnpack', // XNNPACK + nnapi: 'nnapi', // NNAPI webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default) 'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU @@ -105,6 +108,10 @@ const ORT_SYMBOL = Symbol.for('onnxruntime'); if (ORT_SYMBOL in globalThis) { // If the JS runtime exposes their own ONNX runtime, use it ONNX = globalThis[ORT_SYMBOL]; +} else if (apis.IS_REACT_NATIVE_ENV) { + ONNX = ONNX_REACT_NATIVE.default ?? ONNX_REACT_NATIVE; + supportedDevices.push('xnnpack', 'cpu', 'nnapi', 'coreml'); + defaultDevices = ['xnnpack', 'cpu']; } else if (apis.IS_NODE_ENV) { ONNX = ONNX_NODE; @@ -167,7 +174,7 @@ export function deviceToExecutionProviders(device = null) { case 'auto': return supportedDevices; case 'gpu': - return supportedDevices.filter((x) => ['webgpu', 'cuda', 'dml', 'webnn-gpu'].includes(x)); + return supportedDevices.filter((x) => ['webgpu', 'cuda', 'dml', 'webnn-gpu', 'xnnpack', 'nnapi', 'coreml'].includes(x)); } if (supportedDevices.includes(device)) { diff --git a/packages/transformers/src/env.js b/packages/transformers/src/env.js index c3fd2e06a..91d85533d 100644 --- a/packages/transformers/src/env.js +++ b/packages/transformers/src/env.js @@ -22,16 +22,19 @@ * @module env */ +import * as NativeFS from 'native-universal-fs'; import fs from 'node:fs'; -import path from 'node:path'; +import nodePath from 'node:path'; +import path from 'path-browserify'; import url from 'node:url'; const VERSION = '4.0.0-next.8'; const HAS_SELF = typeof self !== 'undefined'; +const IS_REACT_NATIVE_ENV = typeof navigator !== 'undefined' && navigator.product === 'ReactNative'; -const IS_FS_AVAILABLE = !isEmpty(fs); -const IS_PATH_AVAILABLE = !isEmpty(path); +const IS_FS_AVAILABLE = !isEmpty(fs) || !isEmpty(NativeFS); +const IS_PATH_AVAILABLE = !isEmpty(nodePath) || !isEmpty(path); const IS_WEB_CACHE_AVAILABLE = HAS_SELF && 'caches' in self; // Runtime detection @@ -105,6 +108,9 @@ export const apis = Object.freeze({ /** Whether we are running in a web-like environment (browser, web worker, or Deno web runtime) */ IS_WEB_ENV, + /** Whether we are running in a React Native environment */ + IS_REACT_NATIVE_ENV, + /** Whether we are running in a service worker environment */ IS_SERVICE_WORKER_ENV, @@ -145,16 +151,18 @@ export const apis = Object.freeze({ const RUNNING_LOCALLY = IS_FS_AVAILABLE && IS_PATH_AVAILABLE; let dirname__ = './'; -if (RUNNING_LOCALLY) { +if (IS_REACT_NATIVE_ENV) { + dirname__ = NativeFS.DocumentDirectoryPath; +} else if (RUNNING_LOCALLY) { // NOTE: We wrap `import.meta` in a call to `Object` to prevent Webpack from trying to bundle it in CommonJS. // Although we get the warning: "Accessing import.meta directly is unsupported (only property access or destructuring is supported)", // it is safe to ignore since the bundled value (`{}`) isn't used for CommonJS environments (we use __dirname instead). const _import_meta_url = Object(import.meta).url; if (_import_meta_url) { - dirname__ = path.dirname(path.dirname(url.fileURLToPath(_import_meta_url))); // ESM + dirname__ = nodePath.dirname(nodePath.dirname(url.fileURLToPath(_import_meta_url))); // ESM } else if (typeof __dirname !== 'undefined') { - dirname__ = path.dirname(__dirname); // CommonJS + dirname__ = nodePath.dirname(__dirname); // CommonJS } } @@ -262,6 +270,7 @@ export const env = { allowLocalModels: !(IS_BROWSER_ENV || IS_WEBWORKER_ENV || IS_DENO_WEB_RUNTIME), // Default to true for non-web environments, false for web environments localModelPath: localModelPath, useFS: IS_FS_AVAILABLE, + rnUseCanvas: true, /////////////////// Cache settings /////////////////// useBrowserCache: IS_WEB_CACHE_AVAILABLE, diff --git a/packages/transformers/src/utils/audio.js b/packages/transformers/src/utils/audio.js index be64a13de..18b5943c6 100644 --- a/packages/transformers/src/utils/audio.js +++ b/packages/transformers/src/utils/audio.js @@ -12,8 +12,17 @@ import { FFT, max } from './maths.js'; import { calculateReflectOffset } from './core.js'; import { saveBlob } from './io.js'; import { Tensor, matmul } from './tensor.js'; +import { Buffer } from 'buffer'; +import fs from 'node:fs'; +import * as NativeFS from 'native-universal-fs'; import { logger } from './logger.js'; +/** + * Helper function to read audio from a path/URL. + * @param {string|URL} url The path/URL to load the audio from. + * @param {number} sampling_rate The sampling rate to use when decoding the audio. + * @returns {Promise} The decoded audio as a `Float32Array`. + */ /** * Helper function to read audio from a path/URL. * @param {string|URL} url The path/URL to load the audio from. @@ -856,6 +865,14 @@ export class RawAudio { * @returns {Promise} */ async save(path) { - return saveBlob(path, this.toBlob()); + if (apis.IS_REACT_NATIVE_ENV) { + const buffer = await this.toBlob().arrayBuffer(); + await NativeFS.writeFile(path, Buffer.from(buffer).toString('base64'), 'base64'); + return; + } + if (apis.IS_WEB_ENV || apis.IS_FS_AVAILABLE) { + return saveBlob(path, this.toBlob()); + } + throw new Error('Unable to save because filesystem is disabled in this environment.'); } } diff --git a/packages/transformers/src/utils/cache/FileCache.js b/packages/transformers/src/utils/cache/FileCache.js index 9e81d4a8e..2aacc03cf 100644 --- a/packages/transformers/src/utils/cache/FileCache.js +++ b/packages/transformers/src/utils/cache/FileCache.js @@ -1,5 +1,8 @@ +import * as NativeFS from 'native-universal-fs'; +import { Buffer } from 'buffer'; import fs from 'node:fs'; -import path from 'node:path'; +import nodePath from 'node:path'; +import path from 'path-browserify'; import { FileResponse } from '../hub/FileResponse.js'; import { Random } from '../random.js'; @@ -28,13 +31,8 @@ export class FileCache { */ async match(request) { let filePath = path.join(this.path, request); - let file = new FileResponse(filePath); - - if (file.exists) { - return file; - } else { - return undefined; - } + let file = apis.IS_REACT_NATIVE_ENV ? await FileResponse.create(filePath) : new FileResponse(filePath); + return file.exists ? file : undefined; } /** @@ -60,44 +58,45 @@ export class FileCache { const total = parseInt(contentLength ?? '0'); let loaded = 0; - await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); - const fileStream = fs.createWriteStream(tmpPath); const reader = response.body.getReader(); - - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; + if (apis.IS_REACT_NATIVE_ENV) { + await NativeFS.mkdir(path.dirname(filePath)); + const chunks = []; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + loaded += value.length; + const progress = total ? (loaded / total) * 100 : 0; + progress_callback?.({ progress, loaded, total }); } - - await new Promise((resolve, reject) => { - fileStream.write(value, (err) => { - if (err) { - reject(err); - return; - } - resolve(); + const combined = new Uint8Array(loaded); + let offset = 0; + for (const chunk of chunks) { combined.set(chunk, offset); offset += chunk.length; } + await NativeFS.writeFile(tmpPath, Buffer.from(combined).toString('base64'), 'base64'); + await NativeFS.moveFile(tmpPath, filePath); + } else { + await fs.promises.mkdir(nodePath.dirname(filePath), { recursive: true }); + const fileStream = fs.createWriteStream(tmpPath); + while (true) { + const { done, value } = await reader.read(); + if (done) break; + await new Promise((resolve, reject) => { + fileStream.write(value, (err) => err ? reject(err) : resolve()); }); + loaded += value.length; + const progress = total ? (loaded / total) * 100 : 0; + progress_callback?.({ progress, loaded, total }); + } + await new Promise((resolve, reject) => { + fileStream.close((err) => (err ? reject(err) : resolve())); }); - - loaded += value.length; - const progress = total ? (loaded / total) * 100 : 0; - - progress_callback?.({ progress, loaded, total }); + await fs.promises.rename(tmpPath, filePath); } - - await new Promise((resolve, reject) => { - fileStream.close((err) => (err ? reject(err) : resolve())); - }); - - // Atomically move the completed temp file to the final path so that - // concurrent readers (other processes or other in-process calls) - // never observe a partially-written file. - await fs.promises.rename(tmpPath, filePath); } catch (error) { // Clean up the temp file if an error occurred during download try { - await fs.promises.unlink(tmpPath); + if (apis.IS_REACT_NATIVE_ENV) await NativeFS.unlink(tmpPath); else await fs.promises.unlink(tmpPath); } catch {} throw error; } @@ -112,7 +111,7 @@ export class FileCache { let filePath = path.join(this.path, request); try { - await fs.promises.unlink(filePath); + if (apis.IS_REACT_NATIVE_ENV) await NativeFS.unlink(filePath); else await fs.promises.unlink(filePath); return true; } catch (error) { // File doesn't exist or couldn't be deleted diff --git a/packages/transformers/src/utils/devices.js b/packages/transformers/src/utils/devices.js index 613e7e556..ea2a3c5ca 100644 --- a/packages/transformers/src/utils/devices.js +++ b/packages/transformers/src/utils/devices.js @@ -12,13 +12,15 @@ export const DEVICE_TYPES = Object.freeze({ cuda: 'cuda', // CUDA dml: 'dml', // DirectML coreml: 'coreml', // CoreML + xnnpack: 'xnnpack', // XNNPACK + nnapi: 'nnapi', // NNAPI webnn: 'webnn', // WebNN (default) 'webnn-npu': 'webnn-npu', // WebNN NPU 'webnn-gpu': 'webnn-gpu', // WebNN GPU 'webnn-cpu': 'webnn-cpu', // WebNN CPU }); -const DEFAULT_DEVICE = apis.IS_NODE_ENV ? 'cpu' : 'wasm'; +const DEFAULT_DEVICE = apis.IS_NODE_ENV ? 'cpu' : (apis.IS_REACT_NATIVE_ENV ? 'xnnpack' : 'wasm'); /** * @typedef {keyof typeof DEVICE_TYPES} DeviceType diff --git a/packages/transformers/src/utils/hub.js b/packages/transformers/src/utils/hub.js index cc15d792c..37b557af9 100755 --- a/packages/transformers/src/utils/hub.js +++ b/packages/transformers/src/utils/hub.js @@ -59,7 +59,13 @@ export { MAX_EXTERNAL_DATA_CHUNKS } from './hub/constants.js'; */ export async function getFile(urlOrPath) { if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) { - return new FileResponse( + return apis.IS_REACT_NATIVE_ENV ? await FileResponse.create( + urlOrPath instanceof URL + ? urlOrPath.protocol === 'file:' + ? urlOrPath.pathname + : urlOrPath.toString() + : urlOrPath, + ) : new FileResponse( urlOrPath instanceof URL ? urlOrPath.protocol === 'file:' ? urlOrPath.pathname @@ -67,7 +73,7 @@ export async function getFile(urlOrPath) { : urlOrPath, ); } else { - return env.fetch(urlOrPath, { + return (apis.IS_REACT_NATIVE_ENV ? fetch : env.fetch)(urlOrPath, { headers: getFetchHeaders(urlOrPath), }); } diff --git a/packages/transformers/src/utils/hub/FileResponse.js b/packages/transformers/src/utils/hub/FileResponse.js index d0102dd79..8d7f8eb5f 100644 --- a/packages/transformers/src/utils/hub/FileResponse.js +++ b/packages/transformers/src/utils/hub/FileResponse.js @@ -1,5 +1,9 @@ +import * as NativeFS from 'native-universal-fs'; +import { Buffer } from 'buffer'; import fs from 'node:fs'; +import { apis } from '../../env.js'; + /** * Mapping from file extensions to MIME types. */ @@ -22,9 +26,10 @@ export class FileResponse { */ constructor(filePath) { this.filePath = filePath; + this.url = String(filePath).startsWith('file://') ? String(filePath) : `file://${filePath}`; this.headers = new Headers(); - this.exists = fs.existsSync(filePath); + this.exists = apis.IS_REACT_NATIVE_ENV ? false : fs.existsSync(filePath); if (this.exists) { this.status = 200; this.statusText = 'OK'; @@ -67,6 +72,25 @@ export class FileResponse { * Clone the current FileResponse object. * @returns {FileResponse} A new FileResponse object with the same properties as the current object. */ + static async create(filePath) { + const response = new FileResponse(filePath); + if (apis.IS_REACT_NATIVE_ENV) { + response.exists = await NativeFS.exists(String(response.url)); + if (response.exists) { + response.status = 200; + response.statusText = 'OK'; + const stats = await NativeFS.stat(String(response.url)); + response.headers.set('content-length', String(stats.size)); + response.updateContentType(); + } else { + response.status = 404; + response.statusText = 'Not Found'; + response.body = null; + } + } + return response; + } + clone() { let response = new FileResponse(this.filePath); response.exists = this.exists; @@ -83,6 +107,10 @@ export class FileResponse { * @throws {Error} If the file cannot be read. */ async arrayBuffer() { + if (apis.IS_REACT_NATIVE_ENV) { + const data = Buffer.from(await NativeFS.readFile(String(this.url), 'base64'), 'base64'); + return /** @type {ArrayBuffer} */ (data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength)); + } const data = await fs.promises.readFile(this.filePath); return /** @type {ArrayBuffer} */ (data.buffer); } @@ -94,7 +122,7 @@ export class FileResponse { * @throws {Error} If the file cannot be read. */ async blob() { - const data = await fs.promises.readFile(this.filePath); + const data = apis.IS_REACT_NATIVE_ENV ? Buffer.from(await NativeFS.readFile(String(this.url), 'base64'), 'base64') : await fs.promises.readFile(this.filePath); return new Blob([/** @type {any} */ (data)], { type: this.headers.get('content-type') }); } @@ -105,6 +133,9 @@ export class FileResponse { * @throws {Error} If the file cannot be read. */ async text() { + if (apis.IS_REACT_NATIVE_ENV) { + return await NativeFS.readFile(String(this.url), 'utf8'); + } return await fs.promises.readFile(this.filePath, 'utf8'); } diff --git a/packages/transformers/src/utils/hub/utils.js b/packages/transformers/src/utils/hub/utils.js index e05ee27ee..36f86cb3e 100644 --- a/packages/transformers/src/utils/hub/utils.js +++ b/packages/transformers/src/utils/hub/utils.js @@ -1,3 +1,4 @@ +import { apis } from '../../env.js'; import { ERROR_MAPPING, REPO_ID_REGEX } from './constants.js'; import { logger } from '../logger.js'; @@ -29,6 +30,16 @@ export function pathJoin(...parts) { * @returns {boolean} True if the string is a valid URL, false otherwise. */ export function isValidUrl(string, protocols = null, validHosts = null) { + if (apis.IS_REACT_NATIVE_ENV) { + const str = String(string); + if (!/^\w+:\/\//.test(str)) return false; + if (protocols && !protocols.some((protocol) => str.startsWith(protocol))) return false; + if (validHosts) { + const match = str.match(/^(\w+\:)\/\/(([^:\/?#]*)(?:\:([0-9]+))?)/); + if (!match || !validHosts.includes(match[3])) return false; + } + return true; + } let url; try { url = new URL(string); diff --git a/packages/transformers/src/utils/image.js b/packages/transformers/src/utils/image.js index d7d8c7fcd..d37f71be8 100644 --- a/packages/transformers/src/utils/image.js +++ b/packages/transformers/src/utils/image.js @@ -9,8 +9,9 @@ import { isNullishDimension } from './core.js'; import { getFile } from './hub.js'; -import { apis } from '../env.js'; +import { apis, env } from '../env.js'; import { Tensor } from './tensor.js'; +import { interpolate_data, permute_data } from './maths.js'; import { saveBlob } from './io.js'; // Will be empty (or not used) if running in browser or web-worker @@ -20,7 +21,39 @@ import { logger } from './logger.js'; let createCanvasFunction; let ImageDataClass; let loadImageFunction; -if (apis.IS_WEB_ENV) { +if (apis.IS_REACT_NATIVE_ENV) { + // Optional Support gcanvas or skia with web polyfill for better performance + const offscreenCanvasExists = typeof OffscreenCanvas !== 'undefined'; + if (typeof Image !== 'undefined' && (typeof document !== 'undefined' || offscreenCanvasExists)) { + createCanvasFunction = (/** @type {number} */ width, /** @type {number} */ height) => { + if (offscreenCanvasExists) { + return new OffscreenCanvas(width, height); + } else { + const canvas = document.createElement('canvas'); + canvas.width = width; + canvas.height = height; + return canvas; + } + }; + loadImageFunction = async (/**@type {URL|string}*/url) => + await new Promise((resolve, reject) => { + const image = new Image(); + image.onload = () => { + const canvas = createCanvasFunction(image.width, image.height); + const ctx = canvas.getContext('2d'); + ctx.drawImage(image, 0, 0); + const { data } = ctx.getImageData(0, 0, image.width, image.height); + resolve(new RawImage(data, image.width, image.height, 4)); + // @ts-expect-error TS2339 + image.dispose?.(); + canvas.dispose?.(); + } + image.onerror = reject; + image.src = String(url); + }); + ImageDataClass = global.ImageData; + } +} else if (apis.IS_WEB_ENV) { // Running in browser or web-worker createCanvasFunction = (/** @type {number} */ width, /** @type {number} */ height) => { if (!self.OffscreenCanvas) { @@ -150,12 +183,20 @@ export class RawImage { * @returns {Promise} The image object. */ static async fromURL(url) { - const response = await getFile(url); - if (response.status !== 200) { - throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`); + if (apis.IS_REACT_NATIVE_ENV) { + if (env.rnUseCanvas && loadImageFunction) { + return await loadImageFunction(url); + } else { + throw new Error('fromURL() is not supported React Native without OffscreenCanvas.'); + } + } else { + const response = await getFile(url); + if (response.status !== 200) { + throw new Error(`Unable to read image from "${url}" (${response.status} ${response.statusText})`); + } + const blob = await response.blob(); + return this.fromBlob(blob); } - const blob = await response.blob(); - return this.fromBlob(blob); } /** @@ -378,7 +419,40 @@ export class RawImage { height = (width / this.width) * this.height; } - if (apis.IS_WEB_ENV) { + if (apis.IS_REACT_NATIVE_ENV) { + if (createCanvasFunction !== undefined && env.rnUseCanvas) { + // Running in environment with canvas + let canvas = createCanvasFunction(this.width, this.height); + let ctx = canvas.getContext('2d'); + let imageData = this.toImageData(); + ctx.putImageData(imageData, 0, 0); + ctx.drawImage(canvas, 0, 0, this.width, this.height, 0, 0, width, height); + let newImageData = ctx.getImageData(0, 0, width, height); + const resized = new RawImage(newImageData.data, width, height, 4); + canvas.dispose?.(); + return resized.convert(this.channels); + } else { + // Running in environment without canvas + // WHC -> CHW + const [trsnsposed] = permute_data( + this.data, + [this.width, this.height, this.channels], + [2, 0, 1] + ); + const resized = interpolate_data( + trsnsposed, + [this.channels, this.height, this.width], + [height, width] + ); + // CHW -> WHC + const [newData] = permute_data( + resized, + [this.channels, height, width], + [1, 2, 0] + ); + return new RawImage(newData, width, height, this.channels); + } + } else if (apis.IS_WEB_ENV) { // TODO use `resample` in browser environment // Store number of channels before resizing @@ -452,7 +526,38 @@ export class RawImage { return this; } - if (apis.IS_WEB_ENV) { + if (apis.IS_REACT_NATIVE_ENV) { + if (createCanvasFunction !== undefined && env.rnUseCanvas) { + // Running in environment with canvas + let newWidth = this.width + left + right; + let newHeight = this.height + top + bottom; + let canvas = createCanvasFunction(newWidth, newHeight); + let ctx = canvas.getContext('2d'); + let imageData = this.toImageData(); + ctx.putImageData(imageData, left, top); + let newImageData = ctx.getImageData(0, 0, newWidth, newHeight); + const padded = new RawImage(newImageData.data, newWidth, newHeight, 4); + canvas.dispose?.(); + return padded.convert(this.channels); + } else { + // Running in environment without canvas + const channels = this.channels; + const data = this.data; + const width = this.width + left + right; + const height = this.height + top + bottom; + const paddedData = new Uint8ClampedArray(width * height * channels); + for (let i = 0; i < data.length; i += channels) { + const x = Math.floor(i / channels) % this.width; + const y = Math.floor(i / channels / this.width); + const pixelIndex = (y * width + x) * channels; + for (let j = 0; j < channels; j++) { + paddedData[pixelIndex + j] = data[i + j]; + } + } + return new RawImage(paddedData, width, height, channels); + } + + } else if (apis.IS_WEB_ENV) { // Store number of channels before padding const numChannels = this.channels; @@ -494,7 +599,34 @@ export class RawImage { const crop_width = x_max - x_min + 1; const crop_height = y_max - y_min + 1; - if (apis.IS_WEB_ENV) { + if (apis.IS_REACT_NATIVE_ENV) { + if (createCanvasFunction !== undefined && env.rnUseCanvas) { + // Running in environment with canvas + let canvas = createCanvasFunction(crop_width, crop_height); + let ctx = canvas.getContext('2d'); + let imageData = this.toImageData(); + ctx.putImageData(imageData, -x_min, -y_min); + let newImageData = ctx.getImageData(0, 0, crop_width, crop_height); + const cropped = new RawImage(newImageData.data, crop_width, crop_height, 4); + canvas.dispose?.(); + return cropped.convert(this.channels); + } else { + // Running in environment without canvas + let channels = this.channels; + let data = this.data; + let croppedData = new Uint8ClampedArray(crop_width * crop_height * channels); + for (let i = 0; i < croppedData.length; i += channels) { + const x = Math.floor(i / channels) % crop_width; + const y = Math.floor(i / channels / crop_width); + const pixelIndex = ((y + y_min) * this.width + (x + x_min)) * channels; + for (let j = 0; j < channels; j++) { + croppedData[i + j] = data[pixelIndex + j]; + } + } + return new RawImage(croppedData, crop_width, crop_height, channels); + } + + } else if (apis.IS_WEB_ENV) { // Store number of channels before resizing const numChannels = this.channels; @@ -541,7 +673,13 @@ export class RawImage { const width_offset = (this.width - crop_width) / 2; const height_offset = (this.height - crop_height) / 2; - if (apis.IS_WEB_ENV) { + if (apis.IS_REACT_NATIVE_ENV) { + return this.crop([ + width_offset, height_offset, + width_offset + crop_width - 1, height_offset + crop_height - 1 + ]); + + } else if (apis.IS_WEB_ENV) { // Store number of channels before resizing const numChannels = this.channels; @@ -648,6 +786,16 @@ export class RawImage { } } + toImageData() { + if (apis.IS_REACT_NATIVE_ENV && ImageDataClass === undefined) + throw new Error('toImageData() is only supported in browser environments.'); + // Clone, and convert data to RGBA before create ImageData object. + // This is because the ImageData API only supports RGBA + let cloned = this.clone().rgba(); + + return new ImageDataClass(cloned.data, cloned.width, cloned.height); + } + async toBlob(type = 'image/png', quality = 1) { if (!apis.IS_WEB_ENV) { throw new Error('toBlob() is only supported in browser environments.'); @@ -672,7 +820,7 @@ export class RawImage { } toCanvas() { - if (!apis.IS_WEB_ENV) { + if ((!apis.IS_WEB_ENV && !apis.IS_REACT_NATIVE_ENV) || !createCanvasFunction) { throw new Error('toCanvas() is only supported in browser environments.'); } @@ -773,7 +921,13 @@ export class RawImage { * @returns {Promise} */ async save(path) { - if (apis.IS_WEB_ENV) { + const extension = path.split('.').pop().toLowerCase(); + const mime = CONTENT_TYPE_MAP.get(extension) ?? 'image/png'; + + if (apis.IS_REACT_NATIVE_ENV) { + throw new Error('save() is not supported in React Native environments.'); + + } else if (apis.IS_WEB_ENV) { if (apis.IS_WEBWORKER_ENV) { throw new Error('Unable to save an image from a Web Worker.'); } @@ -798,7 +952,7 @@ export class RawImage { * @returns {import('sharp').Sharp} The Sharp instance. */ toSharp() { - if (apis.IS_WEB_ENV) { + if (apis.IS_WEB_ENV || apis.IS_REACT_NATIVE_ENV) { throw new Error('toSharp() is only supported in server-side environments.'); } diff --git a/packages/transformers/src/utils/io.js b/packages/transformers/src/utils/io.js index 4aa868391..3f48295c4 100644 --- a/packages/transformers/src/utils/io.js +++ b/packages/transformers/src/utils/io.js @@ -1,3 +1,5 @@ +import * as NativeFS from 'native-universal-fs'; +import { Buffer } from 'buffer'; import fs from 'node:fs'; import { Readable } from 'node:stream'; import { pipeline as pipe } from 'node:stream/promises'; @@ -33,6 +35,10 @@ export async function saveBlob(path, blob) { // Revoke the Object URL to free up memory URL.revokeObjectURL(dataURL); + } else if (apis.IS_REACT_NATIVE_ENV) { + const arrayBuffer = await blob.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + await NativeFS.writeFile(path, base64, 'base64'); } else if (apis.IS_FS_AVAILABLE) { // Convert Blob to a Node.js Readable Stream const webStream = blob.stream(); diff --git a/packages/transformers/src/utils/model-loader.js b/packages/transformers/src/utils/model-loader.js index 599aefe9d..3252b8745 100644 --- a/packages/transformers/src/utils/model-loader.js +++ b/packages/transformers/src/utils/model-loader.js @@ -45,7 +45,7 @@ export async function getCoreModelFile(pretrained_model_name_or_path, fileName, const baseName = `${fileName}${suffix}.onnx`; const fullPath = `${options.subfolder ?? ''}/${baseName}`; - return await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV); + return await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV || apis.IS_REACT_NATIVE_ENV); } /** @@ -68,7 +68,7 @@ export async function getModelDataFiles( session_options = {}, ) { const baseName = `${fileName}${suffix}.onnx`; - const return_path = apis.IS_NODE_ENV; + const return_path = apis.IS_NODE_ENV || apis.IS_REACT_NATIVE_ENV; /** @type {Promise[]} */ let externalDataPromises = []; diff --git a/packages/transformers/tests/react-native.mock.js b/packages/transformers/tests/react-native.mock.js new file mode 100644 index 000000000..d38d2aa4b --- /dev/null +++ b/packages/transformers/tests/react-native.mock.js @@ -0,0 +1 @@ +// empty mock for non-react-native environments