diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..d6d6395 --- /dev/null +++ b/.clang-format @@ -0,0 +1,9 @@ +--- +Language: Cpp +BasedOnStyle: Google + +# Headstage differences +ColumnLimit: 100 +SortIncludes: false +DerivePointerAlignment: false +PointerAlignment: Left diff --git a/.gitignore b/.gitignore index 0728111..e58fb04 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ *.pb.cc *.pb.h +*_pb2.py +*.pyc .cache/ .vscode/ @@ -76,4 +78,7 @@ Testing/ *.deb # cached permissions -.synapse_deploy_cache.json \ No newline at end of file +.synapse_deploy_cache.json +app-sdk/ +vcpkg_installed/ +*.desc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f5150be --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,45 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-executables-have-shebangs + - id: check-json + - id: check-shebang-scripts-are-executable + - id: check-yaml + - id: detect-private-key + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + +- repo: https://github.com/pocc/pre-commit-hooks + rev: v1.3.5 + hooks: + - id: clang-format + args: + - --style=file + - -i + - id: cppcheck + name: cppcheck + language: system + entry: cppcheck + args: [ + "--enable=style,performance,warning", + "--check-level=exhaustive", + "--suppress=missingInclude", + "--suppress=missingIncludeSystem", + "--suppress=unusedFunction", + "--suppress=useStlAlgorithm", + "--inconclusive", + "--error-exitcode=1", + "--std=c++20", + "--platform=unix64", + "--language=c++" + ] + files: \.(c|cpp)$ + +- repo: https://github.com/koalaman/shellcheck-precommit + rev: v0.10.0 + hooks: + - id: shellcheck diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bf0699..b6292e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,6 @@ include("cmake/protos.cmake") option(BUILD_FOR_ARM64 "Build for ARM64 architecture" OFF) option(USE_LOCAL_SDK "Use locally built SDK instead of system installation" OFF) set(LOCAL_SDK_PATH "app-sdk" CACHE PATH "Path to local SDK build directory") - if (BUILD_FOR_ARM64 AND NOT CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64|arm64") message(STATUS "Cross compiling for ARM64") set(CMAKE_SYSTEM_NAME Linux) @@ -27,13 +26,12 @@ if(USE_LOCAL_SDK) message(FATAL_ERROR "USE_LOCAL_SDK is ON but LOCAL_SDK_PATH is not set") endif() message(STATUS "Using local SDK from: ${LOCAL_SDK_PATH}") - + set(SDK_LIB_PATH "${LOCAL_SDK_PATH}") set(SDK_LIB_NAME "synapse-app-sdk") link_directories(${SDK_LIB_PATH}) - + set(SYNAPSE_APP_SDK_LIB ${SDK_LIB_NAME}) - include_directories("${LOCAL_SDK_PATH}/include") else() find_library(SYNAPSE_APP_SDK_LIB NAMES synapse-app-sdk libsynapse-app-sdk.so.0.1.0) @@ -45,9 +43,20 @@ add_executable(synapse-example-app ${CMAKE_CURRENT_SOURCE_DIR}/src/fixed_weight_decoder.cpp ) +# Generate synapse sdk api protobufs generate_protobufs( TARGET synapse-example-app OUT_PROTO_DIR PROTO_OUT_DIR + PROTO_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/external/sciencecorp/synapse-api" +) + +# And also generate your custom app configuration protobuf +generate_protobufs( + TARGET synapse-example-app + OUT_PROTO_DIR APP_PROTO_OUT_DIR + PROTO_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/proto" + GENERATE_PYTHON + PYTHON_OUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/proto" ) # Be strict about the standard @@ -61,8 +70,9 @@ set_target_properties(synapse-example-app PROPERTIES target_include_directories(synapse-example-app PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src - PRIVATE + PRIVATE ${PROTO_OUT_DIR} + ${APP_PROTO_OUT_DIR} ) # Link dependencies diff --git a/Dockerfile b/Dockerfile index 65db546..991b5d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -98,7 +98,7 @@ RUN git clone https://github.com/microsoft/vcpkg.git "${VCPKG_ROOT}" && \ # copy project-specific ports and manifest before installing COPY vcpkg.json "${VCPKG_ROOT}/vcpkg.json" -COPY external "${VCPKG_ROOT}/external/" +COPY external/sciencecorp/vcpkg "${VCPKG_ROOT}/external/sciencecorp/vcpkg" RUN cd "${VCPKG_ROOT}" && \ ./vcpkg install \ @@ -109,7 +109,7 @@ RUN cd "${VCPKG_ROOT}" && \ # ----------------------------------------------------------------------------- # Install Synapse SDK from internal repository (same steps on both) # ----------------------------------------------------------------------------- -ARG SDK_VERSION=0.3.0 +ARG SDK_VERSION=0.4.4 COPY keys/science-repo-public.asc /usr/share/keyrings/scifi-repo-science-public.asc RUN set -eux; \ apt-get update && apt-get install -y --no-install-recommends ca-certificates; \ @@ -128,4 +128,4 @@ ENV VCPKG_INSTALLED_DIR="${VCPKG_ROOT}/build/host/vcpkg_installed" # Final workspace & entrypoint # ----------------------------------------------------------------------------- WORKDIR /home/workspace -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 80476f4..75c5b57 100644 --- a/README.md +++ b/README.md @@ -122,3 +122,16 @@ To listen to joystick output: ```bash python3 ${REPO_ROOT}/client/listen_to_joystick.py --device-ip ``` + +## Development +If you want, it is recommended to install and configure pre-commit to auto lint your files. + +```bash +pip install pre-commit + +pre-commit install + +# Now this will be run when you commit +# However, you can also run it manually like this +pre-commit run +``` diff --git a/client/listen_to_joystick.py b/client/listen_to_joystick.py old mode 100644 new mode 100755 diff --git a/cmake/protos.cmake b/cmake/protos.cmake index cbe1d06..3845e45 100644 --- a/cmake/protos.cmake +++ b/cmake/protos.cmake @@ -1,8 +1,8 @@ function(generate_protobufs) cmake_parse_arguments(PARSE_ARGV 0 "arg" - "" - "TARGET;OUT_PROTO_DIR" - "" + "GENERATE_PYTHON" + "TARGET;OUT_PROTO_DIR;PYTHON_OUT_DIR" + "PROTO_DIRS;PROTO_FILES" ) if(DEFINED arg_UNPARSED_ARGUMENTS) @@ -14,16 +14,56 @@ function(generate_protobufs) if(NOT DEFINED arg_TARGET) message(FATAL_ERROR "TARGET must be specified.") endif() + if(NOT DEFINED arg_PROTO_DIRS AND NOT DEFINED arg_PROTO_FILES) + message(FATAL_ERROR "At least one of PROTO_DIRS or PROTO_FILES must be specified.") + endif() + if(arg_GENERATE_PYTHON AND NOT DEFINED arg_PYTHON_OUT_DIR) + message(FATAL_ERROR "PYTHON_OUT_DIR must be specified when GENERATE_PYTHON is enabled.") + endif() set(PROTO_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/include) file(MAKE_DIRECTORY ${PROTO_OUT_DIR}) + # Initialize empty lists for proto include dirs and proto files + set(PROTO_INCLUDE_DIRS "") + set(PROTOS "") + + # Include Synapse API proto directory in the include path get_filename_component(SYNAPSE_PROTO_INCLUDE_DIR ./external/sciencecorp/synapse-api REALPATH) - file(GLOB_RECURSE SYNAPSE_PROTOS ${SYNAPSE_PROTO_INCLUDE_DIR}/*.proto) + list(APPEND PROTO_INCLUDE_DIRS ${SYNAPSE_PROTO_INCLUDE_DIR}) + if(DEFINED SCIFI_PROTO_INCLUDE_DIR) + list(APPEND PROTO_INCLUDE_DIRS ${SCIFI_PROTO_INCLUDE_DIR}) + endif() + + # Add custom proto directories and files if provided + if(DEFINED arg_PROTO_DIRS) + foreach(DIR ${arg_PROTO_DIRS}) + get_filename_component(ABS_DIR ${DIR} REALPATH) + list(APPEND PROTO_INCLUDE_DIRS ${ABS_DIR}) + # If specific files not provided, include all protos in the directory + if(NOT DEFINED arg_PROTO_FILES) + file(GLOB_RECURSE DIR_PROTOS ${ABS_DIR}/*.proto) + list(APPEND PROTOS ${DIR_PROTOS}) + endif() + endforeach() + endif() + + # Add specific proto files if provided + if(DEFINED arg_PROTO_FILES) + foreach(PROTO_FILE ${arg_PROTO_FILES}) + get_filename_component(ABS_PROTO_FILE ${PROTO_FILE} REALPATH) + list(APPEND PROTOS ${ABS_PROTO_FILE}) + + # Also add the directory containing this proto file to include dirs + get_filename_component(PROTO_DIR ${ABS_PROTO_FILE} DIRECTORY) + list(APPEND PROTO_INCLUDE_DIRS ${PROTO_DIR}) + endforeach() + endif() - set(PROTO_INCLUDE_DIRS ${SYNAPSE_PROTO_INCLUDE_DIR} ${SCIFI_PROTO_INCLUDE_DIR}) - set(PROTOS ${SYNAPSE_PROTOS} ${SCIFI_PROTOS}) + # Remove duplicates from PROTO_INCLUDE_DIRS + list(REMOVE_DUPLICATES PROTO_INCLUDE_DIRS) + # Generate C++ protobufs (existing functionality) protobuf_generate( TARGET ${arg_TARGET} LANGUAGE cpp @@ -33,6 +73,52 @@ function(generate_protobufs) OUT_VAR PROTO_SOURCES ) + # Generate Python protobufs if requested + if(arg_GENERATE_PYTHON) + file(MAKE_DIRECTORY ${arg_PYTHON_OUT_DIR}) + + # Generate Python bindings + protobuf_generate( + TARGET ${arg_TARGET} + LANGUAGE python + IMPORT_DIRS ${PROTO_INCLUDE_DIRS} + PROTOS ${PROTOS} + PROTOC_OUT_DIR ${arg_PYTHON_OUT_DIR} + OUT_VAR PYTHON_PROTO_SOURCES + ) + + # Convert PROTO_INCLUDE_DIRS to --proto_path arguments BEFORE using them + set(PROTO_INCLUDE_DIRS_ARGS "") + foreach(INCLUDE_DIR ${PROTO_INCLUDE_DIRS}) + list(APPEND PROTO_INCLUDE_DIRS_ARGS --proto_path=${INCLUDE_DIR}) + endforeach() + + # Generate descriptor sets (.desc files) for runtime loading + set(DESC_OUT_DIR ${arg_PYTHON_OUT_DIR}) + foreach(PROTO_FILE ${PROTOS}) + get_filename_component(PROTO_NAME ${PROTO_FILE} NAME_WE) + set(DESC_FILE ${DESC_OUT_DIR}/${PROTO_NAME}.desc) + + add_custom_command( + OUTPUT ${DESC_FILE} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --descriptor_set_out=${DESC_FILE} + --include_imports + ${PROTO_INCLUDE_DIRS_ARGS} + ${PROTO_FILE} + DEPENDS ${PROTO_FILE} + COMMENT "Generating descriptor set for ${PROTO_NAME}" + VERBATIM + ) + + # Add to target dependencies + add_custom_target(${arg_TARGET}_${PROTO_NAME}_desc DEPENDS ${DESC_FILE}) + add_dependencies(${arg_TARGET} ${arg_TARGET}_${PROTO_NAME}_desc) + endforeach() + + message(STATUS "Python protobufs and descriptors will be generated in: ${arg_PYTHON_OUT_DIR}") + endif() + # NOTE: Uncomment this to generate the gRPC code if we ever need it # protobuf_generate( # TARGET ${arg_TARGET} diff --git a/config/simulator_32ch.json b/config/simulator_32ch.json index 41a3b02..8b0e677 100644 --- a/config/simulator_32ch.json +++ b/config/simulator_32ch.json @@ -4,7 +4,19 @@ "type": "kApplication", "id": 2, "application": { - "name": "synapse-example-app" + "name": "synapse-example-app", + "parameters": { + "@type": "type.googleapis.com/app.ExampleAppConfig", + "low_cutoff_hz": 200.0, + "high_cutoff_hz": 5000.0, + "spike_threshold_uv": 50.0, + "waveform_size": 50, + "refractory_period_us": 1000, + "window_size": 5, + "max_expected_rate": 10.0, + "cursor_channels": [0, 7, 16, 30], + "enable_function_profiling": false + } } }, { diff --git a/external/sciencecorp/synapse-api b/external/sciencecorp/synapse-api index 1b01a82..58b73cd 160000 --- a/external/sciencecorp/synapse-api +++ b/external/sciencecorp/synapse-api @@ -1 +1 @@ -Subproject commit 1b01a8211e04709c22c169827eeb62ffcd10e1ac +Subproject commit 58b73cd76dcb1584066163d133e08f7a6084d328 diff --git a/manifest.json b/manifest.json index c164bbe..374c8e8 100644 --- a/manifest.json +++ b/manifest.json @@ -1,3 +1,5 @@ { - "name": "synapse-example-app" + "name": "synapse-example-app", + "proto_files": ["proto/example_app.proto"], + "device_config_path": "config/simulator_32ch.json" } diff --git a/proto/example_app.proto b/proto/example_app.proto new file mode 100644 index 0000000..8fba323 --- /dev/null +++ b/proto/example_app.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package app; + +message ExampleAppConfig { + float low_cutoff_hz = 1; + float high_cutoff_hz = 2; + + // Threshold in microvolts + float spike_threshold_uv = 3; + + // Samples per waveform + uint32 waveform_size = 4; + + // Refractory Period, in microseconds + uint32 refractory_period_us = 5; + + // Number of bins to use for firing rate estimation + uint32 window_size = 6; + + // max expected rate, used for normalization + float max_expected_rate = 7; + + // Cursor control channels to use, we expect there to be four channels + repeated int32 cursor_channels = 8; + + // Should we be enabling the function profiling and readouts + bool enable_function_profiling = 9; +} diff --git a/src/fixed_weight_decoder.cpp b/src/fixed_weight_decoder.cpp index 79871f3..badd342 100644 --- a/src/fixed_weight_decoder.cpp +++ b/src/fixed_weight_decoder.cpp @@ -1,38 +1,41 @@ #include "fixed_weight_decoder.hpp" +#include // for std::clamp +#include #include +#include // for parse_protobuf_message #include -#include -#include // for std::clamp -#include // for parse_protobuf_message - -namespace app -{ - // Helper function to clamp a value between min and max - template - T clamp(T value, T min, T max) - { - return (value < min) ? min : (value > max) ? max - : value; - } - FixedWeightDecoder::FixedWeightDecoder() : publish_rate_limiter_(kPublishRateSec) {} +namespace app { +// Helper function to clamp a value between min and max +template +T clamp(T value, T min, T max) { + return (value < min) ? min : (value > max) ? max : value; +} - bool FixedWeightDecoder::setup() - { - const uint32_t broadband_node_id = 1; - if (!setup_reader(broadband_node_id)) - { - spdlog::warn("Failed to set up reader for controller"); - return 1; - } +FixedWeightDecoder::FixedWeightDecoder() : publish_rate_limiter_(kPublishRateSec) {} - // Setup our output tap - if (!create_tap("joystick_out")) - { - spdlog::warn("Failed to create tap for joystick out"); - return false; - } +bool FixedWeightDecoder::setup() { + // Make sure our app configuration is valid + if (!get_app_parameters( + [this](const app::ExampleAppConfig& config) { return validate_configuration(config); }, + configuration_)) { + spdlog::error("Failed to validate app parameters from app configuration"); + return false; + } + + const uint32_t broadband_node_id = 1; + if (!setup_reader(broadband_node_id)) { + spdlog::warn("Failed to set up reader for controller"); + return 1; + } + // Setup our output tap + if (!create_tap("joystick_out")) { + spdlog::warn("Failed to create tap for joystick out"); + return false; + } + + if (configuration_.enable_function_profiling()) { // Enable performance monitoring function_profiler_manager_.add("full_loop"); @@ -41,373 +44,365 @@ namespace app spdlog::error("Failed to enable function profile monitoring"); return false; } - return true; } - void FixedWeightDecoder::main() - { - // Store our broadband frames here - const float bin_size_ms = 10; - std::vector broadband_frames; - - while (node_running_) - { - // Receive data from the node you configured - if (!wait_for_frames(broadband_frames, bin_size_ms)) - { - // No frames just go wait again - continue; - } + return true; +} - // Keep track of how long processing takes - start_profile("full_loop"); +void FixedWeightDecoder::main() { + // Store our broadband frames here + const float bin_size_ms = 10; + std::vector broadband_frames; - // You have a set of broadband frames now, you can do whatever you want - // 1. Initialize the filters and our state on the first full set of frames - const auto broadband_frame = broadband_frames.at(0); - if (!filters_initialized_) - { - const size_t channel_count = broadband_frame.frame_data_size(); - const float sample_rate_hz = broadband_frame.sample_rate_hz(); + while (node_running_) { + // Receive data from the node you configured + if (!wait_for_frames(broadband_frames, bin_size_ms)) { + // No frames just go wait again + continue; + } - // Store the sample rate for later use - sample_rate_hz_ = sample_rate_hz; + // Keep track of how long processing takes + start_profile("full_loop"); - initialize_filters(channel_count, sample_rate_hz, bin_size_ms); - initialize_spike_detectors(channel_count); + // You have a set of broadband frames now, you can do whatever you want + // 1. Initialize the filters and our state on the first full set of frames + const auto broadband_frame = broadband_frames.at(0); + if (!filters_initialized_) { + const size_t channel_count = broadband_frame.frame_data_size(); + const float sample_rate_hz = broadband_frame.sample_rate_hz(); - // Move to the next loop after init - continue; - } + // Store the sample rate for later use + sample_rate_hz_ = sample_rate_hz; - // Cleanup any previously detected spikes before processing new frames - cleanup_spike_events(); - - // 2. Filter the received frames - // We have a mapping of channel to filtered data with size of the frames we got - // TODO/NOTE: you could drop out early based on timestamps - std::vector> filtered_channel_data; - filtered_channel_data.resize(broadband_frames.at(0).frame_data_size()); - for (auto &channel_vector : filtered_channel_data) - { - channel_vector.reserve(broadband_frames.size()); - } + initialize_filters(channel_count, sample_rate_hz, bin_size_ms); + initialize_spike_detectors(channel_count); - // Create a vector to count spikes per channel in this batch - std::vector spike_counts(broadband_frames.at(0).frame_data_size(), 0); - - for (const auto &frame : broadband_frames) - { - const auto &frame_data = frame.frame_data(); - const uint64_t frame_timestamp_ns = frame.timestamp_ns(); - - for (int channel_id = 0; channel_id < frame_data.size(); ++channel_id) - { - // TODO: bounds checking - but we might not even want this way of doing things - auto &channel_filter = bandpass_filters_.at(channel_id); - const float filtered_data = channel_filter->filter(frame_data[channel_id]); - filtered_channel_data.at(channel_id).push_back(filtered_data); - - // 3. Detect spikes on the filtered data - if (spike_detectors_initialized_) - { - auto &spike_detector = spike_detectors_.at(channel_id); - - // Pass the filtered data to the spike detector along with the frame timestamp - // The detector handles the rest internally - synapse::SpikeEvent *spike_event = - spike_detector->detect(filtered_data, frame_timestamp_ns, channel_id); - - if (spike_event != nullptr) - { - // Store the detected spike for further processing - detected_spikes_.push_back(spike_event); - - // Increment the spike count for this channel - spike_counts[channel_id]++; - } - } - } - } + // Move to the next loop after init + continue; + } - // Add current binned spike counts to the window - spike_count_window_.push_back(spike_counts); + // Cleanup any previously detected spikes before processing new frames + cleanup_spike_events(); + + // 2. Filter the received frames + // We have a mapping of channel to filtered data with size of the frames we + // got + // TODO/NOTE: you could drop out early based on timestamps + std::vector> filtered_channel_data; + filtered_channel_data.resize(broadband_frames.at(0).frame_data_size()); + for (auto& channel_vector : filtered_channel_data) { + channel_vector.reserve(broadband_frames.size()); + } - // Keep window at fixed size - if (spike_count_window_.size() > window_size_) - { - spike_count_window_.pop_front(); - } + // Create a vector to count spikes per channel in this batch + std::vector spike_counts(broadband_frames.at(0).frame_data_size(), 0); - // Calculate cursor position based on the binned spike counts - float cursor_x = 0.0f; - float cursor_y = 0.0f; - - // Only calculate cursor position if we have enough data in the window - if (spike_count_window_.size() == window_size_) - { - // Calculate firing rates over the window for each cursor control channel - std::array firing_rates = {0.0f, 0.0f, 0.0f, 0.0f}; - - for (int i = 0; i < 4; i++) - { - size_t ch = cursor_channels_[i]; - for (const auto &bin_counts : spike_count_window_) - { - firing_rates[i] += bin_counts[ch]; + for (const auto& frame : broadband_frames) { + const auto& frame_data = frame.frame_data(); + const uint64_t frame_timestamp_ns = frame.timestamp_ns(); + + for (int channel_id = 0; channel_id < frame_data.size(); ++channel_id) { + // TODO: bounds checking - but we might not even want this way of doing + // things + auto& channel_filter = bandpass_filters_.at(channel_id); + const float filtered_data = channel_filter->filter(frame_data[channel_id]); + filtered_channel_data.at(channel_id).push_back(filtered_data); + + // 3. Detect spikes on the filtered data + if (spike_detectors_initialized_) { + auto& spike_detector = spike_detectors_.at(channel_id); + + // Pass the filtered data to the spike detector along with the frame + // timestamp The detector handles the rest internally + synapse::SpikeEvent* spike_event = + spike_detector->detect(filtered_data, frame_timestamp_ns, channel_id); + + if (spike_event != nullptr) { + // Store the detected spike for further processing + detected_spikes_.push_back(spike_event); + + // Increment the spike count for this channel + spike_counts[channel_id]++; } - firing_rates[i] /= window_size_; // Average over window } + } + } - // Calculate x-position based on first channel pair (differential) - cursor_x = firing_rates[1] - firing_rates[0]; // Positive = right, negative = left + // Add current binned spike counts to the window + spike_count_window_.push_back(spike_counts); - // Calculate y-position based on second channel pair (differential) - cursor_y = firing_rates[3] - firing_rates[2]; // Positive = up, negative = down + // Keep window at fixed size + if (spike_count_window_.size() > configuration_.window_size()) { + spike_count_window_.pop_front(); + } - // Normalize to reasonable range (-1 to 1) - cursor_x = clamp(cursor_x / max_expected_rate_, -1.0f, 1.0f); - cursor_y = clamp(cursor_y / max_expected_rate_, -1.0f, 1.0f); - } - else - { - // Not enough data in window yet, use default values - cursor_x = 0.0f; - cursor_y = 0.0f; - } + // Calculate cursor position based on the binned spike counts + float cursor_x = 0.0f; + float cursor_y = 0.0f; - // Create a tensor with the cursor position - synapse::Tensor output_tensor; - const auto tensor_shape = {2}; - output_tensor.mutable_shape()->Add(tensor_shape.begin(), tensor_shape.end()); - output_tensor.set_dtype(synapse::Tensor_DType_DT_FLOAT); - output_tensor.set_endianness(synapse::Tensor_Endianness_TENSOR_LITTLE_ENDIAN); - - // Use the calculated cursor position instead of raw data values - const std::vector tensor_data = {cursor_x, cursor_y}; - - // TODO: this could be a helper - // Get pointers for serialization - const char *data_ptr = reinterpret_cast(tensor_data.data()); - size_t data_size = tensor_data.size() * sizeof(float); - output_tensor.set_data(std::string(data_ptr, data_size)); - - const auto current_time_ns = synapse::get_steady_clock_now(); - output_tensor.set_timestamp_ns(current_time_ns.count()); - - // Then, send off your data using the publisher you configured earlier - // In this demo, we use a ZMQ publisher over tcp - if (publish_rate_limiter_.reset_if_elapsed()) - { - if (publish_tap("joystick_out", output_tensor)) - { - spdlog::info("Published tensor: [x,y]: [{},{}]", tensor_data[0], tensor_data[1]); - } - else - { - spdlog::warn("Failed to publish tensor data"); - } - stop_profile("full_loop"); + // Only calculate cursor position if we have enough data in the window + if (spike_count_window_.size() == configuration_.window_size()) { + // Calculate firing rates over the window for each cursor control channel + std::array firing_rates = {0.0f, 0.0f, 0.0f, 0.0f}; - // We can also get a debug print of the output - print_profile("full_loop"); + for (int i = 0; i < 4; i++) { + size_t ch = cursor_channels_[i]; + for (const auto& bin_counts : spike_count_window_) { + firing_rates[i] += bin_counts[ch]; + } + firing_rates[i] /= configuration_.window_size(); // Average over window } - // You can sleep here if you want, - // We busy wait up at the top if there is no data, so you don't need to here - } - } + // Calculate x-position based on first channel pair (differential) + cursor_x = firing_rates[1] - firing_rates[0]; // Positive = right, negative = left - bool FixedWeightDecoder::wait_for_frames(std::vector &frames, - float bin_size_ms) - { - if (bin_size_ms <= 0) - { - spdlog::warn("invalid bin size of: {}", bin_size_ms); - return false; + // Calculate y-position based on second channel pair (differential) + cursor_y = firing_rates[3] - firing_rates[2]; // Positive = up, negative = down + + // Normalize to reasonable range (-1 to 1) + cursor_x = clamp(cursor_x / configuration_.max_expected_rate(), -1.0f, 1.0f); + cursor_y = clamp(cursor_y / configuration_.max_expected_rate(), -1.0f, 1.0f); + } else { + // Not enough data in window yet, use default values + cursor_x = 0.0f; + cursor_y = 0.0f; } - const uint64_t target_bin_size_ns = static_cast(bin_size_ms * 1e6); - - // Prepare our output vector - frames.clear(); - - // Get the first timestamp - uint64_t first_timestamp_ns = 0; - - // TODO: We should consider having a timeout here - while (node_running_) - { - // In this example, we are listening to BroadbandFrame data - // TODO: the broadband node sends over the messages using multipart - // Figure out why this is the case - auto messages = data_reader_->receive_multipart(); - if (messages.empty()) - { - // Just keep trying - // TODO: We should have better signaling on the read failure - std::this_thread::sleep_for(std::chrono::microseconds(1)); - continue; + // Create a tensor with the cursor position + synapse::Tensor output_tensor; + const auto tensor_shape = {2}; + output_tensor.mutable_shape()->Add(tensor_shape.begin(), tensor_shape.end()); + output_tensor.set_dtype(synapse::Tensor_DType_DT_FLOAT); + output_tensor.set_endianness(synapse::Tensor_Endianness_TENSOR_LITTLE_ENDIAN); + + // Use the calculated cursor position instead of raw data values + output_tensor.set_data(synapse::pack_tensor_data({cursor_x, cursor_y})); + + const auto current_time_ns = synapse::get_steady_clock_now(); + output_tensor.set_timestamp_ns(current_time_ns.count()); + + // Then, send off your data using the publisher you configured earlier + // In this demo, we use a ZMQ publisher over tcp + if (publish_rate_limiter_.reset_if_elapsed()) { + if (publish_tap("joystick_out", output_tensor)) { + spdlog::info("Published tensor: [x,y]: [{},{}]", cursor_x, cursor_y); + } else { + spdlog::warn("Failed to publish tensor data"); } + stop_profile("full_loop"); - // Reserve space for these frames - frames.reserve(frames.size() + messages.size()); - - // Process each received message in this multipart - for (auto &message : messages) - { - // Parse the message into a BroadbandFrame - const auto maybe_frame = - synapse::parse_protobuf_message(std::move(message)); - if (!maybe_frame.has_value()) - { - spdlog::warn("Failed to parse broadband frame"); - // If we have no frames at all, return false - if (frames.empty()) - { - return false; - } - // Otherwise, return what we have so far - return true; - } + // We can also get a debug print of the output + print_profile("full_loop"); + } - const auto &broadband_frame = maybe_frame.value(); + // You can sleep here if you want, + // We busy wait up at the top if there is no data, so you don't need to here + } +} - // Check for dropped frames - const auto dropped_frames = - detect_dropped_frames(last_sequence_number_, broadband_frame.sequence_number()); - if (dropped_frames != 0) - { - spdlog::warn("Dropped: {} frames", dropped_frames); - } - last_sequence_number_ = broadband_frame.sequence_number(); +bool FixedWeightDecoder::wait_for_frames(std::vector& frames, + float bin_size_ms) { + if (bin_size_ms <= 0) { + spdlog::warn("invalid bin size of: {}", bin_size_ms); + return false; + } + + const uint64_t target_bin_size_ns = static_cast(bin_size_ms * 1e6); + + // Prepare our output vector + frames.clear(); + + // Get the first timestamp + uint64_t first_timestamp_ns = 0; + + // TODO: We should consider having a timeout here + while (node_running_) { + // In this example, we are listening to BroadbandFrame data + // TODO: the broadband node sends over the messages using multipart + // Figure out why this is the case + auto messages = data_reader_->receive_multipart(); + if (messages.empty()) { + // Just keep trying + // TODO: We should have better signaling on the read failure + std::this_thread::sleep_for(std::chrono::microseconds(1)); + continue; + } - // Record the first timestamp if this is our first frame - if (frames.empty()) - { - first_timestamp_ns = broadband_frame.timestamp_ns(); + // Reserve space for these frames + frames.reserve(frames.size() + messages.size()); + + // Process each received message in this multipart + for (auto& message : messages) { + // Parse the message into a BroadbandFrame + const auto maybe_frame = + synapse::parse_protobuf_message(std::move(message)); + if (!maybe_frame.has_value()) { + spdlog::warn("Failed to parse broadband frame"); + // If we have no frames at all, return false + if (frames.empty()) { + return false; } + // Otherwise, return what we have so far + return true; + } + + const auto& broadband_frame = maybe_frame.value(); - // Add the frame to our collection - frames.push_back(broadband_frame); + // Check for dropped frames + const auto dropped_frames = + detect_dropped_frames(last_sequence_number_, broadband_frame.sequence_number()); + if (dropped_frames != 0) { + spdlog::warn("Dropped: {} frames", dropped_frames); } + last_sequence_number_ = broadband_frame.sequence_number(); - // TODO: Instead, we could process the entire multipart? - // After processing this multipart, check if we've reached the bin size - if (!frames.empty()) - { - const auto &last_frame = frames.back(); - if (last_frame.timestamp_ns() - first_timestamp_ns >= target_bin_size_ns) - { - // We've collected enough frames to reach the bin size - return true; - } + // Record the first timestamp if this is our first frame + if (frames.empty()) { + first_timestamp_ns = broadband_frame.timestamp_ns(); } + + // Add the frame to our collection + frames.push_back(broadband_frame); } - return false; - } - int FixedWeightDecoder::detect_dropped_frames(const uint64_t last_sequence_number, - const uint64_t current_sequence_number) - { - const auto expected_sequence_number = last_sequence_number + 1; - return (current_sequence_number - expected_sequence_number); + // TODO: Instead, we could process the entire multipart? + // After processing this multipart, check if we've reached the bin size + if (!frames.empty()) { + const auto& last_frame = frames.back(); + if (last_frame.timestamp_ns() - first_timestamp_ns >= target_bin_size_ns) { + // We've collected enough frames to reach the bin size + return true; + } + } } + return false; +} - void FixedWeightDecoder::initialize_spike_detectors(const size_t channel_count) - { - // Create spike detectors for each channel - spike_detectors_.clear(); - spike_detectors_.reserve(channel_count); +int FixedWeightDecoder::detect_dropped_frames(const uint64_t last_sequence_number, + const uint64_t current_sequence_number) { + const auto expected_sequence_number = last_sequence_number + 1; + return (current_sequence_number - expected_sequence_number); +} - for (size_t channel_index = 0; channel_index < channel_count; ++channel_index) - { - auto detector_ptr = synapse::create_threshold_detector(spike_threshold_, waveform_size_, - refractory_period_us_, sample_rate_hz_); +void FixedWeightDecoder::initialize_spike_detectors(const size_t channel_count) { + // Create spike detectors for each channel + spike_detectors_.clear(); + spike_detectors_.reserve(channel_count); - if (detector_ptr == nullptr) - { - spdlog::error("Failed to create spike detector for channel: {}", channel_index); - } - spike_detectors_.push_back(std::move(detector_ptr)); + for (size_t channel_index = 0; channel_index < channel_count; ++channel_index) { + auto detector_ptr = synapse::create_threshold_detector( + configuration_.spike_threshold_uv(), configuration_.waveform_size(), + configuration_.refractory_period_us(), sample_rate_hz_); + + if (detector_ptr == nullptr) { + spdlog::error("Failed to create spike detector for channel: {}", channel_index); } + spike_detectors_.push_back(std::move(detector_ptr)); + } + + spdlog::info("Initialized spike detectors with threshold: {} μV, sample rate: {} Hz", + configuration_.spike_threshold_uv(), sample_rate_hz_); + spike_detectors_initialized_ = true; +} - spdlog::info("Initialized spike detectors with threshold: {} μV, sample rate: {} Hz", - spike_threshold_, sample_rate_hz_); - spike_detectors_initialized_ = true; +void FixedWeightDecoder::cleanup_spike_events() { + // Free memory for all detected spike events + for (auto spike_event : detected_spikes_) { + delete spike_event; } + detected_spikes_.clear(); +} - void FixedWeightDecoder::cleanup_spike_events() - { - // Free memory for all detected spike events - for (auto spike_event : detected_spikes_) - { - delete spike_event; - } - detected_spikes_.clear(); +void FixedWeightDecoder::initialize_filters(const size_t channel_count, const float sample_rate_hz, + const float bin_size_ms) { + if (!initialize_cursor_channels(channel_count)) { + return; } - void FixedWeightDecoder::initialize_filters(const size_t channel_count, - const float sample_rate_hz, - const float bin_size_ms) - { - if (!initialize_cursor_channels(channel_count)) - { - return; + // We have four channels selected, initialize our filters + spdlog::info("Initializing\tsample_rate={} Hz\tchannels={}\tbin_size={} ms", sample_rate_hz, + channel_count, bin_size_ms); + + // Create filters for each channel + bandpass_filters_.clear(); + bandpass_filters_.reserve(channel_count); + for (size_t channel_index = 0; channel_index < channel_count; ++channel_index) { + auto filter_ptr = synapse::create_bandpass_filter( + sample_rate_hz, configuration_.low_cutoff_hz(), configuration_.high_cutoff_hz()); + if (filter_ptr == nullptr) { + spdlog::error("Failed to create filter for channel: {}", channel_index); } + bandpass_filters_.push_back(std::move(filter_ptr)); + } + spdlog::info("Initialized filters"); + filters_initialized_ = true; +} - // We have four channels selected, initialize our filters - spdlog::info("Initializing\tsample_rate={} Hz\tchannels={}\tbin_size={} ms", sample_rate_hz, - channel_count, bin_size_ms); - - // Create filters for each channel - bandpass_filters_.clear(); - bandpass_filters_.reserve(channel_count); - for (size_t channel_index = 0; channel_index < channel_count; ++channel_index) - { - auto filter_ptr = synapse::create_bandpass_filter( - sample_rate_hz, low_cutoff_hz_, high_cutoff_hz_); - if (filter_ptr == nullptr) - { - spdlog::error("Failed to create filter for channel: {}", channel_index); - } - bandpass_filters_.push_back(std::move(filter_ptr)); - } - spdlog::info("Initialized filters"); - filters_initialized_ = true; +bool FixedWeightDecoder::initialize_cursor_channels(const size_t channel_count) { + if (channel_count < 4) { + spdlog::warn("Need at least four channels for joystick control"); + return false; } - bool FixedWeightDecoder::initialize_cursor_channels(const size_t channel_count) - { - if (channel_count < 4) - { - spdlog::warn("Need at least four channels for joystick control"); - return false; - } + // Copy values from the repeated field to the fixed-size array + for (int i = 0; i < 4 && i < configuration_.cursor_channels_size(); i++) { + cursor_channels_[i] = configuration_.cursor_channels(i); + } + + std::stringstream ss; + ss << "Using ["; + for (const auto& channel : cursor_channels_) { + ss << channel << ","; + } - // Select four random channels - // std::vector all_channels(channel_count); - // std::iota(all_channels.begin(), all_channels.end(), 0); + ss << "] for cursor control"; + spdlog::info("{}", ss.str()); + return true; +} - // // Randomly sample 4 channels - // std::random_device rd; - // std::mt19937 gen(rd()); - // std::sample(all_channels.begin(), all_channels.end(), cursor_channels_.begin(), 4, gen); +bool FixedWeightDecoder::validate_configuration(const app::ExampleAppConfig& config) { + // FIlters should be above zero, high cutoff should be above low cutoff + if (config.low_cutoff_hz() <= 0 || config.high_cutoff_hz() <= 0 || + config.low_cutoff_hz() >= config.high_cutoff_hz()) { + spdlog::error("Invalid filter configuration: low_cutoff={} Hz, high_cutoff={} Hz", + config.low_cutoff_hz(), config.high_cutoff_hz()); + return false; + } - std::stringstream ss; - ss << "Using ["; - for (const auto &channel : cursor_channels_) - { - ss << channel << ","; - } + // Spike threshold should be above zero + if (config.spike_threshold_uv() <= 0) { + spdlog::error("Invalid spike threshold: {}", config.spike_threshold_uv()); + return false; + } - ss << "] for cursor control"; - spdlog::info("{}", ss.str()); - return true; + // Refractory period should be above zero + if (config.refractory_period_us() <= 0) { + spdlog::error("Invalid refractory period: {}", config.refractory_period_us()); + return false; + } + + // Window size should be above zero + if (config.window_size() <= 0) { + spdlog::error("Invalid window size: {}", config.window_size()); + return false; + } + + // Max expected rate should be above zero + if (config.max_expected_rate() <= 0) { + spdlog::error("Invalid max expected rate: {}", config.max_expected_rate()); + return false; + } + + // Should only have 4 cursor channels + if (config.cursor_channels_size() != 4) { + spdlog::error("Invalid number of cursor channels: {}, expected 4", + config.cursor_channels_size()); + return false; } -} // namespace app -int main(const int, const char **) -{ - return synapse::Entrypoint(); + return true; } + +} // namespace app + +int main(const int, const char**) { return synapse::Entrypoint(); } diff --git a/src/fixed_weight_decoder.hpp b/src/fixed_weight_decoder.hpp index 3db9c5d..bea7d5a 100644 --- a/src/fixed_weight_decoder.hpp +++ b/src/fixed_weight_decoder.hpp @@ -1,90 +1,89 @@ #pragma once #include -#include #include +#include #include -#include -#include #include #include +#include +#include #include "api/datatype.pb.h" #include "api/nodes/broadband_source.pb.h" -namespace app -{ - // 10 hz - constexpr auto kPublishRateSec = 1.0 / 10.0; - class FixedWeightDecoder : public synapse::App - { - public: - FixedWeightDecoder(); - - virtual bool setup() override; - - protected: - virtual void main() override; - - private: - // Use this to detect if there is frame drops - uint64_t last_sequence_number_ = 0; - - // A timer to provide a consistent publishing cadence for joystick commands - synapse::Timer publish_rate_limiter_; - - // We want to filter the incoming broadband data, so do so here - std::atomic filters_initialized_{false}; - - // TODO(gilbert): This should probably be configurable? - const float low_cutoff_hz_ = 200.0; - const float high_cutoff_hz_ = 5000.0; - static constexpr int kSpectralFilterOrder = 2; - std::vector> bandpass_filters_; - - // Spike detection configuration and detectors - std::atomic spike_detectors_initialized_{false}; - const float spike_threshold_ = 50.0; // Threshold in microvolts - const uint32_t waveform_size_ = 50; // Total samples per waveform - const uint64_t refractory_period_us_ = 1000; // 1ms refractory period - float sample_rate_hz_ = 30000.0; // Will be updated during initialization - std::vector> spike_detectors_; - - // Collection of detected spikes - std::vector detected_spikes_; - - // Spike binning and cursor control parameters - static constexpr int window_size_ = 5; // Number of bins to use for firing rate estimation - static constexpr float max_expected_rate_ = 10.0f; // For normalization - std::deque> - spike_count_window_; // Window buffer to store binned spike counts - - // We will select 4 channels randomly for cursor control - std::array cursor_channels_ = {0, 7, 16, 30}; - - // Waits until a set of broadband frames are read from the node - // Returns false if there was an error reading - bool wait_for_frames(std::vector &frames, const float bin_size_ms); - - // If not zero, we dropped some frames, determine what to do - int detect_dropped_frames(const uint64_t last_sequence_number, - const uint64_t current_sequence_number); - - // Randomly select channels to use for cursor control - bool initialize_cursor_channels(const size_t channel_count); - - // Before starting, set up our filters. - // We can use the first broadband frame to do this initialization - void initialize_filters(const size_t channel_count, const float sample_rate_hz, - const float bin_size_ms); - - // Initialize spike detectors for each channel - void initialize_spike_detectors(const size_t channel_count); - - // Clean up any allocated spike events - void cleanup_spike_events(); - - // Calculate cursor position from spike counts - std::pair calculate_cursor_position(const std::vector &spike_counts); - }; -} // namespace app \ No newline at end of file +#include "example_app.pb.h" + +namespace app { +// 10 hz +constexpr auto kPublishRateSec = 1.0 / 10.0; +class FixedWeightDecoder : public synapse::App { + public: + FixedWeightDecoder(); + + virtual bool setup() override; + + protected: + virtual void main() override; + + private: + // Use this to detect if there is frame drops + uint64_t last_sequence_number_ = 0; + + // A timer to provide a consistent publishing cadence for joystick commands + synapse::Timer publish_rate_limiter_; + + // We want to filter the incoming broadband data, so do so here + std::atomic filters_initialized_{false}; + + // App parameters + app::ExampleAppConfig configuration_; + + // Filter configuration + static constexpr int kSpectralFilterOrder = 2; + std::vector> bandpass_filters_; + + // Spike detection configuration and detectors + std::atomic spike_detectors_initialized_{false}; + float sample_rate_hz_ = 30000.0; // Will be updated during initialization + std::vector> spike_detectors_; + + // Collection of detected spikes + std::vector detected_spikes_; + + // Spike binning and cursor control parameters + std::deque> + spike_count_window_; // Window buffer to store binned spike counts + + // We will select 4 channels randomly for cursor control + // Default to a random selection + std::array cursor_channels_ = {0, 7, 16, 30}; + + // Waits until a set of broadband frames are read from the node + // Returns false if there was an error reading + bool wait_for_frames(std::vector& frames, const float bin_size_ms); + + // If not zero, we dropped some frames, determine what to do + int detect_dropped_frames(const uint64_t last_sequence_number, + const uint64_t current_sequence_number); + + // Randomly select channels to use for cursor control + bool initialize_cursor_channels(const size_t channel_count); + + // Before starting, set up our filters. + // We can use the first broadband frame to do this initialization + void initialize_filters(const size_t channel_count, const float sample_rate_hz, + const float bin_size_ms); + + // Initialize spike detectors for each channel + void initialize_spike_detectors(const size_t channel_count); + + // Clean up any allocated spike events + void cleanup_spike_events(); + + // Calculate cursor position from spike counts + std::pair calculate_cursor_position(const std::vector& spike_counts); + + bool validate_configuration(const app::ExampleAppConfig& config); +}; +} // namespace app