diff --git a/apps/llm/app/_layout.tsx b/apps/llm/app/_layout.tsx index 5ece80f1f..4ab010693 100644 --- a/apps/llm/app/_layout.tsx +++ b/apps/llm/app/_layout.tsx @@ -89,6 +89,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + Voice Chat + router.navigate('multimodal_llm/')} + > + Multimodal LLM (VLM) + ); diff --git a/apps/llm/app/multimodal_llm/index.tsx b/apps/llm/app/multimodal_llm/index.tsx new file mode 100644 index 000000000..1781684a0 --- /dev/null +++ b/apps/llm/app/multimodal_llm/index.tsx @@ -0,0 +1,310 @@ +import { useContext, useEffect, useRef, useState } from 'react'; +import { + Image, + Keyboard, + KeyboardAvoidingView, + Platform, + StyleSheet, + Text, + TextInput, + TouchableOpacity, + TouchableWithoutFeedback, + View, +} from 'react-native'; +import { launchImageLibrary } from 'react-native-image-picker'; +import { useIsFocused } from '@react-navigation/native'; +import { useLLM, LFM2_VL_1_6B_QUANTIZED } from 'react-native-executorch'; +import SendIcon from '../../assets/icons/send_icon.svg'; +import PauseIcon from '../../assets/icons/pause_icon.svg'; +import ColorPalette from '../../colors'; +import Messages from '../../components/Messages'; +import Spinner from '../../components/Spinner'; +import { GeneratingContext } from '../../context'; + +export default function MultimodalLLMScreenWrapper() { + const isFocused = useIsFocused(); + return isFocused ? : null; +} + +function MultimodalLLMScreen() { + const [imageUri, setImageUri] = useState(null); + const [userInput, setUserInput] = useState(''); + const [isTextInputFocused, setIsTextInputFocused] = useState(false); + const textInputRef = useRef(null); + const { setGlobalGenerating } = useContext(GeneratingContext); + + const vlm = useLLM({ + model: LFM2_VL_1_6B_QUANTIZED, + }); + + useEffect(() => { + setGlobalGenerating(vlm.isGenerating); + }, [vlm.isGenerating, setGlobalGenerating]); + + useEffect(() => { + if (vlm.error) console.error('MultimodalLLM error:', vlm.error); + }, [vlm.error]); + + const pickImage = async () => { + const result = await launchImageLibrary({ mediaType: 'photo' }); + if (result.assets && result.assets.length > 0) { + const uri = result.assets[0]?.uri; + if (uri) setImageUri(uri); + } + }; + + const sendMessage = async () => { + if (!userInput.trim() || vlm.isGenerating) return; + const text = userInput.trim(); + setUserInput(''); + textInputRef.current?.clear(); + Keyboard.dismiss(); + const currentImageUri = imageUri; + setImageUri(null); + try { + await vlm.sendMessage( + text, + currentImageUri ? { imagePath: currentImageUri } : undefined + ); + } catch (e) { + console.error('Generation error:', e); + } + }; + + if (!vlm.isReady) { + return ( + + ); + } + + return ( + + + + {vlm.messageHistory.length ? ( + + + + ) : ( + + Hello! 👋 + + Pick an image and ask me anything about it. + + + )} + + {/* Image thumbnail strip */} + {imageUri && ( + + + Tap to change + + )} + + + {/* Image picker button */} + + 📷 + + + setIsTextInputFocused(true)} + onBlur={() => setIsTextInputFocused(false)} + style={[ + styles.textInput, + { + borderColor: isTextInputFocused + ? ColorPalette.blueDark + : ColorPalette.blueLight, + }, + ]} + placeholder={imageUri ? 'Ask about the image…' : 'Your message'} + placeholderTextColor="#C1C6E5" + multiline + onChangeText={setUserInput} + /> + + {userInput.trim() && !vlm.isGenerating && ( + + + + )} + {vlm.isGenerating && ( + + + + )} + + + + + ); +} + +const styles = StyleSheet.create({ + // Setup phase + setupContainer: { + flex: 1, + padding: 24, + backgroundColor: '#fff', + justifyContent: 'center', + }, + setupTitle: { + fontSize: 20, + fontFamily: 'medium', + color: ColorPalette.primary, + marginBottom: 8, + }, + setupHint: { + fontSize: 13, + fontFamily: 'regular', + color: ColorPalette.blueDark, + marginBottom: 32, + lineHeight: 18, + }, + filePickerRow: { + flexDirection: 'row', + alignItems: 'center', + borderWidth: 1, + borderColor: ColorPalette.blueLight, + borderRadius: 10, + padding: 14, + marginBottom: 12, + backgroundColor: '#fafbff', + }, + filePickerInfo: { flex: 1 }, + filePickerLabel: { + fontSize: 12, + fontFamily: 'medium', + color: ColorPalette.blueDark, + marginBottom: 2, + }, + filePickerValue: { fontSize: 14, fontFamily: 'regular' }, + filePickerValueSet: { color: ColorPalette.primary }, + filePickerValueEmpty: { color: ColorPalette.blueLight }, + filePickerChevron: { + fontSize: 24, + color: ColorPalette.blueLight, + marginLeft: 8, + }, + loadButton: { + marginTop: 16, + backgroundColor: ColorPalette.strongPrimary, + borderRadius: 10, + padding: 14, + alignItems: 'center', + }, + loadButtonDisabled: { backgroundColor: ColorPalette.blueLight }, + loadButtonText: { color: '#fff', fontFamily: 'medium', fontSize: 15 }, + + // Chat phase + container: { flex: 1 }, + chatContainer: { flex: 10, width: '100%' }, + helloMessageContainer: { + flex: 10, + width: '100%', + alignItems: 'center', + justifyContent: 'center', + }, + helloText: { + fontFamily: 'medium', + fontSize: 30, + color: ColorPalette.primary, + }, + bottomHelloText: { + fontFamily: 'regular', + fontSize: 20, + lineHeight: 28, + textAlign: 'center', + color: ColorPalette.primary, + paddingHorizontal: 24, + }, + imageThumbnailContainer: { + flexDirection: 'row', + alignItems: 'center', + paddingHorizontal: 16, + paddingVertical: 6, + gap: 8, + }, + imageThumbnail: { + width: 48, + height: 48, + borderRadius: 8, + borderWidth: 1, + borderColor: ColorPalette.blueLight, + }, + imageThumbnailHint: { + fontSize: 12, + fontFamily: 'regular', + color: ColorPalette.blueDark, + }, + bottomContainer: { + height: 100, + width: '100%', + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'center', + paddingHorizontal: 16, + }, + imageButton: { + width: 40, + height: 40, + justifyContent: 'center', + alignItems: 'center', + marginRight: 4, + }, + imageButtonText: { fontSize: 22 }, + textInput: { + flex: 1, + borderWidth: 1, + borderRadius: 8, + lineHeight: 19.6, + fontFamily: 'regular', + fontSize: 14, + color: ColorPalette.primary, + padding: 16, + }, + sendChatTouchable: { + height: '100%', + width: 48, + justifyContent: 'center', + alignItems: 'flex-end', + }, +}); diff --git a/apps/llm/components/MessageItem.tsx b/apps/llm/components/MessageItem.tsx index c4d7d549e..58da5074c 100644 --- a/apps/llm/components/MessageItem.tsx +++ b/apps/llm/components/MessageItem.tsx @@ -4,6 +4,7 @@ import { StyleSheet, TouchableOpacity, Text, + Image, Platform, } from 'react-native'; import MarkdownComponent from './MarkdownComponent'; @@ -17,19 +18,31 @@ interface MessageItemProps { } const MessageItem = memo(({ message, deleteMessage }: MessageItemProps) => { - return ( - - {message.role === 'assistant' && ( + if (message.role === 'assistant') { + return ( + - )} - + + + + ); + } + + return ( + + + {message.mediaPath && ( + + )} + + ); }); @@ -64,17 +77,26 @@ const styles = StyleSheet.create({ marginVertical: 8, alignItems: 'center', }, - userMessage: { + userMessageWrapper: { flexDirection: 'row-reverse', - paddingHorizontal: 12, - paddingVertical: 8, marginRight: 8, marginVertical: 8, maxWidth: '75%', + alignSelf: 'flex-end', + alignItems: 'flex-start', + }, + userMessageBubble: { + flexDirection: 'column', + paddingHorizontal: 12, + paddingVertical: 8, borderRadius: 8, backgroundColor: ColorPalette.seaBlueLight, - alignSelf: 'flex-end', - alignItems: 'center', + }, + userMessageImage: { + width: 200, + height: 200, + borderRadius: 6, + marginBottom: 6, }, aiMessageIconContainer: { backgroundColor: ColorPalette.seaBlueLight, diff --git a/apps/llm/package.json b/apps/llm/package.json index f58bc8127..d0fbb6401 100644 --- a/apps/llm/package.json +++ b/apps/llm/package.json @@ -19,6 +19,7 @@ "expo-brightness": "~14.0.8", "expo-calendar": "~15.0.8", "expo-constants": "~18.0.11", + "expo-document-picker": "~13.0.3", "expo-font": "~14.0.10", "expo-linking": "~8.0.10", "expo-router": "~6.0.17", @@ -30,6 +31,7 @@ "react-native-device-info": "^15.0.2", "react-native-executorch": "workspace:*", "react-native-gesture-handler": "~2.28.0", + "react-native-image-picker": "^7.2.2", "react-native-loading-spinner-overlay": "^3.0.1", "react-native-markdown-display": "^7.0.2", "react-native-reanimated": "~4.1.1", diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 7712b2b9d..e6e21e278 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -45,7 +45,9 @@ template class ModelHostObject : public JsiHostObject { "getInputShape")); } - if constexpr (meta::HasGenerate) { + // LLM::generate and LLM::generateMultimodal registered explicitly below + if constexpr (meta::HasGenerate && + !meta::SameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::generate>, "generate")); @@ -98,6 +100,10 @@ template class ModelHostObject : public JsiHostObject { } if constexpr (meta::SameAs) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generate>, + "generate")); + addFunctions(JSI_EXPORT_FUNCTION( ModelHostObject, synchronousHostFunction<&Model::interrupt>, "interrupt")); @@ -144,6 +150,16 @@ template class ModelHostObject : public JsiHostObject { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, synchronousHostFunction<&Model::reset>, "reset")); + + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateMultimodal>, + "generateMultimodal")); + + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, + synchronousHostFunction<&Model::getVisualTokenCount>, + "getVisualTokenCount")); } if constexpr (meta::SameAs) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 4a9d40033..03afd4ed0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -2,23 +2,42 @@ #include #include +#include #include +#include #include +#include +#include +#include namespace rnexecutorch::models::llm { namespace llm = ::executorch::extension::llm; namespace fs = std::filesystem; using namespace facebook; -using executorch::extension::TensorPtr; using executorch::extension::module::Module; using executorch::runtime::Error; LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, + std::vector capabilities, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker, Module::LoadMode::File), - runner( - std::make_unique(module_.get(), tokenizerSource)) { - auto loadResult = runner->load(); + : BaseModel(modelSource, callInvoker, Module::LoadMode::File) { + + if (capabilities.empty()) { + runner_ = + std::make_unique(std::move(module_), tokenizerSource); + } else { + std::map> encoders; + for (const auto &cap : capabilities) { + if (cap == "vision") { + encoders[llm::MultimodalType::Image] = + std::make_unique(*module_); + } + } + runner_ = std::make_unique( + std::move(module_), tokenizerSource, std::move(encoders)); + } + + auto loadResult = runner_->load(); if (loadResult != Error::Ok) { throw RnExecutorchError(loadResult, "Failed to load LLM runner"); } @@ -27,17 +46,13 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, fs::file_size(fs::path(tokenizerSource)); } -// TODO: add a way to manipulate the generation config with params std::string LLM::generate(std::string input, std::shared_ptr callback) { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } - std::string output; - - // Create a native callback that accumulates tokens and optionally invokes JS auto nativeCallback = [this, callback, &output](const std::string &token) { output += token; if (callback && callInvoker) { @@ -48,51 +63,135 @@ std::string LLM::generate(std::string input, }; auto config = llm::GenerationConfig{.echo = false, .warming = false}; - auto error = runner->generate(input, config, nativeCallback, {}); - if (error != executorch::runtime::Error::Ok) { + auto error = runner_->generate(input, config, nativeCallback, {}); + if (error != Error::Ok) { throw RnExecutorchError(error, "Failed to generate text"); } + return output; +} + +std::string LLM::generateMultimodal(std::string prompt, + std::vector imagePaths, + std::string imageToken, + std::shared_ptr callback) { + if (!runner_ || !runner_->is_loaded()) { + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Runner is not loaded"); + } + if (!runner_->is_multimodal()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "This model does not support multimodal input. Use generate(prompt, " + "callback) for text-only generation."); + } + if (imageToken.empty()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "imageToken must not be empty. Pass the model's image token (e.g. " + "from tokenizer_config.json)."); + } + + const size_t kImageTokenLen = imageToken.size(); + + std::vector inputs; + size_t imageIdx = 0; + size_t searchPos = 0; + + while (true) { + size_t found = prompt.find(imageToken, searchPos); + if (found == std::string::npos) { + if (searchPos < prompt.size()) { + inputs.push_back(llm::make_text_input(prompt.substr(searchPos))); + } + break; + } + // Text segment before this placeholder + if (found > searchPos) { + inputs.push_back( + llm::make_text_input(prompt.substr(searchPos, found - searchPos))); + } + // Image at this position + if (imageIdx >= imagePaths.size()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "More '" + imageToken + + "' placeholders in prompt than image paths provided"); + } + inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); + searchPos = found + kImageTokenLen; + } + + if (imageIdx < imagePaths.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More image paths provided than '" + imageToken + + "' placeholders in prompt"); + } + + if (inputs.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "No inputs to generate from"); + } + + std::string output; + auto nativeCallback = [this, callback, &output](const std::string &token) { + output += token; + if (callback && callInvoker) { + callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) { + callback->call(runtime, jsi::String::createFromUtf8(runtime, token)); + }); + } + }; + + auto error = runner_->generate(inputs, nativeCallback); + if (error != Error::Ok) { + throw RnExecutorchError(error, "Failed to generate multimodal response"); + } return output; } void LLM::interrupt() { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't interrupt a model that's not loaded"); } - runner->stop(); + runner_->stop(); } void LLM::reset() { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't reset a model that's not loaded"); } - runner->reset(); + runner_->reset(); } size_t LLM::getGeneratedTokenCount() const noexcept { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) return 0; - } - return runner->stats_.num_generated_tokens; + return runner_->stats_.num_generated_tokens; } size_t LLM::getPromptTokenCount() const noexcept { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) + return 0; + return runner_->stats_.num_prompt_tokens; +} + +int32_t LLM::getVisualTokenCount() const { + if (!runner_ || !runner_->is_loaded()) { return 0; } - return runner->stats_.num_prompt_tokens; + return runner_->get_visual_token_count(); } int32_t LLM::countTextTokens(std::string text) const { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError( RnExecutorchErrorCode::ModuleNotLoaded, "Can't count tokens from a model that's not loaded"); } - return runner->count_text_tokens(text); + return runner_->count_text_tokens(text); } size_t LLM::getMemoryLowerBound() const noexcept { @@ -100,7 +199,7 @@ size_t LLM::getMemoryLowerBound() const noexcept { } void LLM::setCountInterval(size_t countInterval) { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't configure a model that's not loaded"); } @@ -108,11 +207,11 @@ void LLM::setCountInterval(size_t countInterval) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, "Count interval must be greater than 0"); } - runner->set_count_interval(countInterval); + runner_->set_count_interval(countInterval); } void LLM::setTimeInterval(size_t timeInterval) { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't configure a model that's not loaded"); } @@ -120,11 +219,11 @@ void LLM::setTimeInterval(size_t timeInterval) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, "Time interval must be greater than 0"); } - runner->set_time_interval(timeInterval); + runner_->set_time_interval(timeInterval); } void LLM::setTemperature(float temperature) { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't configure a model that's not loaded"); } @@ -132,11 +231,11 @@ void LLM::setTemperature(float temperature) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, "Temperature must be non-negative"); } - runner->set_temperature(temperature); + runner_->set_temperature(temperature); } void LLM::setTopp(float topp) { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Can't configure a model that's not loaded"); } @@ -144,18 +243,18 @@ void LLM::setTopp(float topp) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, "Top-p must be between 0.0 and 1.0"); } - runner->set_topp(topp); + runner_->set_topp(topp); } int32_t LLM::getMaxContextLength() const { - if (!runner || !runner->is_loaded()) { + if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError( RnExecutorchErrorCode::ModuleNotLoaded, "Can't get context length from a model that's not loaded"); } - return runner->get_max_context_length(); + return runner_->get_max_context_length(); } -void LLM::unload() noexcept { runner.reset(nullptr); } +void LLM::unload() noexcept { runner_.reset(nullptr); } } // namespace rnexecutorch::models::llm diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 99daaf6f5..fcb93d0c1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -2,11 +2,12 @@ #include #include +#include #include #include #include -#include +#include namespace rnexecutorch { namespace models::llm { @@ -16,16 +17,23 @@ class LLM : public BaseModel { public: explicit LLM(const std::string &modelSource, const std::string &tokenizerSource, + std::vector capabilities, std::shared_ptr callInvoker); - std::string generate(std::string input, + std::string generate(std::string prompt, std::shared_ptr callback); + std::string generateMultimodal(std::string prompt, + std::vector imagePaths, + std::string imageToken, + std::shared_ptr callback); + void interrupt(); void reset(); void unload() noexcept; size_t getGeneratedTokenCount() const noexcept; size_t getPromptTokenCount() const noexcept; int32_t countTextTokens(std::string text) const; + int32_t getVisualTokenCount() const; size_t getMemoryLowerBound() const noexcept; void setCountInterval(size_t countInterval); void setTemperature(float temperature); @@ -34,10 +42,11 @@ class LLM : public BaseModel { int32_t getMaxContextLength() const; private: - std::unique_ptr runner; + std::unique_ptr<::executorch::extension::llm::BaseLLMRunner> runner_; }; } // namespace models::llm REGISTER_CONSTRUCTOR(models::llm::LLM, std::string, std::string, + std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index e2a8c16bf..159f00159 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -210,12 +210,17 @@ add_rn_test(SpeechToTextTests integration/SpeechToTextTest.cpp add_rn_test(LLMTests integration/LLMTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/llm/LLM.cpp - ${COMMON_DIR}/runner/runner.cpp + ${COMMON_DIR}/runner/base_llm_runner.cpp + ${COMMON_DIR}/runner/text_runner.cpp + ${COMMON_DIR}/runner/multimodal_runner.cpp + ${COMMON_DIR}/runner/multimodal_prefiller.cpp ${COMMON_DIR}/runner/text_prefiller.cpp ${COMMON_DIR}/runner/text_decoder_runner.cpp ${COMMON_DIR}/runner/sampler.cpp ${COMMON_DIR}/runner/arange_util.cpp - LIBS tokenizers_deps + ${COMMON_DIR}/runner/encoders/vision_encoder.cpp + ${IMAGE_UTILS_SOURCES} + LIBS tokenizers_deps opencv_deps ) add_rn_test(TextToImageTests integration/TextToImageTest.cpp diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp index e79294090..5ebb96fbe 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp @@ -6,6 +6,7 @@ #include #include #include +#include using namespace rnexecutorch; using namespace rnexecutorch::models::llm; @@ -37,12 +38,12 @@ template <> struct ModelTraits { using ModelType = LLM; static ModelType createValid() { - return ModelType(kValidModelPath, kValidTokenizerPath, + return ModelType(kValidModelPath, kValidTokenizerPath, {}, rnexecutorch::createMockCallInvoker()); } static ModelType createInvalid() { - return ModelType("nonexistent.pte", kValidTokenizerPath, + return ModelType("nonexistent.pte", kValidTokenizerPath, {}, rnexecutorch::createMockCallInvoker()); } @@ -67,18 +68,24 @@ class LLMTest : public ::testing::Test { }; TEST(LLMCtorTests, InvalidTokenizerPathThrows) { - EXPECT_THROW(LLM(kValidModelPath, "nonexistent_tokenizer.json", + EXPECT_THROW(LLM(kValidModelPath, "nonexistent_tokenizer.json", {}, createMockCallInvoker()), RnExecutorchError); } +TEST(LLMCtorTests, WrongCapabilitiesThrowsClearError) { + EXPECT_THROW(LLM(kValidModelPath, kValidTokenizerPath, {"vision"}, + createMockCallInvoker()), + rnexecutorch::RnExecutorchError); +} + TEST_F(LLMTest, GetGeneratedTokenCountInitiallyZero) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_EQ(model.getGeneratedTokenCount(), 0); } TEST_F(LLMTest, SetTemperature) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); // Should not throw for valid values EXPECT_NO_THROW(model.setTemperature(0.5f)); EXPECT_NO_THROW(model.setTemperature(1.0f)); @@ -86,43 +93,43 @@ TEST_F(LLMTest, SetTemperature) { } TEST_F(LLMTest, SetTemperatureNegativeThrows) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_THROW(model.setTemperature(-0.1f), RnExecutorchError); } TEST_F(LLMTest, SetTopp) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_NO_THROW(model.setTopp(0.9f)); EXPECT_NO_THROW(model.setTopp(0.5f)); EXPECT_NO_THROW(model.setTopp(1.0f)); } TEST_F(LLMTest, SetToppInvalidThrows) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_THROW(model.setTopp(-0.1f), RnExecutorchError); EXPECT_THROW(model.setTopp(1.1f), RnExecutorchError); } TEST_F(LLMTest, SetCountInterval) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_NO_THROW(model.setCountInterval(5)); EXPECT_NO_THROW(model.setCountInterval(10)); } TEST_F(LLMTest, SetTimeInterval) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_NO_THROW(model.setTimeInterval(100)); EXPECT_NO_THROW(model.setTimeInterval(500)); } TEST_F(LLMTest, InterruptThrowsWhenUnloaded) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); model.unload(); EXPECT_THROW(model.interrupt(), RnExecutorchError); } TEST_F(LLMTest, SettersThrowWhenUnloaded) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); model.unload(); // All setters should throw when model is unloaded EXPECT_THROW(model.setTemperature(0.5f), RnExecutorchError); @@ -132,7 +139,7 @@ TEST_F(LLMTest, SettersThrowWhenUnloaded) { } TEST_F(LLMTest, GenerateProducesValidOutput) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); model.setTemperature(0.0f); std::string prompt = formatChatML(kSystemPrompt, "Repeat exactly this: `naszponcilem testy`"); @@ -141,7 +148,7 @@ TEST_F(LLMTest, GenerateProducesValidOutput) { } TEST_F(LLMTest, GenerateUpdatesTokenCount) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_EQ(model.getGeneratedTokenCount(), 0); std::string prompt = formatChatML(kSystemPrompt, "Repeat exactly this: 'naszponcilem testy'"); @@ -150,6 +157,58 @@ TEST_F(LLMTest, GenerateUpdatesTokenCount) { } TEST_F(LLMTest, EmptyPromptThrows) { - LLM model(kValidModelPath, kValidTokenizerPath, mockInvoker_); + LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_THROW((void)model.generate("", nullptr), RnExecutorchError); } + +TEST(VisionEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { + // smolLm2_135M_8da4w.pte has no vision_encoder method + auto module = std::make_unique<::executorch::extension::Module>( + "smolLm2_135M_8da4w.pte", + ::executorch::extension::Module::LoadMode::File); + + auto encoder = + std::make_unique(module.get()); + + EXPECT_THROW(encoder->load(), rnexecutorch::RnExecutorchError); +} + +#include + +// Minimal concrete subclass — only used in tests to verify base class behavior +class StubRunner : public rnexecutorch::llm::runner::BaseLLMRunner { +public: + using BaseLLMRunner::BaseLLMRunner; + bool is_loaded() const override { return loaded_; } + ::executorch::runtime::Error load_subcomponents() override { + loaded_ = true; + return ::executorch::runtime::Error::Ok; + } + ::executorch::runtime::Error generate_internal( + const std::vector<::executorch::extension::llm::MultimodalInput> &, + std::function) override { + return ::executorch::runtime::Error::Ok; + } + void stop_impl() override {} + void set_temperature_impl(float t) override { last_temp_ = t; } + void set_topp_impl(float) override {} + void set_count_interval_impl(size_t) override {} + void set_time_interval_impl(size_t) override {} + + bool loaded_ = false; + float last_temp_ = -1.f; +}; + +TEST(BaseLLMRunnerTest, SetTemperatureWritesConfigAndCallsImpl) { + StubRunner runner(nullptr, "dummy_tokenizer.json"); + runner.set_temperature(0.5f); + EXPECT_FLOAT_EQ(runner.config_.temperature, 0.5f); + EXPECT_FLOAT_EQ(runner.last_temp_, 0.5f); +} + +TEST(BaseLLMRunnerTest, ResetZerosPos) { + StubRunner runner(nullptr, "dummy_tokenizer.json"); + runner.pos_ = 42; + runner.reset(); + EXPECT_EQ(runner.pos_, 0); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h b/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h index 50025eeeb..aeea242f4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h +++ b/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h @@ -35,8 +35,6 @@ class GlobalThreadPool { } numThreads = std::max(numThreads.value(), 2u); - log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with", - numThreads, "threads"); instance = std::make_unique(numThreads.value(), config); // Disable OpenCV's internal threading to prevent it from overriding our diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.cpp b/packages/react-native-executorch/common/runner/base_llm_runner.cpp new file mode 100644 index 000000000..9c8d83cc4 --- /dev/null +++ b/packages/react-native-executorch/common/runner/base_llm_runner.cpp @@ -0,0 +1,171 @@ +// common/runner/base_llm_runner.cpp +#include "base_llm_runner.h" +#include "constants.h" +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::extension::Module; +using ::executorch::runtime::Error; + +BaseLLMRunner::BaseLLMRunner(std::unique_ptr module, + const std::string &tokenizer_path, + const GenerationConfig &config) + : config_(config), module_(std::move(module)), + tokenizer_path_(tokenizer_path), + tokenizer_(std::make_unique()), + metadata_({ + {kEnableDynamicShape, false}, + {kMaxSeqLen, 128}, + {kMaxContextLen, 128}, + {kUseKVCache, true}, + }) {} + +Error BaseLLMRunner::load() { + if (is_loaded()) + return Error::Ok; + + auto status = tokenizer_->load(tokenizer_path_); + if (status != tokenizers::Error::Ok) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while loading tokenizer (error code: " + + std::to_string(static_cast(status)) + ")"); + } + + const auto method_names = + ET_UNWRAP(module_->method_names(), "Failed reading method names"); + + metadata_[kVocabSize] = tokenizer_->vocab_size(); + for (auto &pair : metadata_) { + const auto &method_name = pair.first; + auto &value = pair.second; + if (method_names.count(method_name)) { + value = ET_UNWRAP(module_->get(method_name)) + .toScalar() + .to(); + } + } + + if (config_.max_seq_len < 0) + config_.max_seq_len = static_cast(metadata_.at(kMaxSeqLen)); + if (config_.max_context_length < 0) { + config_.max_context_length = + method_names.count(kMaxContextLen) + ? static_cast(metadata_.at(kMaxContextLen)) + : static_cast(metadata_.at(kMaxSeqLen)); + } + if (config_.max_new_tokens < 0) + config_.max_new_tokens = + std::min(config_.max_seq_len, config_.max_context_length); + config_.enable_dynamic_shape = + static_cast(metadata_.at(kEnableDynamicShape)); + config_.enable_kv_cache = static_cast(metadata_.at(kUseKVCache)); + + eos_ids_ = std::make_unique>(); + if (method_names.count(kEosIds)) { + for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) { + eos_ids_->emplace(static_cast(eos_id.toScalar().to())); + } + } + if (eos_ids_->empty()) { + eos_ids_->emplace(7); // fallback <|im_end|> + } + + io_manager_ = std::make_unique(*module_); + + return load_subcomponents(); +} + +Error BaseLLMRunner::generate( + const std::string &prompt, const GenerationConfig &generation_config, + std::function token_callback, + std::function stats_callback) { + + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + + std::vector inputs = {make_text_input(prompt)}; + auto err = generate_internal(inputs, token_callback); + + if (stats_callback) + stats_callback(stats_); + + return err; +} + +Error BaseLLMRunner::generate( + const std::vector &inputs, + std::function token_callback, + std::function stats_callback) { + + auto err = generate_internal(inputs, token_callback); + + if (stats_callback) + stats_callback(stats_); + + return err; +} + +void BaseLLMRunner::stop() { stop_impl(); } + +void BaseLLMRunner::reset() { + stats_.reset(); + pos_ = 0; +} + +int32_t BaseLLMRunner::count_text_tokens(const std::string &text) const { + auto encodeResult = + tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens); + if (!encodeResult.ok()) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::TokenizerError, + "Encoding failed during token count check."); + } + return static_cast(encodeResult.get().size()); +} + +int32_t BaseLLMRunner::get_max_context_length() const { + if (!is_loaded()) + return static_cast(metadata_.at(kMaxContextLen)); + return config_.max_context_length; +} + +void BaseLLMRunner::set_temperature(float temperature) noexcept { + config_.temperature = temperature; + set_temperature_impl(temperature); +} + +void BaseLLMRunner::set_topp(float topp) noexcept { + config_.topp = topp; + set_topp_impl(topp); +} + +void BaseLLMRunner::set_count_interval(size_t count_interval) { + set_count_interval_impl(count_interval); +} + +void BaseLLMRunner::set_time_interval(size_t time_interval) { + set_time_interval_impl(time_interval); +} + +int32_t BaseLLMRunner::resolve_max_new_tokens(int32_t num_prompt_tokens, + int32_t max_seq_len, + int32_t max_context_len, + int32_t max_new_tokens) const { + int32_t result; + if (max_seq_len == -1 && max_new_tokens == -1) + result = max_context_len - num_prompt_tokens; + else if (max_seq_len == -1) + result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); + else if (max_new_tokens == -1) + result = std::min(max_seq_len, max_context_len) - num_prompt_tokens; + else + result = + std::min(std::min(max_seq_len, max_context_len) - num_prompt_tokens, + max_new_tokens); + return std::max(0, result); +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.h b/packages/react-native-executorch/common/runner/base_llm_runner.h new file mode 100644 index 000000000..3924aa3d7 --- /dev/null +++ b/packages/react-native-executorch/common/runner/base_llm_runner.h @@ -0,0 +1,85 @@ +// common/runner/base_llm_runner.h +#pragma once + +#include "io_manager.h" +#include "irunner.h" +#include "multimodal_input.h" +#include "stats.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { +class BaseLLMRunner { +public: + explicit BaseLLMRunner( + std::unique_ptr<::executorch::extension::Module> module, + const std::string &tokenizer_path, + const GenerationConfig &config = {.temperature = 0.8F, .topp = 0.9F}); + + virtual ~BaseLLMRunner() = default; + + virtual bool is_loaded() const = 0; + + virtual ::executorch::runtime::Error load(); + + ::executorch::runtime::Error + generate(const std::string &prompt, + const GenerationConfig &generation_config = {}, + std::function token_callback = {}, + std::function stats_callback = {}); + + ::executorch::runtime::Error + generate(const std::vector &inputs, + std::function token_callback = {}, + std::function stats_callback = {}); + + virtual ::executorch::runtime::Error generate_internal( + const std::vector &inputs, + std::function token_callback) = 0; + + void stop(); + void reset(); + int32_t count_text_tokens(const std::string &text) const; + int32_t get_max_context_length() const; + virtual bool is_multimodal() const { return false; } + virtual int32_t get_visual_token_count() const { return 0; } + + void set_temperature(float temperature) noexcept; + void set_topp(float topp) noexcept; + void set_count_interval(size_t count_interval); + void set_time_interval(size_t time_interval); + + Stats stats_; + + // Public for test access + GenerationConfig config_; + int64_t pos_{0}; + +protected: + virtual ::executorch::runtime::Error load_subcomponents() = 0; + virtual void stop_impl() = 0; + virtual void set_temperature_impl(float temperature) = 0; + virtual void set_topp_impl(float topp) = 0; + virtual void set_count_interval_impl(size_t count_interval) = 0; + virtual void set_time_interval_impl(size_t time_interval) = 0; + + int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len, + int32_t max_context_len, + int32_t max_new_tokens = -1) const; + + std::unique_ptr<::executorch::extension::Module> module_; + std::string tokenizer_path_; + std::unique_ptr tokenizer_; + std::unordered_map metadata_; + std::unique_ptr io_manager_; + std::unique_ptr> eos_ids_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/constants.h b/packages/react-native-executorch/common/runner/constants.h index e75466829..f1fee2347 100644 --- a/packages/react-native-executorch/common/runner/constants.h +++ b/packages/react-native-executorch/common/runner/constants.h @@ -17,7 +17,6 @@ inline constexpr auto kMaxSeqLen = "get_max_seq_len"; inline constexpr auto kMaxContextLen = "get_max_context_len"; inline constexpr auto kVocabSize = "get_vocab_size"; inline constexpr auto kUseKVCache = "use_kv_cache"; -inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; // Multimodal method name conventions inline constexpr auto kVisionEncoderMethod = "vision_encoder"; diff --git a/packages/react-native-executorch/common/runner/encoders/iencoder.h b/packages/react-native-executorch/common/runner/encoders/iencoder.h new file mode 100644 index 000000000..ae9bb203c --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/iencoder.h @@ -0,0 +1,25 @@ +// common/runner/encoders/iencoder.h +#pragma once + +#include +#include +#include +#include + +namespace executorch::extension::llm { + +class IEncoder { +public: + virtual ~IEncoder() = default; + virtual ::executorch::runtime::Error load() = 0; + virtual bool is_loaded() const noexcept = 0; + + virtual ::executorch::runtime::Result<::executorch::runtime::EValue> + encode(const MultimodalInput &input) = 0; + + // Returns the number of tokens produced per encoded input (e.g. visual + // tokens per image). Returns 0 if not loaded or unknown. + virtual int32_t encoderTokenCount() const { return 0; } +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp new file mode 100644 index 000000000..b28281493 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -0,0 +1,129 @@ +// common/runner/encoders/vision_encoder.cpp +#include "vision_encoder.h" + +#include +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; + +VisionEncoder::VisionEncoder(::executorch::extension::Module &module) + : module_(&module) {} + +Error VisionEncoder::load() { + if (is_loaded()) { + return Error::Ok; + } + auto method_names_result = module_->method_names(); + if (!method_names_result.ok()) { + return method_names_result.error(); + } + + if (method_names_result->count(kVisionEncoderMethod) == 0) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "Model does not support vision: 'vision_encoder' method not found. " + "Check that the .pte file matches the declared capabilities."); + } + return module_->load_method(kVisionEncoderMethod); +} + +bool VisionEncoder::is_loaded() const noexcept { + return module_->is_method_loaded(kVisionEncoderMethod); +} + +int32_t VisionEncoder::encoderTokenCount() const noexcept { + if (!is_loaded()) { + return 0; + } + auto meta_result = module_->method_meta(kVisionEncoderMethod); + if (!meta_result.ok()) { + return 0; + } + auto output_meta = meta_result->output_tensor_meta(0); + if (!output_meta.ok()) { + return 0; + } + // Output shape is [1, num_visual_tokens, embed_dim] + auto sizes = output_meta->sizes(); + if (sizes.size() < 2) { + return 0; + } + return static_cast(sizes[1]); +} + +Result VisionEncoder::getInputShape() const { + auto method_meta = ET_UNWRAP(module_->method_meta(kVisionEncoderMethod)); + auto input_meta = ET_UNWRAP(method_meta.input_tensor_meta(0)); + auto dims = input_meta.sizes(); + const bool with_batch = dims.size() == 4; + const int32_t offset = with_batch ? 1 : 0; + return ImageShape{ + .channels = static_cast(dims[offset]), + .height = static_cast(dims[offset + 1]), + .width = static_cast(dims[offset + 2]), + .with_batch = with_batch, + }; +} + +std::vector +VisionEncoder::preprocessImage(const std::string &path, + const ImageShape &targetShape) const { + cv::Mat mat = rnexecutorch::image_processing::readImage(path); + cv::resize(mat, mat, cv::Size(targetShape.width, targetShape.height)); + cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB); + + const int32_t pixelCount = targetShape.height * targetShape.width; + std::vector chw(targetShape.channels * pixelCount); + for (int32_t i = 0; i < pixelCount; ++i) { + cv::Vec3b px = + mat.at(i / targetShape.width, i % targetShape.width); + for (int32_t c = 0; c < targetShape.channels; ++c) { + chw[c * pixelCount + i] = static_cast(px[c]); + } + } + return chw; +} + +Result VisionEncoder::encode(const MultimodalInput &input) { + if (!is_loaded()) { + return Error::InvalidState; + } + if (!input.is_image()) { + return Error::InvalidArgument; + } + + const std::string &path = input.get_image_path(); + + auto it = embedding_cache_.find(path); + if (it != embedding_cache_.end()) { + return it->second; + } + + auto shape = ET_UNWRAP(getInputShape()); + auto chw = preprocessImage(path, shape); + + std::vector<::executorch::aten::SizesType> sizes = { + shape.channels, shape.height, shape.width}; + if (shape.with_batch) { + sizes.insert(sizes.begin(), 1); + } + + auto image_tensor = ::executorch::extension::from_blob( + chw.data(), sizes, ::executorch::aten::ScalarType::Float); + + auto result = ET_UNWRAP(module_->execute(kVisionEncoderMethod, image_tensor)); + auto embedding = result[0]; + embedding_cache_.emplace(path, embedding); + return embedding; +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.h b/packages/react-native-executorch/common/runner/encoders/vision_encoder.h new file mode 100644 index 000000000..1a42ff13a --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.h @@ -0,0 +1,38 @@ +// common/runner/encoders/vision_encoder.h +#pragma once + +#include "iencoder.h" +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +class VisionEncoder : public IEncoder { +public: + explicit VisionEncoder(::executorch::extension::Module &module); + + ::executorch::runtime::Error load() override; + bool is_loaded() const noexcept override; + ::executorch::runtime::Result<::executorch::runtime::EValue> + encode(const MultimodalInput &input) override; + int32_t encoderTokenCount() const noexcept override; + +private: + struct ImageShape { + int32_t channels, height, width; + bool with_batch; + }; + + ::executorch::runtime::Result getInputShape() const; + std::vector preprocessImage(const std::string &path, + const ImageShape &targetShape) const; + + ::executorch::extension::Module *module_; + std::unordered_map + embedding_cache_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h new file mode 100644 index 000000000..3b6fe4660 --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Ported from executorch/extension/llm/runner/multimodal_decoder_runner.h + +#pragma once + +#include "constants.h" +#include "text_decoder_runner.h" + +namespace executorch::extension::llm { +class MultimodalDecoderRunner : public TextDecoderRunner { +public: + explicit MultimodalDecoderRunner(Module *module, IOManager *io_manager) + : TextDecoderRunner(module, io_manager) {} + + inline ::executorch::runtime::Result<::executorch::aten::Tensor> + step(TensorPtr &tokens, int64_t start_pos) override { + auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens); + if (!embed_result.ok()) { + return embed_result.error(); + } + return decode((*embed_result)[0], start_pos); + } + + inline ::executorch::runtime::Result<::executorch::aten::Tensor> + decode(const ::executorch::runtime::EValue &embeddings, int64_t start_pos) { + auto start_pos_tensor = ::executorch::extension::from_blob( + &start_pos, {1}, ::executorch::aten::ScalarType::Long); + auto outputs_result = + module_->execute(kTextModelMethod, {embeddings, start_pos_tensor}); + if (!outputs_result.ok()) { + return outputs_result.error(); + } + auto &outputs = *outputs_result; + ET_CHECK_MSG(outputs.size() == 1, + "Expected 1 output from text_decoder, got %zu", + outputs.size()); + ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor"); + return outputs[0].toTensor(); + } + + inline ::executorch::runtime::Error load() override { + if (is_method_loaded()) { + return ::executorch::runtime::Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); + return ::executorch::runtime::Error::Ok; + } + + inline bool is_method_loaded() override { + return module_->is_method_loaded(kTokenEmbeddingMethod) && + module_->is_method_loaded(kTextModelMethod); + } +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_input.h b/packages/react-native-executorch/common/runner/multimodal_input.h new file mode 100644 index 000000000..1dce55de0 --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_input.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Ported from executorch/extension/llm/runner/multimodal_input.h +// Audio support stripped — only text and image are used by LFM2-VL. + +#pragma once + +#include +#include +#include + +namespace executorch::extension::llm { +struct ImagePath { + std::string path; +}; + +class MultimodalInput { +public: + explicit MultimodalInput(const std::string &text) : data_(text) {} + explicit MultimodalInput(std::string &&text) : data_(std::move(text)) {} + explicit MultimodalInput(const std::vector &tokens) + : data_(tokens) {} + explicit MultimodalInput(std::vector &&tokens) + : data_(std::move(tokens)) {} + explicit MultimodalInput(ImagePath image_path) + : data_(std::move(image_path)) {} + + MultimodalInput(const MultimodalInput &) = default; + MultimodalInput &operator=(const MultimodalInput &) = default; + MultimodalInput(MultimodalInput &&) noexcept = default; + MultimodalInput &operator=(MultimodalInput &&) noexcept = default; + + bool is_text() const noexcept { + return std::holds_alternative(data_); + } + bool is_tokens() const noexcept { + return std::holds_alternative>(data_); + } + bool is_image() const noexcept { + return std::holds_alternative(data_); + } + + const std::string &get_text() const & { return std::get(data_); } + const std::vector &get_tokens() const & { + return std::get>(data_); + } + const std::string &get_image_path() const & { + return std::get(data_).path; + } + +private: + std::variant, ImagePath> data_; +}; + +inline MultimodalInput make_text_input(const std::string &text) noexcept { + return MultimodalInput(text); +} +inline MultimodalInput make_text_input(std::string &&text) noexcept { + return MultimodalInput(std::move(text)); +} +inline MultimodalInput make_image_input(std::string path) noexcept { + return MultimodalInput(ImagePath{std::move(path)}); +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp new file mode 100644 index 000000000..358f0e0cd --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Ported from executorch/extension/llm/runner/multimodal_prefiller.cpp +// with our token-embedding padding fix and LFM2-VL adaptations. + +#include "multimodal_prefiller.h" +#include "constants.h" +#include "util.h" + +namespace executorch::extension::llm { + +using ::executorch::aten::SizesType; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; + +MultimodalPrefiller::MultimodalPrefiller( + Module *module, MultimodalDecoderRunner *decoder_runner, + tokenizers::HFTokenizer *tokenizer, IOManager *io_manager, + IEncoder *image_encoder) + : module_(module), decoder_runner_(decoder_runner), tokenizer_(tokenizer), + io_manager_(io_manager), image_encoder_(image_encoder) {} + +Result MultimodalPrefiller::prefill(const MultimodalInput &input, + int64_t &start_pos) { + EValue encoder_output; + std::vector padded_tokens_storage; + TensorPtr sliced_embed_storage; + + if (input.is_image()) { + ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, + "No image encoder registered"); + auto encode_result = image_encoder_->encode(input); + ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); + encoder_output = *encode_result; + + } else if (input.is_text() || input.is_tokens()) { + std::vector tokens; + if (input.is_text()) { + auto encode_result = tokenizer_->encode(input.get_text()); + if (!encode_result.ok()) { + ET_LOG(Error, "Tokenizer encode error %d", + static_cast(encode_result.error())); + return Error::InvalidArgument; + } + tokens = std::move(*encode_result); + } else { + tokens = input.get_tokens(); + } + + const auto actual_seq_len = static_cast(tokens.size()); + + // The token_embedding PTE has a fixed MAX_SEQ_LEN input buffer. + // Pad with zeros, run embedding, then slice output back to actual length. + int64_t max_seq_len = actual_seq_len; // fallback: no padding needed + auto max_seq_len_result = module_->get(kMaxSeqLen); + if (max_seq_len_result.error() == Error::Ok) { + max_seq_len = max_seq_len_result->toScalar().to(); + } + + padded_tokens_storage.assign(max_seq_len, 0); + std::copy(tokens.begin(), tokens.end(), padded_tokens_storage.begin()); + + auto text_tensor = ::executorch::extension::from_blob( + padded_tokens_storage.data(), {1, static_cast(max_seq_len)}, + ::executorch::aten::ScalarType::Long); + + auto embed_result = module_->execute(kTokenEmbeddingMethod, text_tensor); + ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + + auto full_embed = (*embed_result)[0].toTensor(); + const auto embed_dim = static_cast(full_embed.size(2)); + sliced_embed_storage = ::executorch::extension::from_blob( + full_embed.mutable_data_ptr(), {1, actual_seq_len, embed_dim}, + ::executorch::aten::ScalarType::Float); + encoder_output = EValue(*sliced_embed_storage); + + } else { + ET_LOG(Error, "Unsupported MultimodalInput type"); + return Error::NotSupported; + } + + // Run text_decoder for prefill. + int64_t seq_len = encoder_output.toTensor().size(1); + if (seq_len == 0) { + ET_LOG(Error, "Encoder returned empty output"); + return Error::InvalidState; + } + + std::vector cache_positions; + auto cache_pos_result = populate_start_pos_or_cache_position( + module_, start_pos, cache_positions, seq_len, kTextModelMethod); + ET_CHECK_OK_OR_RETURN_ERROR(cache_pos_result.error()); + + auto prefill_result = + module_->execute(kTextModelMethod, {encoder_output, *cache_pos_result}); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error()); + + auto &prefill_outputs = *prefill_result; + ET_CHECK_OR_RETURN_ERROR(!prefill_outputs.empty(), InvalidState, + "text_decoder returned no outputs during prefill"); + + auto logits = prefill_outputs[0].toTensor(); + start_pos += seq_len; + + return static_cast(decoder_runner_->logits_to_token(logits)); +} + +Error MultimodalPrefiller::load() { + if (is_method_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); + + auto method_names_result = module_->method_names(); + ET_CHECK_OK_OR_RETURN_ERROR(method_names_result.error(), + "Failed to get method names"); + const auto &methods = *method_names_result; + + if (methods.find(kVisionEncoderMethod) != methods.end()) { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kVisionEncoderMethod)); + } + return Error::Ok; +} + +bool MultimodalPrefiller::is_method_loaded() { + auto methods_res = module_->method_names(); + if (methods_res.error() != Error::Ok) { + return false; + } + if (!module_->is_method_loaded(kTokenEmbeddingMethod) || + !module_->is_method_loaded(kTextModelMethod)) { + return false; + } + const auto &methods = *methods_res; + if (methods.find(kVisionEncoderMethod) != methods.end()) { + return module_->is_method_loaded(kVisionEncoderMethod); + } + return true; +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.h b/packages/react-native-executorch/common/runner/multimodal_prefiller.h new file mode 100644 index 000000000..5f1978943 --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Ported from executorch/extension/llm/runner/multimodal_prefiller.h + +#pragma once + +#include "multimodal_decoder_runner.h" +#include "multimodal_input.h" +#include +#include +#include + +namespace executorch::extension::llm { + +class MultimodalPrefiller { +public: + explicit MultimodalPrefiller(Module *module, + MultimodalDecoderRunner *decoder_runner, + tokenizers::HFTokenizer *tokenizer, + IOManager *io_manager, + IEncoder *image_encoder = nullptr); + + // Prefill one input segment. Updates start_pos in-place. + // Returns the first predicted token after this segment. + ::executorch::runtime::Result prefill(const MultimodalInput &input, + int64_t &start_pos); + + ::executorch::runtime::Error load(); + bool is_method_loaded(); + +private: + Module *module_; + MultimodalDecoderRunner *decoder_runner_; + tokenizers::HFTokenizer *tokenizer_; + IOManager *io_manager_; + IEncoder *image_encoder_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.cpp b/packages/react-native-executorch/common/runner/multimodal_runner.cpp new file mode 100644 index 000000000..91ab6b181 --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_runner.cpp @@ -0,0 +1,133 @@ +// common/runner/multimodal_runner.cpp +#include "multimodal_runner.h" +#include "constants.h" +#include "util.h" +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::extension::Module; +using ::executorch::runtime::Error; + +MultimodalRunner::MultimodalRunner( + std::unique_ptr module, const std::string &tokenizer_path, + std::map> encoders, + const GenerationConfig &config) + : BaseLLMRunner(std::move(module), tokenizer_path, config), + encoders_(std::move(encoders)) {} + +int32_t MultimodalRunner::get_visual_token_count() const { + auto it = encoders_.find(MultimodalType::Image); + if (it == encoders_.end()) { + return 0; + } + return it->second->encoderTokenCount(); +} + +bool MultimodalRunner::is_loaded() const { + if (!mm_prefiller_ || !mm_token_generator_) + return false; + if (!mm_prefiller_->is_method_loaded() || !mm_token_generator_->is_loaded()) + return false; + for (const auto &[type, encoder] : encoders_) { + if (!encoder->is_loaded()) + return false; + } + return true; +} + +Error MultimodalRunner::load_subcomponents() { + for (auto &[type, encoder] : encoders_) { + ET_CHECK_OK_OR_RETURN_ERROR(encoder->load()); + } + + Stats *stats_ptr = &stats_; + + mm_decoder_runner_ = std::make_unique( + module_.get(), io_manager_.get()); + IEncoder *image_encoder = nullptr; + auto enc_it = encoders_.find(MultimodalType::Image); + if (enc_it != encoders_.end()) { + image_encoder = enc_it->second.get(); + } + mm_prefiller_ = std::make_unique( + module_.get(), mm_decoder_runner_.get(), tokenizer_.get(), + io_manager_.get(), image_encoder); + mm_token_generator_ = std::make_unique( + tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true, + std::move(eos_ids_), stats_ptr); + + ET_CHECK_OK_OR_RETURN_ERROR(mm_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(mm_token_generator_->load()); + + return Error::Ok; +} + +Error MultimodalRunner::generate_internal( + const std::vector &inputs, + std::function token_callback) { + + if (inputs.empty()) + return Error::InvalidArgument; + if (!is_loaded()) + ET_CHECK_OK_OR_RETURN_ERROR(load()); + + stats_.inference_start_ms = time_in_ms(); + + uint64_t prefill_next_token = 0; + for (const auto &input : inputs) { + auto prefill_result = mm_prefiller_->prefill(input, pos_); + if (!prefill_result.ok()) + return prefill_result.error(); + prefill_next_token = prefill_result.get(); + } + + stats_.first_token_ms = time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); + stats_.num_prompt_tokens = pos_; + + int32_t resolved_max_new = resolve_max_new_tokens( + static_cast(pos_), config_.max_seq_len, + config_.max_context_length, config_.max_new_tokens); + + std::vector seed_tokens = {prefill_next_token}; + auto wrapped_callback = [&](const std::string &piece) { + safe_printf(piece.c_str()); + fflush(stdout); + if (token_callback) + token_callback(piece); + }; + + auto generate_result = mm_token_generator_->generate( + seed_tokens, pos_, + static_cast(std::max(0, resolved_max_new - 1)), + config_.temperature, config_.topp, wrapped_callback); + + if (!generate_result.ok()) + return generate_result.error(); + + int64_t num_generated = generate_result.get(); + pos_ += num_generated; + stats_.inference_end_ms = time_in_ms(); + stats_.num_generated_tokens = num_generated; + + return Error::Ok; +} + +void MultimodalRunner::stop_impl() { + if (mm_token_generator_) + mm_token_generator_->stop(); +} + +void MultimodalRunner::set_count_interval_impl(size_t count_interval) { + if (mm_token_generator_) + mm_token_generator_->set_count_interval(count_interval); +} + +void MultimodalRunner::set_time_interval_impl(size_t time_interval) { + if (mm_token_generator_) + mm_token_generator_->set_time_interval(time_interval); +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.h b/packages/react-native-executorch/common/runner/multimodal_runner.h new file mode 100644 index 000000000..3c31c0165 --- /dev/null +++ b/packages/react-native-executorch/common/runner/multimodal_runner.h @@ -0,0 +1,45 @@ +#pragma once + +#include "base_llm_runner.h" +#include "encoders/iencoder.h" +#include "multimodal_decoder_runner.h" +#include "multimodal_input.h" +#include "multimodal_prefiller.h" +#include "text_token_generator.h" +#include + +namespace executorch::extension::llm { + +enum class MultimodalType { Image }; + +class MultimodalRunner : public BaseLLMRunner { +public: + explicit MultimodalRunner( + std::unique_ptr module, const std::string &tokenizer_path, + std::map> encoders, + const GenerationConfig &config = {.temperature = 0.8F, .topp = 0.9F}); + + bool is_loaded() const override; + bool is_multimodal() const override { return true; } + int32_t get_visual_token_count() const override; + + ::executorch::runtime::Error generate_internal( + const std::vector &inputs, + std::function token_callback) override; + +protected: + ::executorch::runtime::Error load_subcomponents() override; + void stop_impl() override; + void set_temperature_impl(float) override {} + void set_topp_impl(float) override {} + void set_count_interval_impl(size_t count_interval) override; + void set_time_interval_impl(size_t time_interval) override; + +private: + std::map> encoders_; + std::unique_ptr mm_decoder_runner_; + std::unique_ptr mm_prefiller_; + std::unique_ptr mm_token_generator_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/runner.cpp b/packages/react-native-executorch/common/runner/runner.cpp deleted file mode 100644 index 8e4660ac5..000000000 --- a/packages/react-native-executorch/common/runner/runner.cpp +++ /dev/null @@ -1,391 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated - */ - -// A simple llama2 runner that includes preprocessing and post processing logic. -// The module takes in a string as input and emits a string as output. - -#include "runner.h" -#include "constants.h" -#include "util.h" -#include -#include -#include - -namespace example { - -using namespace executorch::extension::llm; -using ::executorch::extension::Module; -using ::executorch::runtime::Error; -using ::executorch::runtime::Result; - -Runner::Runner(Module *module, const std::string &tokenizer_path, - const llm::GenerationConfig &config) - : config_(config), module_(module), tokenizer_path_(tokenizer_path), - tokenizer_(std::make_unique()), - metadata_({ - {kEnableDynamicShape, false}, - {kMaxSeqLen, 128}, - {kMaxContextLen, 128}, - {kUseKVCache, true}, - {kUseSDPAWithKVCache, false}, - }) {} - -bool Runner::is_loaded() const { - return module_->is_loaded() && tokenizer_->is_loaded() && - text_decoder_runner_ && text_prefiller_ && text_token_generator_; -} - -Error Runner::load() { - if (is_loaded()) { - return Error::Ok; - } - - auto status = tokenizer_->load(tokenizer_path_); - - if (status != tokenizers::Error::Ok) { - throw rnexecutorch::RnExecutorchError( - rnexecutorch::RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occured while loading tokenizer"); - }; - - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - - ET_LOG(Info, "Reading metadata from model"); - - auto eos_ids = std::make_unique>(); - metadata_[kVocabSize] = tokenizer_->vocab_size(); - - // Load model metadata - const auto method_names = - ET_UNWRAP(module_->method_names(), "Failed reading method names"); - for (auto &pair : metadata_) { - const auto &method_name = pair.first; - auto &value = pair.second; - if (method_names.count(method_name)) { - value = ET_UNWRAP(module_->get(method_name)) - .toScalar() - .to(); - } else { - ET_LOG(Info, "Method %s not found, using the default value %" PRId64, - method_name.c_str(), value); - } - ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); - } - - // Load EOS token ids - if (method_names.count(kEosIds)) { - eos_ids->clear(); - for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) { - auto value = eos_id.toScalar().to(); - eos_ids->emplace(value); - ET_LOG(Info, "eos_id = %" PRId64, value); - } - } - - // Determine missing config values - // If user does not directly specify configuration parameters such as - // max_seq_len (i.e. leaves them as default values), they are determined by - // reading the exported model's methods. - if (config_.max_seq_len < 0) - config_.max_seq_len = static_cast(metadata_.at(kMaxSeqLen)); - if (config_.max_context_length < 0) - config_.max_context_length = - static_cast(metadata_.at(kMaxContextLen)); - if (config_.max_new_tokens < 0) - config_.max_new_tokens = - std::min(config_.max_seq_len, config_.max_context_length); - if (config_.enable_dynamic_shape) - config_.enable_dynamic_shape = - static_cast(metadata_.at(kEnableDynamicShape)); - if (config_.enable_kv_cache) - config_.enable_kv_cache = static_cast(metadata_.at(kUseKVCache)); - - io_manager_ = std::make_unique(*module_); - text_decoder_runner_ = std::make_unique( - module_, io_manager_.get(), config_.temperature, config_.topp); - text_prefiller_ = std::make_unique( - text_decoder_runner_.get(), config_.enable_kv_cache, - config_.enable_dynamic_shape, config_.max_seq_len); - - text_token_generator_ = std::make_unique( - tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache, - std::move(eos_ids), &stats_); - - return Error::Ok; -} - -// Don't print with the same priority during warmup -#define RUNNER_ET_LOG(warmup, format, ...) \ - if (warmup) { \ - ET_LOG(Debug, format, __VA_ARGS__); \ - } else { \ - ET_LOG(Info, format, __VA_ARGS__); \ - } - -Error Runner::generate(const std::string &prompt, - const llm::GenerationConfig &generation_config, - std::function token_callback, - std::function stats_callback) { - // Prepare the inputs. - // Use ones-initialized inputs. - ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); - if (!is_loaded()) { - stats_.model_load_start_ms = llm::time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_.model_load_end_ms = llm::time_in_ms(); - } - - if (generation_config.warming) { - ET_LOG(Info, "Doing a warmup run..."); - } - - RUNNER_ET_LOG(generation_config.warming, - "RSS after loading model: %f MiB (0 if unsupported)", - llm::get_rss_bytes() / 1024.0 / 1024.0); - - // Wrap the token_callback with print function - std::function wrapped_callback = - [token_callback, &generation_config](const std::string &piece) { - if (!generation_config.warming) { - llm::safe_printf(piece.c_str()); - fflush(stdout); - } - if (token_callback) { - token_callback(piece); - } - }; - // First token time only measures the time it takes to encode the prompt and - // return a response token. - - stats_.inference_start_ms = llm::time_in_ms(); - shouldStop_ = false; - - // Override main config fields with given generation config if specified - int32_t max_seq_len = generation_config.max_seq_len >= 0 - ? generation_config.max_seq_len - : config_.max_seq_len; - int32_t max_context_length = generation_config.max_context_length >= 0 - ? generation_config.max_context_length - : config_.max_context_length; - int32_t new_tokens_limit = generation_config.max_new_tokens >= 0 - ? generation_config.max_new_tokens - : config_.max_new_tokens; - float temperature = generation_config.temperature >= 0.F - ? generation_config.temperature - : config_.temperature; - float topp = - generation_config.topp >= 0.F ? generation_config.topp : config_.topp; - - int64_t context_len_left = static_cast(max_context_length) - pos_; - - // If the used tokenizer.json has defined post_processor field, - // setting any of bos or eos arguments to value other than provided constant - // ( which is 0) will result in running the post_processor with - // 'add_special_token' flag - auto encodeResult = - tokenizer_->encode(prompt, numOfAddedBoSTokens, numOfAddedEoSTokens); - if (!encodeResult.ok()) { - throw rnexecutorch::RnExecutorchError( - rnexecutorch::RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occured while encoding: " + - std::to_string(static_cast(encodeResult.error()))); - } - std::vector prompt_tokens = encodeResult.get(); - - std::vector prompt_tokens_uint64(prompt_tokens.begin(), - prompt_tokens.end()); - - // encode the (string) prompt into tokens sequence - int num_prompt_tokens = prompt_tokens.size(); - - ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument, - "Expected at least 1 prompt token"); - ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < max_seq_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId32 - ", Max seq length exceeded - please increase max " - "seq len value in your export script", - num_prompt_tokens, max_seq_len); - - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. - int32_t max_new_tokens = resolve_max_new_tokens( - num_prompt_tokens, max_seq_len, static_cast(context_len_left), - new_tokens_limit); - - ET_LOG(Info, - "Max new tokens resolved: %d, given pos_ %" PRId64 - ", num_prompt_tokens %zu, max_context_len %" PRId64, - max_new_tokens, pos_, prompt_tokens.size(), - static_cast(max_context_length)); - ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument, - "Max new tokens %d is less than or equal to 0", - max_new_tokens); - - // Prefill first - // Here feed all tokens to the model and get the next predicted token - // after the prompt. After that we will enter generate loop. - - // print prompts - if (generation_config.echo) { - wrapped_callback(prompt); - } - auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos_); - stats_.first_token_ms = llm::time_in_ms(); - stats_.prompt_eval_end_ms = llm::time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - uint64_t cur_token = prefill_res.get(); - auto decodeResult = tokenizer_->decode({cur_token}); - if (!decodeResult.ok()) { - throw rnexecutorch::RnExecutorchError( - rnexecutorch::RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occured while decoding: " + - std::to_string(static_cast(decodeResult.error()))); - } - const std::string cur_decoded = decodeResult.get(); - RUNNER_ET_LOG(generation_config.warming, - "RSS after prompt prefill: %f MiB (0 if unsupported)", - llm::get_rss_bytes() / 1024.0 / 1024.0); - - // start the main loop - prompt_tokens_uint64.push_back(cur_token); - int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate( - prompt_tokens_uint64, pos_, max_new_tokens - 1, temperature, topp, - wrapped_callback)); - - pos_ += num_generated_tokens; - - stats_.inference_end_ms = llm::time_in_ms(); - if (!generation_config.warming) { - printf("\n"); - } - RUNNER_ET_LOG( - generation_config.warming, - "RSS after finishing text generation: %f MiB (0 if unsupported)", - llm::get_rss_bytes() / 1024.0 / 1024.0); - - if (num_generated_tokens == max_new_tokens) { - RUNNER_ET_LOG(generation_config.warming, "Max new tokens %i reached!", - max_new_tokens); - } - - stats_.num_prompt_tokens = num_prompt_tokens; - stats_.num_generated_tokens = num_generated_tokens; - - if (generation_config.warming) { - ET_LOG(Info, "Warmup run finished!"); - } else { - // Do not print report during warmup -#ifndef TEST_BUILD - ::executorch::llm::print_report(stats_); -#endif - } - if (stats_callback) { - stats_callback(stats_); - } - - return Error::Ok; -} - -Error Runner::warmup(const std::string &prompt) { - // Create a GenerationConfig for warmup - llm::GenerationConfig config{.echo = false, .warming = true}; - - // Call generate with the warmup config - Error err = generate(prompt, config, - /*token_callback=*/nullptr, - /*stats_callbak=*/nullptr); - - // Reset stats after warmup - reset(); - - return err; -} - -void Runner::stop() { - if (is_loaded()) { - text_token_generator_->stop(); - } else { - ET_LOG(Error, "Token generator is not loaded, cannot stop"); - } -} - -void Runner::reset() { - stats_.reset(); - pos_ = 0; -} - -void Runner::set_count_interval(size_t count_interval) { - text_token_generator_->set_count_interval(count_interval); -} - -void Runner::set_time_interval(size_t time_interval) { - text_token_generator_->set_time_interval(time_interval); -} - -void Runner::set_temperature(float temperature) noexcept { - config_.temperature = temperature; - if (text_decoder_runner_) { - text_decoder_runner_->set_temperature(temperature); - } -} - -void Runner::set_topp(float topp) noexcept { - config_.topp = topp; - if (text_decoder_runner_) { - text_decoder_runner_->set_topp(topp); - } -} - -int32_t Runner::get_max_context_length() const { - if (!is_loaded()) { - return metadata_.at(kMaxContextLen); - } - return config_.max_context_length; -} - -int32_t Runner::count_text_tokens(const std::string &text) const { - auto encodeResult = - tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens); - - if (!encodeResult.ok()) { - throw rnexecutorch::RnExecutorchError( - rnexecutorch::RnExecutorchErrorCode::TokenizerError, - "Encoding failed during token count check."); - } - - return encodeResult.get().size(); -} - -int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens, - int32_t max_seq_len, - int32_t max_context_len, - int32_t max_new_tokens) const { - int32_t result; - - if (max_seq_len == -1 && max_new_tokens == -1) { - // Both are -1, use max context len minus prompt tokens - result = max_context_len - num_prompt_tokens; - } else if (max_seq_len == -1 && max_new_tokens != -1) { - // Only max_new_tokens is specified - result = std::min(max_new_tokens, max_context_len - num_prompt_tokens); - } else if (max_seq_len != -1 && max_new_tokens == -1) { - // Only seq_len is specified - result = std::min(max_seq_len, max_context_len) - num_prompt_tokens; - } else { - // Both are specified - result = - std::min(std::min(max_seq_len, max_context_len) - num_prompt_tokens, - max_new_tokens); - } - - // Ensure result is not negative - return std::max(0, result); -} - -} // namespace example diff --git a/packages/react-native-executorch/common/runner/runner.h b/packages/react-native-executorch/common/runner/runner.h deleted file mode 100644 index 03dff39bc..000000000 --- a/packages/react-native-executorch/common/runner/runner.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple llama2 runner that includes preprocessing and post processing logic. -// The module takes in a string as input and emits a string as output. - -#pragma once - -#include "irunner.h" -#include "stats.h" -#include "text_decoder_runner.h" -#include "text_prefiller.h" -#include "text_token_generator.h" -#include -#include -#include -#include -#include -#include -#include -#include - -namespace example { - -namespace llm = ::executorch::extension::llm; - -class Runner : public llm::IRunner { -public: - explicit Runner(::executorch::extension::Module *module, - const std::string &tokenizer_path, - const llm::GenerationConfig &config = { - .temperature = 0.8F, .topp = 0.9F}); // The main config - - bool is_loaded() const override; - ::executorch::runtime::Error load() override; - ::executorch::runtime::Error generate( - const std::string &prompt, - const llm::GenerationConfig &generation_config = - {}, // An extra config which temporarily overrides previous model - // settings - std::function token_callback = {}, - std::function stats_callback = {}) override; - ::executorch::runtime::Error warmup(const std::string &prompt); - void set_count_interval(size_t count_interval); - void set_time_interval(size_t time_interval); - void set_temperature(float temperature) noexcept; - void set_topp(float topp) noexcept; - int32_t count_text_tokens(const std::string &text) const; - int32_t get_max_context_length() const; - - void stop() override; - void reset() override; - - llm::Stats stats_; - -private: - // Helper functions - int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len, - int32_t max_context_len, - int32_t max_new_tokens = -1) const; - - // Main config - llm::GenerationConfig config_; - - // Flow control - bool shouldStop_{false}; - int64_t pos_ = 0; // The position in KV cache of the input, starting from 0. - - // Main model - ::executorch::extension::Module *module_; - - // Subcomponents - std::string tokenizer_path_; - std::unique_ptr tokenizer_; - std::unordered_map metadata_; - std::unique_ptr io_manager_; - std::unique_ptr text_decoder_runner_; - std::unique_ptr text_prefiller_; - std::unique_ptr text_token_generator_; -}; - -} // namespace example diff --git a/packages/react-native-executorch/common/runner/text_runner.cpp b/packages/react-native-executorch/common/runner/text_runner.cpp new file mode 100644 index 000000000..5b9dd441e --- /dev/null +++ b/packages/react-native-executorch/common/runner/text_runner.cpp @@ -0,0 +1,152 @@ +// common/runner/text_runner.cpp +#include "text_runner.h" +#include "constants.h" +#include "util.h" +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::extension::Module; +using ::executorch::runtime::Error; + +TextRunner::TextRunner(std::unique_ptr module, + const std::string &tokenizer_path, + const GenerationConfig &config) + : BaseLLMRunner(std::move(module), tokenizer_path, config) {} + +bool TextRunner::is_loaded() const { + return module_ && module_->is_loaded() && tokenizer_ && + tokenizer_->is_loaded() && text_decoder_runner_ && text_prefiller_ && + text_token_generator_; +} + +Error TextRunner::load_subcomponents() { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); + + Stats *stats_ptr = &stats_; + + text_decoder_runner_ = std::make_unique( + module_.get(), io_manager_.get(), config_.temperature, config_.topp); + text_prefiller_ = std::make_unique( + text_decoder_runner_.get(), config_.enable_kv_cache, + config_.enable_dynamic_shape, config_.max_seq_len); + text_token_generator_ = std::make_unique( + tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache, + std::move(eos_ids_), stats_ptr); + + return Error::Ok; +} + +Error TextRunner::generate_internal( + const std::vector &inputs, + std::function token_callback) { + + if (inputs.empty()) { + return Error::InvalidArgument; + } + + const std::string &prompt = inputs[0].get_text(); + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + + if (!is_loaded()) { + stats_.model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_.model_load_end_ms = time_in_ms(); + } + + std::function wrapped_callback = + [token_callback](const std::string &piece) { + safe_printf(piece.c_str()); + fflush(stdout); + if (token_callback) + token_callback(piece); + }; + + stats_.inference_start_ms = time_in_ms(); + + int64_t context_len_left = + static_cast(config_.max_context_length) - pos_; + + auto encodeResult = + tokenizer_->encode(prompt, numOfAddedBoSTokens, numOfAddedEoSTokens); + if (!encodeResult.ok()) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while encoding: " + + std::to_string(static_cast(encodeResult.error()))); + } + std::vector prompt_tokens = encodeResult.get(); + int num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument, + "Expected at least 1 prompt token"); + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < config_.max_seq_len, + InvalidArgument, + "num_prompt_tokens %d >= max_seq_len %" PRId32, + num_prompt_tokens, config_.max_seq_len); + + int32_t max_new_tokens = resolve_max_new_tokens( + num_prompt_tokens, config_.max_seq_len, + static_cast(context_len_left), config_.max_new_tokens); + + ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument, + "Max new tokens %d is <= 0", max_new_tokens); + + if (config_.echo) + wrapped_callback(prompt); + + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + stats_.first_token_ms = time_in_ms(); + stats_.prompt_eval_end_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + + uint64_t cur_token = prefill_res.get(); + auto decodeResult = tokenizer_->decode({cur_token}); + if (!decodeResult.ok()) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while decoding: " + + std::to_string(static_cast(decodeResult.error()))); + } + + prompt_tokens.push_back(cur_token); + int64_t num_generated = ET_UNWRAP(text_token_generator_->generate( + prompt_tokens, pos_, max_new_tokens - 1, config_.temperature, + config_.topp, wrapped_callback)); + + pos_ += num_generated; + stats_.inference_end_ms = time_in_ms(); + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = num_generated; + + return Error::Ok; +} + +void TextRunner::stop_impl() { + if (text_token_generator_) + text_token_generator_->stop(); +} + +void TextRunner::set_temperature_impl(float temperature) { + if (text_decoder_runner_) + text_decoder_runner_->set_temperature(temperature); +} + +void TextRunner::set_topp_impl(float topp) { + if (text_decoder_runner_) + text_decoder_runner_->set_topp(topp); +} + +void TextRunner::set_count_interval_impl(size_t count_interval) { + if (text_token_generator_) + text_token_generator_->set_count_interval(count_interval); +} + +void TextRunner::set_time_interval_impl(size_t time_interval) { + if (text_token_generator_) + text_token_generator_->set_time_interval(time_interval); +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/text_runner.h b/packages/react-native-executorch/common/runner/text_runner.h new file mode 100644 index 000000000..4fce0e815 --- /dev/null +++ b/packages/react-native-executorch/common/runner/text_runner.h @@ -0,0 +1,38 @@ +// common/runner/text_runner.h +#pragma once + +#include "base_llm_runner.h" +#include "text_decoder_runner.h" +#include "text_prefiller.h" +#include "text_token_generator.h" + +namespace executorch::extension::llm { + +class TextRunner : public BaseLLMRunner { +public: + explicit TextRunner(std::unique_ptr<::executorch::extension::Module> module, + const std::string &tokenizer_path, + const GenerationConfig &config = {.temperature = 0.8F, + .topp = 0.9F}); + + bool is_loaded() const override; + + ::executorch::runtime::Error generate_internal( + const std::vector &inputs, + std::function token_callback) override; + +protected: + ::executorch::runtime::Error load_subcomponents() override; + void stop_impl() override; + void set_temperature_impl(float temperature) override; + void set_topp_impl(float topp) override; + void set_count_interval_impl(size_t count_interval) override; + void set_time_interval_impl(size_t time_interval) override; + +private: + std::unique_ptr text_decoder_runner_; + std::unique_ptr text_prefiller_; + std::unique_ptr text_token_generator_; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 499abf63a..a59c1d7f3 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -371,6 +371,22 @@ export const LFM2_5_1_2B_INSTRUCT_QUANTIZED = { tokenizerConfigSource: LFM2_5_1_2B_TOKENIZER_CONFIG, }; +// LFM2.5-VL-1.6B (Vision-Language) +const LFM2_VL_1_6B_QUANTIZED_MODEL = `https://huggingface.co/nklockiewicz/lfm2-vl-et/resolve/main/lfm2_5_vl_quantized_xnnpack_v2.pte`; +const LFM2_VL_TOKENIZER = `https://huggingface.co/nklockiewicz/lfm2-vl-et/resolve/main/tokenizer_2.5.json`; +const LFM2_VL_TOKENIZER_CONFIG = `https://huggingface.co/nklockiewicz/lfm2-vl-et/resolve/main/tokenizer_config_2_5.json`; + +/** + * @category Models - VLM + */ +export const LFM2_VL_1_6B_QUANTIZED = { + modelName: 'lfm2.5-vl-1.6b-quantized', + capabilities: ['vision'] as const, + modelSource: LFM2_VL_1_6B_QUANTIZED_MODEL, + tokenizerSource: LFM2_VL_TOKENIZER, + tokenizerConfigSource: LFM2_VL_TOKENIZER_CONFIG, +}; + // Classification const EFFICIENTNET_V2_S_MODEL = Platform.OS === `ios` diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index 702a00c45..7422c936c 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -5,6 +5,7 @@ import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults'; import { ChatConfig, GenerationConfig, + LLMCapability, LLMTool, Message, SPECIAL_TOKENS, @@ -24,7 +25,6 @@ export class LLMController { private _isReady = false; private _isGenerating = false; private _messageHistory: Message[] = []; - // User callbacks private tokenCallback: (token: string) => void; private messageHistoryCallback: (messageHistory: Message[]) => void; @@ -75,11 +75,13 @@ export class LLMController { modelSource, tokenizerSource, tokenizerConfigSource, + capabilities, onDownloadProgressCallback, }: { modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource; + capabilities?: readonly LLMCapability[]; onDownloadProgressCallback?: (downloadProgress: number) => void; }) { // reset inner state when loading new model @@ -118,7 +120,11 @@ export class LLMController { this.tokenizerConfig = JSON.parse( await ResourceFetcher.fs.readAsString(tokenizerConfigPath!) ); - this.nativeModule = global.loadLLM(modelPath, tokenizerPath); + this.nativeModule = global.loadLLM( + modelPath, + tokenizerPath, + capabilities ?? [] + ); this.isReadyCallback(true); this.onToken = (data: string) => { if (!data) { @@ -212,7 +218,7 @@ export class LLMController { this.isGeneratingCallback(false); } - public async forward(input: string): Promise { + public async forward(input: string, imagePaths?: string[]): Promise { if (!this._isReady) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, @@ -228,7 +234,15 @@ export class LLMController { try { this.isGeneratingCallback(true); this.nativeModule.reset(); - const response = await this.nativeModule.generate(input, this.onToken); + const response = + imagePaths && imagePaths.length > 0 + ? await this.nativeModule.generateMultimodal( + input, + imagePaths, + this.tokenizerConfig?.image_token ?? '', + this.onToken + ) + : await this.nativeModule.generate(input, this.onToken); return this.filterSpecialTokens(response); } catch (e) { throw parseUnknownError(e); @@ -273,7 +287,8 @@ export class LLMController { public async generate( messages: Message[], - tools?: LLMTool[] + tools?: LLMTool[], + imagePaths?: string[] ): Promise { if (!this._isReady) { throw new RnExecutorchError( @@ -301,16 +316,35 @@ export class LLMController { { tools_in_user_message: false, add_generation_prompt: true } ); - return await this.forward(renderedChat); + return await this.forward(renderedChat, imagePaths); } - public async sendMessage(message: string): Promise { - const updatedHistory = [ - ...this._messageHistory, - { content: message, role: 'user' as const }, - ]; + public async sendMessage( + message: string, + media?: { imagePath?: string } + ): Promise { + const mediaPath = media?.imagePath; + const newMessage: Message = { + content: message, + role: 'user', + ...(mediaPath ? { mediaPath } : {}), + }; + const updatedHistory = [...this._messageHistory, newMessage]; this.messageHistoryCallback(updatedHistory); + const historyForTemplate = updatedHistory.map((m) => + m.mediaPath + ? { + ...m, + content: [ + { type: 'image' }, + { type: 'text', text: m.content }, + ] as any, + } + : m + ); + + const visualTokenCount = this.nativeModule.getVisualTokenCount(); const countTokensCallback = (messages: Message[]) => { const rendered = this.applyChatTemplate( messages, @@ -319,20 +353,27 @@ export class LLMController { // eslint-disable-next-line camelcase { tools_in_user_message: false, add_generation_prompt: true } ); - return this.nativeModule.countTextTokens(rendered); + const textTokens = this.nativeModule.countTextTokens(rendered); + const imageCount = messages.filter((m) => m.mediaPath).length; + return textTokens + imageCount * (visualTokenCount - 1); }; const maxContextLength = this.nativeModule.getMaxContextLength(); const messageHistoryWithPrompt = this.chatConfig.contextStrategy.buildContext( this.chatConfig.systemPrompt, - updatedHistory, + historyForTemplate, maxContextLength, countTokensCallback ); + const imagePaths = messageHistoryWithPrompt + .filter((m) => m.mediaPath) + .map((m) => m.mediaPath!); + const response = await this.generate( messageHistoryWithPrompt, - this.toolsConfig?.tools + this.toolsConfig?.tools, + imagePaths.length > 0 ? imagePaths : undefined ); if (!this.toolsConfig || this.toolsConfig.displayToolCalls) { @@ -341,24 +382,23 @@ export class LLMController { { content: response, role: 'assistant' }, ]); } - if (!this.toolsConfig) { - return response; - } - const toolCalls = parseToolCall(response); - - for (const toolCall of toolCalls) { - this.toolsConfig - .executeToolCallback(toolCall) - .then((toolResponse: string | null) => { - if (toolResponse) { - this.messageHistoryCallback([ - ...this._messageHistory, - { content: toolResponse, role: 'assistant' }, - ]); - } - }); + if (this.toolsConfig) { + const toolCalls = parseToolCall(response); + for (const toolCall of toolCalls) { + this.toolsConfig + .executeToolCallback(toolCall) + .then((toolResponse: string | null) => { + if (toolResponse) { + this.messageHistoryCallback([ + ...this._messageHistory, + { content: toolResponse, role: 'assistant' }, + ]); + } + }); + } } + return response; } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 5578c1de7..72c7f4d96 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -1,9 +1,11 @@ import { useCallback, useEffect, useState } from 'react'; import { + LLMCapability, LLMConfig, LLMProps, LLMTool, LLMType, + LLMTypeMultimodal, Message, } from '../../types/llm'; import { LLMController } from '../../controllers/LLMController'; @@ -14,9 +16,16 @@ import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; * * @category Hooks * @param model - Object containing model, tokenizer, and tokenizer config sources. - * @returns An object implementing the `LLMType` interface for interacting with the LLM. + * @returns An object implementing the `LLMTypeMultimodal` interface when `model.capabilities` is provided, otherwise `LLMType`. */ -export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { +export function useLLM( + props: LLMProps & { model: { capabilities: C } } +): LLMTypeMultimodal; +export function useLLM(props: LLMProps): LLMType; +export function useLLM({ + model, + preventLoad = false, +}: LLMProps): LLMType | LLMTypeMultimodal { const [token, setToken] = useState(''); const [response, setResponse] = useState(''); const [messageHistory, setMessageHistory] = useState([]); @@ -24,6 +33,7 @@ export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); + const capabilitiesKey = model.capabilities?.join(',') ?? ''; const tokenCallback = useCallback((newToken: string) => { setToken(newToken); @@ -52,6 +62,7 @@ export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { modelSource: model.modelSource, tokenizerSource: model.tokenizerSource, tokenizerConfigSource: model.tokenizerConfigSource!, + capabilities: model.capabilities, onDownloadProgressCallback: setDownloadProgress, }); } catch (e) { @@ -64,11 +75,13 @@ export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { controllerInstance.delete(); } }; + // eslint-disable-next-line react-hooks/exhaustive-deps }, [ controllerInstance, model.modelSource, model.tokenizerSource, model.tokenizerConfigSource, + capabilitiesKey, // intentional: serialized string to avoid array reference re-runs preventLoad, ]); @@ -84,17 +97,17 @@ export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { ); const generate = useCallback( - (messages: Message[], tools?: LLMTool[]) => { + (messages: Message[], tools?: LLMTool[], imagePaths?: string[]) => { setResponse(''); - return controllerInstance.generate(messages, tools); + return controllerInstance.generate(messages, tools, imagePaths); }, [controllerInstance] ); const sendMessage = useCallback( - (message: string) => { + (message: string, media?: { imagePath?: string; audioPath?: string }) => { setResponse(''); - return controllerInstance.sendMessage(message); + return controllerInstance.sendMessage(message, media); }, [controllerInstance] ); @@ -141,4 +154,4 @@ export const useLLM = ({ model, preventLoad = false }: LLMProps): LLMType => { deleteMessage: deleteMessage, interrupt: interrupt, }; -}; +} diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index dd7557ca2..aa1e2b14d 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -4,6 +4,7 @@ import { ResourceFetcherAdapter, } from './utils/ResourceFetcher'; import { Triple } from './types/common'; +import { LLMCapability } from './types/llm'; /** * Configuration that goes to the `initExecutorch`. * You can pass either bare React Native or Expo configuration. @@ -48,7 +49,11 @@ declare global { var loadImageEmbeddings: (source: string) => any; var loadVAD: (source: string) => any; var loadTextEmbeddings: (modelSource: string, tokenizerSource: string) => any; - var loadLLM: (modelSource: string, tokenizerSource: string) => any; + var loadLLM: ( + modelSource: string, + tokenizerSource: string, + capabilities: readonly LLMCapability[] + ) => any; var loadTextToImage: ( tokenizerSource: string, encoderSource: string, diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 25d87e248..ac57355f1 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -1,6 +1,19 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from './common'; +/** + * Capabilities a multimodal LLM can have. + * @category Types + */ +export type LLMCapability = 'vision'; + +/** + * Derives the media argument shape for `sendMessage` from a capabilities tuple. + * @category Types + */ +export type MediaArg = + 'vision' extends C[number] ? { imagePath?: string } : object; + /** * Properties for initializing and configuring a Large Language Model (LLM) instance. * @@ -19,7 +32,13 @@ export interface LLMProps { /** * `ResourceSource` pointing to the JSON file which contains the tokenizer config. */ - tokenizerConfigSource?: ResourceSource; + tokenizerConfigSource: ResourceSource; + /** + * Optional list of modality capabilities the model supports. + * Determines the type of the `media` argument in `sendMessage`. + * Example: `['vision']` enables `sendMessage(text, { imagePath })`. + */ + capabilities?: readonly LLMCapability[]; }; /** * Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. @@ -28,11 +47,11 @@ export interface LLMProps { } /** - * React hook for managing a Large Language Model (LLM) instance. + * Base return type for `useLLM`. Contains all fields except `sendMessage`. * * @category Types */ -export interface LLMType { +export interface LLMTypeBase { /** * History containing all messages in conversation. This field is updated after model responds to sendMessage. */ @@ -89,9 +108,13 @@ export interface LLMType { * @param tools - Optional array of tools that can be used during generation. * @returns The generated tokens as `string`. */ - generate: (messages: Message[], tools?: LLMTool[]) => Promise; + generate: ( + messages: Message[], + tools?: LLMTool[], + imagePaths?: string[] + ) => Promise; /** - * Returns the number of total tokens from the previous generation.This is a sum of prompt tokens and generated tokens. + * Returns the number of total tokens from the previous generation. This is a sum of prompt tokens and generated tokens. * * @returns The count of prompt and generated tokens. */ @@ -103,15 +126,6 @@ export interface LLMType { */ getPromptTokenCount: () => number; - /** - * Function to add user message to conversation. - * After model responds, `messageHistory` will be updated with both user message and model response. - * - * @param message - The message string to send. - * @returns The model's response as a `string`. - */ - sendMessage: (message: string) => Promise; - /** * Deletes all messages starting with message on `index` position. After deletion `messageHistory` will be updated. * @@ -125,6 +139,43 @@ export interface LLMType { interrupt: () => void; } +/** + * Return type for `useLLM` when `model.capabilities` is provided. + * `sendMessage` accepts a typed `media` object based on declared capabilities. + * @category Types + */ +export interface LLMTypeMultimodal< + C extends readonly LLMCapability[] = readonly LLMCapability[], +> extends LLMTypeBase { + /** + * Function to add user message to conversation. + * Pass a `media` object whose shape is determined by the declared capabilities. + * After model responds, `messageHistory` will be updated. + * + * @param message - The message string to send. + * @param media - Optional media object (e.g. `{ imagePath }` for vision. + * @returns The model's response as a `string`. + */ + sendMessage: (message: string, media?: MediaArg) => Promise; +} + +/** + * Return type for `useLLM` when `model.isMultimodal` is absent or `false`. + * `sendMessage` accepts only text. + * + * @category Types + */ +export interface LLMType extends LLMTypeBase { + /** + * Function to add user message to conversation. + * After model responds, `messageHistory` will be updated. + * + * @param message - The message string to send. + * @returns The model's response as a `string`. + */ + sendMessage: (message: string) => Promise; +} + /** * Configuration object for initializing and customizing a Large Language Model (LLM) instance. * @@ -184,6 +235,11 @@ export type MessageRole = 'user' | 'assistant' | 'system'; export interface Message { role: MessageRole; content: string; + /** + * Optional local file path to media (image, audio, etc.). + * Only valid on `user` messages. + */ + mediaPath?: string; } /** diff --git a/yarn.lock b/yarn.lock index f839c07a6..c2f2e609c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8721,6 +8721,15 @@ __metadata: languageName: node linkType: hard +"expo-document-picker@npm:~13.0.3": + version: 13.0.3 + resolution: "expo-document-picker@npm:13.0.3" + peerDependencies: + expo: "*" + checksum: 10/a336310e6327d26f36ac19b5867e2ef453dd59a0e30f7b2854c34bc1f874d967f92ced4e0b5fddc2b193ba1d88059033e6f3b076980c060169b191f4af184f90 + languageName: node + linkType: hard + "expo-file-system@npm:^19.0.20, expo-file-system@npm:~19.0.21": version: 19.0.21 resolution: "expo-file-system@npm:19.0.21" @@ -11450,6 +11459,7 @@ __metadata: expo-brightness: "npm:~14.0.8" expo-calendar: "npm:~15.0.8" expo-constants: "npm:~18.0.11" + expo-document-picker: "npm:~13.0.3" expo-font: "npm:~14.0.10" expo-linking: "npm:~8.0.10" expo-router: "npm:~6.0.17" @@ -11461,6 +11471,7 @@ __metadata: react-native-device-info: "npm:^15.0.2" react-native-executorch: "workspace:*" react-native-gesture-handler: "npm:~2.28.0" + react-native-image-picker: "npm:^7.2.2" react-native-loading-spinner-overlay: "npm:^3.0.1" react-native-markdown-display: "npm:^7.0.2" react-native-reanimated: "npm:~4.1.1"