Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <algorithm>
#include <memory>
#include <span>
#include <string>
Expand Down Expand Up @@ -44,6 +45,14 @@ class DurationPredictor : public BaseModel {
// Returns maximum supported amount of input tokens.
size_t getTokensLimit() const;

// Returns the token count of the forward method that would be selected
// for a given input size. E.g., input 37 -> returns 64 (forward_64).
size_t getMethodTokenCount(size_t inputSize) const {
auto it = std::ranges::find_if(forwardMethods_,
[inputSize](const auto &e) { return e.second >= inputSize; });
return (it != forwardMethods_.end()) ? it->second : forwardMethods_.back().second;
}

private:
// Helper function - duration scalling
// Performs integer scaling on the durations tensor to ensure the sum of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,50 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource,

context_.inputTokensLimit = durationPredictor_.getTokensLimit();
context_.inputDurationLimit = synthesizer_.getDurationLimit();

// Cap effective token limit to prevent the Synthesizer's attention from
// drifting on longer sequences, which manifests as progressive speed-up
// in the generated audio. Shorter chunks keep timing faithful to the
// Duration Predictor's output.
static constexpr size_t kSafeTokensLimit = 60;
context_.inputTokensLimit =
std::min(context_.inputTokensLimit, kSafeTokensLimit);
}

void Kokoro::loadVoice(const std::string &voiceSource) {
constexpr size_t rows = static_cast<size_t>(constants::kMaxInputTokens);
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize); // 256
const size_t expectedCount = rows * cols;
const std::streamsize expectedBytes =
static_cast<std::streamsize>(expectedCount * sizeof(float));
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize);
constexpr size_t bytesPerRow = cols * sizeof(float);

std::ifstream in(voiceSource, std::ios::binary);
if (!in) {
throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: cannot open file: " +
"[Kokoro::loadVoice]: cannot open file: " +
voiceSource);
}

// Check the file size
// Determine number of rows from file size
in.seekg(0, std::ios::end);
const std::streamsize fileSize = in.tellg();
const auto fileSize = static_cast<size_t>(in.tellg());
in.seekg(0, std::ios::beg);
if (fileSize < expectedBytes) {

if (fileSize < bytesPerRow) {
throw RnExecutorchError(
RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: file too small: expected at least " +
std::to_string(expectedBytes) + " bytes, got " +
"[Kokoro::loadVoice]: file too small: need at least " +
std::to_string(bytesPerRow) + " bytes for one row, got " +
std::to_string(fileSize));
}

// Read [rows, 1, cols] as contiguous floats directly into voice_
// ([rows][cols])
if (!in.read(reinterpret_cast<char *>(voice_.data()->data()),
expectedBytes)) {
const size_t rows = fileSize / bytesPerRow;
const auto readBytes = static_cast<std::streamsize>(rows * bytesPerRow);

// Resize voice vector to hold all rows from the file
voice_.resize(rows);

if (!in.read(reinterpret_cast<char *>(voice_.data()->data()), readBytes)) {
throw RnExecutorchError(
RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: failed to read voice weights");
"[Kokoro::loadVoice]: failed to read voice weights");
}
}

Expand Down Expand Up @@ -98,13 +107,10 @@ std::vector<float> Kokoro::generate(std::string text, float speed) {
size_t pauseMs = params::kPauseValues.contains(lastPhoneme)
? params::kPauseValues.at(lastPhoneme)
: params::kDefaultPause;
std::vector<float> pause(pauseMs * constants::kSamplesPerMilisecond, 0.F);

// Add audio part and pause to the main audio vector
audio.insert(audio.end(), std::make_move_iterator(audioPart.begin()),
std::make_move_iterator(audioPart.end()));
audio.insert(audio.end(), std::make_move_iterator(pause.begin()),
std::make_move_iterator(pause.end()));
// Add audio part and silence pause to the main audio vector
audio.insert(audio.end(), audioPart.begin(), audioPart.end());
audio.resize(audio.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F);
}

return audio;
Expand All @@ -118,12 +124,13 @@ void Kokoro::stream(std::string text, float speed,
}

// Build a full callback function
auto nativeCallback = [this, callback](const std::vector<float> &audioVec) {
auto nativeCallback = [this, callback](std::vector<float> audioVec) {
if (this->isStreaming_) {
this->callInvoker_->invokeAsync([callback, audioVec](jsi::Runtime &rt) {
callback->call(rt,
rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
});
this->callInvoker_->invokeAsync(
[callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) {
callback->call(
rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
});
}
};

Expand Down Expand Up @@ -166,14 +173,12 @@ void Kokoro::stream(std::string text, float speed,
size_t pauseMs = params::kPauseValues.contains(lastPhoneme)
? params::kPauseValues.at(lastPhoneme)
: params::kDefaultPause;
std::vector<float> pause(pauseMs * constants::kSamplesPerMilisecond, 0.F);

// Add pause to the audio vector
audioPart.insert(audioPart.end(), std::make_move_iterator(pause.begin()),
std::make_move_iterator(pause.end()));
// Append silence pause directly
audioPart.resize(audioPart.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F);

// Push the audio right away to the JS side
nativeCallback(audioPart);
nativeCallback(std::move(audioPart));
}

// Mark the end of the streaming process
Expand All @@ -188,41 +193,62 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
return {};
}

// Clamp the input to not go beyond number of input token limits
// Note that 2 tokens are always reserved for pre- and post-fix padding,
// so we effectively take at most (maxNoInputTokens_ - 2) tokens.
size_t noTokens = std::clamp(phonemes.size() + 2, constants::kMinInputTokens,
// Clamp token count: phonemes + 2 padding tokens (leading + trailing zero)
size_t dpTokens = std::clamp(phonemes.size() + 2,
constants::kMinInputTokens,
context_.inputTokensLimit);

// Map phonemes to tokens
const auto tokens = utils::tokenize(phonemes, {noTokens});
// Map phonemes to tokens, padded to dpTokens
auto tokens = utils::tokenize(phonemes, {dpTokens});

// Select the appropriate voice vector
size_t voiceID = std::min(phonemes.size() - 1, noTokens);
size_t voiceID = std::min({phonemes.size() - 1, dpTokens - 1,
voice_.size() - 1});
auto &voice = voice_[voiceID];

// Initialize text mask
// Exclude all the paddings apart from first and last one.
size_t realInputLength = std::min(phonemes.size() + 2, noTokens);
std::vector<uint8_t> textMask(noTokens, false);
// Initialize text mask for DP
size_t realInputLength = std::min(phonemes.size() + 2, dpTokens);
std::vector<uint8_t> textMask(dpTokens, false);
std::fill(textMask.begin(), textMask.begin() + realInputLength, true);

// Inference 1 - DurationPredictor
// The resulting duration vector is already scalled at this point
auto [d, indices, effectiveDuration] = durationPredictor_.generate(
std::span(tokens),
std::span(reinterpret_cast<bool *>(textMask.data()), textMask.size()),
std::span(voice).last(constants::kVoiceRefHalfSize), speed);

// --- Synthesizer phase ---
// The Synthesizer may have different method sizes than the DP.
// Pad all inputs to the Synthesizer's selected method size.
size_t synthTokens = synthesizer_.getMethodTokenCount(dpTokens);
size_t dCols = d.sizes().back(); // 640

// Pad tokens and textMask to synthTokens (no-op when synthTokens == dpTokens)
tokens.resize(synthTokens, 0);
textMask.resize(synthTokens, false);

// Pad indices to the maximum duration limit
indices.resize(context_.inputDurationLimit, 0);

// Prepare duration data for Synthesizer.
// When sizes match, pass the DP tensor directly to avoid a 320KB copy.
size_t durSize = synthTokens * dCols;
std::vector<float> durPadded;
float *durPtr;
if (synthTokens == dpTokens) {
durPtr = d.mutable_data_ptr<float>();
} else {
durPadded.resize(durSize, 0.0f);
std::copy_n(d.const_data_ptr<float>(), dpTokens * dCols, durPadded.data());
durPtr = durPadded.data();
}

// Inference 2 - Synthesizer
auto decoding = synthesizer_.generate(
std::span(tokens),
std::span(reinterpret_cast<bool *>(textMask.data()), textMask.size()),
std::span(indices),
// Note that we reduce the size of d tensor to match the initial number of
// input tokens
std::span<float>(d.mutable_data_ptr<float>(),
noTokens * d.sizes().back()),
std::span<float>(durPtr, durSize),
std::span(voice));
auto audioTensor = decoding->at(0).toTensor();

Expand All @@ -233,9 +259,7 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
auto croppedAudio =
utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond);

std::vector<float> result(croppedAudio.begin(), croppedAudio.end());

return result;
return {croppedAudio.begin(), croppedAudio.end()};
}

std::size_t Kokoro::getMemoryLowerBound() const noexcept {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,9 @@ class Kokoro {
DurationPredictor durationPredictor_;
Synthesizer synthesizer_;

// Voice array
// There is a separate voice vector for each of the possible numbers of input
// tokens.
std::array<std::array<float, constants::kVoiceRefSize>,
constants::kMaxInputTokens>
voice_;
// Voice array — dynamically sized to match the voice file.
// Each row is a style vector for a given input token count.
std::vector<std::array<float, constants::kVoiceRefSize>> voice_;

// Extra control variables
bool isStreaming_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,34 @@ Synthesizer::Synthesizer(const std::string &modelSource,
const Context &modelContext,
std::shared_ptr<react::CallInvoker> callInvoker)
: BaseModel(modelSource, callInvoker), context_(modelContext) {
const auto inputTensors = getAllInputShapes("forward");
// Discover all forward methods (forward, forward_8, forward_32, etc.)
auto availableMethods = module_->method_names();
if (availableMethods.ok()) {
const auto &names = *availableMethods;
for (const auto &name : names) {
if (name.rfind("forward", 0) == 0) {
const auto inputTensors = getAllInputShapes(name);
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(inputTensors[0], 2);
CHECK_SIZE(inputTensors[1], 2);
CHECK_SIZE(inputTensors[2], 1);
size_t inputSize = inputTensors[0][1];
forwardMethods_.emplace_back(name, inputSize);
}
}
std::stable_sort(forwardMethods_.begin(), forwardMethods_.end(),
[](const auto &a, const auto &b) { return a.second < b.second; });
}

// Perform checks to validate model's compatibility with native code
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(
inputTensors[0],
2); // input tokens must be of shape {1, T}, where T is number of tokens
CHECK_SIZE(
inputTensors[1],
2); // text mask must be of shape {1, T}, where T is number of tokens
CHECK_SIZE(inputTensors[2],
1); // indices must be of shape {D}, where D is a maximum duration
// Fallback: if no methods discovered, validate "forward" directly
if (forwardMethods_.empty()) {
const auto inputTensors = getAllInputShapes("forward");
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(inputTensors[0], 2);
CHECK_SIZE(inputTensors[1], 2);
CHECK_SIZE(inputTensors[2], 1);
forwardMethods_.emplace_back("forward", inputTensors[0][1]);
}
}

Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
Expand Down Expand Up @@ -54,14 +70,19 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefSize},
ref_s.data(), ScalarType::Float);

// Execute the appropriate "forward_xyz" method, based on given method name
auto results = forward(
// Select appropriate forward method based on token count
auto it = std::find_if(forwardMethods_.begin(), forwardMethods_.end(),
[noTokens](const auto &entry) { return static_cast<int32_t>(entry.second) >= noTokens; });
std::string selectedMethod = (it != forwardMethods_.end()) ? it->first : forwardMethods_.back().first;

// Execute the selected forward method
auto results = execute(selectedMethod,
{tokensTensor, textMaskTensor, indicesTensor, durTensor, voiceRefTensor});

if (!results.ok()) {
throw RnExecutorchError(
RnExecutorchErrorCode::InvalidModelOutput,
"[Kokoro::Synthesizer] Failed to execute method forward"
"[Kokoro::Synthesizer] Failed to execute method " + selectedMethod +
", error: " +
std::to_string(static_cast<uint32_t>(results.error())));
}
Expand All @@ -72,13 +93,12 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
}

size_t Synthesizer::getTokensLimit() const {
// Returns tokens input (shape {1, T}) second dim
return getInputShape("forward", 0)[1];
return forwardMethods_.empty() ? 0 : forwardMethods_.back().second;
}

size_t Synthesizer::getDurationLimit() const {
// Returns indices vector first dim (shape {D})
return getInputShape("forward", 2)[0];
if (forwardMethods_.empty()) return 0;
return getInputShape(forwardMethods_.back().first, 2)[0];
}

} // namespace rnexecutorch::models::text_to_speech::kokoro
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <algorithm>
#include <memory>
#include <span>
#include <string>
Expand Down Expand Up @@ -49,7 +50,17 @@ class Synthesizer : public BaseModel {
size_t getTokensLimit() const;
size_t getDurationLimit() const;

// Returns the token count of the forward method that would be selected
// for a given input size. E.g., input 37 -> returns 64 (forward_64).
size_t getMethodTokenCount(size_t inputSize) const {
auto it = std::ranges::find_if(forwardMethods_,
[inputSize](const auto &e) { return e.second >= inputSize; });
return (it != forwardMethods_.end()) ? it->second : forwardMethods_.back().second;
}

private:
// Forward methods discovered at construction (e.g. forward_8, forward_64, forward_128)
std::vector<std::pair<std::string, size_t>> forwardMethods_;
// Shared model context
// A const reference to singleton in Kokoro.
const Context &context_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ std::vector<Token> tokenize(const std::u32string &phonemes,
? constants::kVocab.at(p)
: constants::kInvalidToken;
});
auto validSeqEnd = std::partition(
auto validSeqEnd = std::stable_partition(
tokens.begin() + 1, tokens.begin() + effNoTokens + 1,
[](Token t) -> bool { return t != constants::kInvalidToken; });
std::fill(validSeqEnd, tokens.begin() + effNoTokens + 1,
Expand Down
Binary file not shown.
Binary file not shown.