From 04e39b35187af64cf2aab9bfc25f707326bfbb01 Mon Sep 17 00:00:00 2001 From: donutsoft Date: Fri, 15 May 2026 19:20:41 -0700 Subject: [PATCH] Add packet filtering support to Meshcore --- build_as_lib.py | 1 + docs/cli_commands.md | 119 ++++++++ examples/simple_repeater/MyMesh.cpp | 10 + examples/simple_repeater/MyMesh.h | 3 + platformio.ini | 1 + src/GroupChannel.cpp | 37 +++ src/GroupChannel.h | 17 ++ src/Mesh.h | 7 +- src/helpers/BaseChatMesh.cpp | 10 +- src/helpers/PacketFilter.cpp | 281 ++++++++++++++++++ src/helpers/PacketFilter.h | 34 +++ src/helpers/SimpleBloomFilter.cpp | 140 +++++++++ src/helpers/SimpleBloomFilter.h | 37 +++ src/helpers/SimplePatternMatcher.cpp | 33 ++ src/helpers/SimplePatternMatcher.h | 9 + .../AdvertRateLimitClassifier.cpp | 182 ++++++++++++ .../packet_filter/AdvertRateLimitClassifier.h | 67 +++++ .../BasePacketFilterClassifier.h | 69 +++++ .../ChannelMessageClassifier.cpp | 69 +++++ .../packet_filter/ChannelMessageClassifier.h | 30 ++ src/helpers/packet_filter/PacketFilterRule.h | 96 ++++++ 21 files changed, 1238 insertions(+), 14 deletions(-) create mode 100644 src/GroupChannel.cpp create mode 100644 src/GroupChannel.h create mode 100644 src/helpers/PacketFilter.cpp create mode 100644 src/helpers/PacketFilter.h create mode 100644 src/helpers/SimpleBloomFilter.cpp create mode 100644 src/helpers/SimpleBloomFilter.h create mode 100644 src/helpers/SimplePatternMatcher.cpp create mode 100644 src/helpers/SimplePatternMatcher.h create mode 100644 src/helpers/packet_filter/AdvertRateLimitClassifier.cpp create mode 100644 src/helpers/packet_filter/AdvertRateLimitClassifier.h create mode 100644 src/helpers/packet_filter/BasePacketFilterClassifier.h create mode 100644 src/helpers/packet_filter/ChannelMessageClassifier.cpp create mode 100644 src/helpers/packet_filter/ChannelMessageClassifier.h create mode 100644 src/helpers/packet_filter/PacketFilterRule.h diff --git a/build_as_lib.py b/build_as_lib.py index d8e95378eb..d6cb1d4dde 100644 --- a/build_as_lib.py +++ b/build_as_lib.py @@ -6,6 +6,7 @@ src_filter = [ '+<*.cpp>', '+', + '+', '+', '+', '+', diff --git a/docs/cli_commands.md b/docs/cli_commands.md index 99dced3658..8de1481241 100644 --- a/docs/cli_commands.md +++ b/docs/cli_commands.md @@ -7,6 +7,7 @@ This document provides an overview of CLI commands that can be sent to MeshCore - [Operational](#operational) - [Neighbors](#neighbors-repeater-only) - [Statistics](#statistics) +- [Packet Filtering](#packet-filtering-repeater-only) - [Logging](#logging) - [Information](#info) - [Configuration](#configuration) @@ -144,6 +145,124 @@ This document provides an overview of CLI commands that can be sent to MeshCore --- +## Packet Filtering (Repeater Only) + +Packet filters let a repeater decline to forward selected flood packets. Rules are evaluated before forwarding; packets generated by the local node are not affected. Rules are persisted on the node and reloaded at boot. + +The maximum number of active rules is set by `PACKET_FILTER_MAX_RULES` at build time. The default is `20`. + +### Show packet filter help +**Usage:** `block help` + +**Response:** +``` +block list [page] +block del +block clear +block stats +block advert help +block channelmessage help +``` + +--- + +### List packet filter rules +**Usage:** `block list [page]` + +**Parameters:** +- `page`: Optional page number. Defaults to `1`. + +**Response:** One numbered rule per line, or `-none-` when no rules are active. + +If there are more rules than fit in one response packet, the output ends with `..N`, where `N` is the next page number. + +**Examples:** +``` +block list +block list 2 +``` + +Example paged response: +``` +1: advert repeater 60 2 +2: channelmessage #test *spam* +..2 +``` + +--- + +### Delete a packet filter rule +**Usage:** `block del ` + +**Parameters:** +- `index`: Rule index shown by `block list`. + +Deletes the selected rule and updates the persisted rules file. The next added rule reuses the first inactive rule slot. + +**Example:** +``` +block del 1 +``` + +--- + +### Clear packet filter rules +**Usage:** `block clear` + +Clears all active rules and updates the persisted rules file. + +--- + +### Show packet filter statistics +**Usage:** `block stats` + +Shows the number of packets blocked since boot or since the filter statistics were reset by firmware. + +--- + +### Rate-limit forwarded adverts +**Usage:** `block advert [max]` + +**Parameters:** +- `advert_type`: `companion`, `repeater`, `room`, or a numeric advert type from `1` to `15` +- `period_minutes`: Length of the rate-limit window in minutes (`1`-`65535`) +- `max`: Optional maximum adverts per sender public key in each window. Defaults to `1`. + +**Example:** +``` +block advert repeater 60 2 +``` + +Allows each repeater to have at most two forwarded adverts per 60-minute window. + +**Notes:** +- This rule only matches node advertisement packets with the selected advert type. +- The limiter uses fixed windows, not a rolling window. For a 60-minute period, the per-sender count resets when the next 60-minute bucket starts. A sender near a bucket boundary can therefore have up to `max` adverts forwarded at the end of one period and another `max` at the start of the next period. +- Sender public keys are tracked with salted Bloom filters. If a filter saturates, the rule temporarily fails open for the rest of the window instead of blocking packets incorrectly. + +--- + +### Block public channel messages by pattern +**Usage:** `block channelmessage ""` + +**Parameters:** +- `channel_name`: A public channel name whose secret can be derived from the name. Examples: `Public`, `#test`. +- `pattern`: Text pattern to block. `*` matches zero or more characters and `?` matches one character. Matching is case-sensitive and must match the complete channel message text. + +**Examples:** +``` +block channelmessage #test "*spam*" +block channelmessage Public "badword*" +block channelmessage #test "BadUser: *" +block channelmessage #test "*: *BadWord*" +``` + +**Notes:** +- This rule decrypts only public channel messages whose channel secret can be derived from the channel name. It does not block private channels with random secrets. +- Channel message text includes the sender prefix used by MeshCore group text messages, normally `: `. To block all messages from a specific displayed sender name, match that prefix explicitly, such as `BadUser: *`. To block messages containing a word regardless of sender, include the separator in the pattern, such as `*: *BadWord*`. + +--- + ## Logging ### Begin capture of rx log to node storage diff --git a/examples/simple_repeater/MyMesh.cpp b/examples/simple_repeater/MyMesh.cpp index 666f79fc5c..0c00f828f5 100644 --- a/examples/simple_repeater/MyMesh.cpp +++ b/examples/simple_repeater/MyMesh.cpp @@ -447,6 +447,10 @@ bool MyMesh::allowPacketForward(const mesh::Packet *packet) { return false; } } + if (!packet_filter.shouldRepeatPacket(packet)) { + MESH_DEBUG_PRINTLN("allowPacketForward: packet blocked by packet filter"); + return false; + } return true; } @@ -845,6 +849,7 @@ MyMesh::MyMesh(mesh::MainBoard &board, mesh::Radio &radio, mesh::MillisecondCloc mesh::RTCClock &rtc, mesh::MeshTables &tables) : mesh::Mesh(radio, ms, rng, rtc, *new StaticPoolPacketManager(32), tables), region_map(key_store), temp_map(key_store), + packet_filter(&rng), _cli(board, rtc, sensors, region_map, acl, &_prefs, this), telemetry(MAX_PACKET_PAYLOAD - 4), discover_limiter(4, 120), // max 4 every 2 minutes @@ -926,6 +931,7 @@ void MyMesh::begin(FILESYSTEM *fs) { acl.load(_fs, self_id); // TODO: key_store.begin(); region_map.load(_fs); + packet_filter.load(_fs, PACKET_FILTER_RULES_FILE); // establish default-scope { @@ -1163,6 +1169,7 @@ void MyMesh::clearStats() { radio_driver.resetStats(); resetStats(); ((SimpleMeshTables *)getTables())->resetStats(); + packet_filter.resetStats(); } void MyMesh::handleCommand(uint32_t sender_timestamp, char *command, char *reply) { @@ -1251,6 +1258,9 @@ void MyMesh::handleCommand(uint32_t sender_timestamp, char *command, char *reply sendNodeDiscoverReq(); strcpy(reply, "OK - Discover sent"); } + } else if (memcmp(command, "block", 5) == 0 && (command[5] == 0 || command[5] == ' ')) { + // Packet filtering coommands. + packet_filter.handleBlockCommand(_fs, PACKET_FILTER_RULES_FILE, command, reply, 160); } else{ _cli.handleCommand(sender_timestamp, command, reply); // common CLI commands } diff --git a/examples/simple_repeater/MyMesh.h b/examples/simple_repeater/MyMesh.h index 8ed0317e69..10bbac5153 100644 --- a/examples/simple_repeater/MyMesh.h +++ b/examples/simple_repeater/MyMesh.h @@ -33,6 +33,7 @@ #include #include #include +#include #include "RateLimiter.h" #ifdef WITH_BRIDGE @@ -79,6 +80,7 @@ struct NeighbourInfo { #define FIRMWARE_ROLE "repeater" #define PACKET_LOG_FILE "/packet_log" +#define PACKET_FILTER_RULES_FILE "/packet_filter" class MyMesh : public mesh::Mesh, public CommonCLICallbacks { FILESYSTEM* _fs; @@ -95,6 +97,7 @@ class MyMesh : public mesh::Mesh, public CommonCLICallbacks { uint8_t reply_path_hash_size; TransportKeyStore key_store; RegionMap region_map, temp_map; + PacketFilter packet_filter; RegionEntry* load_stack[8]; RegionEntry* recv_pkt_region; TransportKey default_scope; diff --git a/platformio.ini b/platformio.ini index 864e5e1ffe..b33c217556 100644 --- a/platformio.ini +++ b/platformio.ini @@ -48,6 +48,7 @@ build_flags = -w -DNDEBUG -DRADIOLIB_STATIC_ONLY=1 -DRADIOLIB_GODMODE=1 build_src_filter = +<*.cpp> + + + + + + diff --git a/src/GroupChannel.cpp b/src/GroupChannel.cpp new file mode 100644 index 0000000000..c898da75b1 --- /dev/null +++ b/src/GroupChannel.cpp @@ -0,0 +1,37 @@ +#include "GroupChannel.h" +#include "Utils.h" + +#define DEFAULT_PUBLIC_CHANNEL_SECRET_HEX "8b3387e9c5cdea6ac9e5edbaa115cd72" + +namespace mesh { + +bool GroupChannel::deriveHash(int secret_len) { + if (secret_len != 16 && secret_len != 32) return false; + + Utils::sha256(hash, PATH_HASH_SIZE, secret, secret_len); + return true; +} + +void GroupChannel::deriveHash() { + static uint8_t zeroes[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + if (memcmp(&secret[16], zeroes, 16) == 0) { + deriveHash(16); + } else { + deriveHash(32); + } +} + +bool GroupChannel::derivePublicSecret(const char* name, uint8_t* dest_secret) { + if (name == NULL || dest_secret == NULL) return false; + + memset(dest_secret, 0, PUB_KEY_SIZE); + + if (strcmp(name, "Public") == 0) { + return Utils::fromHex(dest_secret, 16, DEFAULT_PUBLIC_CHANNEL_SECRET_HEX); + } + + Utils::sha256(dest_secret, 16, (const uint8_t*)name, strlen(name)); + return true; +} + +} diff --git a/src/GroupChannel.h b/src/GroupChannel.h new file mode 100644 index 0000000000..1a60793231 --- /dev/null +++ b/src/GroupChannel.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace mesh { + +class GroupChannel { +public: + uint8_t hash[PATH_HASH_SIZE]; + uint8_t secret[PUB_KEY_SIZE]; + + void deriveHash(); + bool deriveHash(int secret_len); + static bool derivePublicSecret(const char* name, uint8_t* dest_secret); +}; +} diff --git a/src/Mesh.h b/src/Mesh.h index f9f8786320..e7cc269a5f 100644 --- a/src/Mesh.h +++ b/src/Mesh.h @@ -1,15 +1,10 @@ #pragma once #include +#include namespace mesh { -class GroupChannel { -public: - uint8_t hash[PATH_HASH_SIZE]; - uint8_t secret[PUB_KEY_SIZE]; -}; - /** * An abstraction of the data tables needed to be maintained */ diff --git a/src/helpers/BaseChatMesh.cpp b/src/helpers/BaseChatMesh.cpp index 7ddc461d29..3f64059bac 100644 --- a/src/helpers/BaseChatMesh.cpp +++ b/src/helpers/BaseChatMesh.cpp @@ -860,7 +860,7 @@ ChannelDetails* BaseChatMesh::addChannel(const char* name, const char* psk_base6 memset(dest->channel.secret, 0, sizeof(dest->channel.secret)); int len = decode_base64((unsigned char *) psk_base64, strlen(psk_base64), dest->channel.secret); if (len == 32 || len == 16) { - mesh::Utils::sha256(dest->channel.hash, sizeof(dest->channel.hash), dest->channel.secret, len); + dest->channel.deriveHash(len); StrHelper::strncpy(dest->name, name, sizeof(dest->name)); num_channels++; return dest; @@ -876,15 +876,9 @@ bool BaseChatMesh::getChannel(int idx, ChannelDetails& dest) { return false; } bool BaseChatMesh::setChannel(int idx, const ChannelDetails& src) { - static uint8_t zeroes[] = { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 }; - if (idx >= 0 && idx < MAX_GROUP_CHANNELS) { channels[idx] = src; - if (memcmp(&src.channel.secret[16], zeroes, 16) == 0) { - mesh::Utils::sha256(channels[idx].channel.hash, sizeof(channels[idx].channel.hash), src.channel.secret, 16); // 128-bit key - } else { - mesh::Utils::sha256(channels[idx].channel.hash, sizeof(channels[idx].channel.hash), src.channel.secret, 32); // 256-bit key - } + channels[idx].channel.deriveHash(); return true; } return false; diff --git a/src/helpers/PacketFilter.cpp b/src/helpers/PacketFilter.cpp new file mode 100644 index 0000000000..e757082b90 --- /dev/null +++ b/src/helpers/PacketFilter.cpp @@ -0,0 +1,281 @@ +#include "PacketFilter.h" + +#include + +static File openPacketFilterRead(FILESYSTEM* fs, const char* path) { +#if defined(RP2040_PLATFORM) + return fs->open(path, "r"); +#else + return fs->open(path); +#endif +} + +static File openPacketFilterWrite(FILESYSTEM* fs, const char* path) { +#if defined(NRF52_PLATFORM) || defined(STM32_PLATFORM) + fs->remove(path); + return fs->open(path, FILE_O_WRITE); +#elif defined(RP2040_PLATFORM) + return fs->open(path, "w"); +#else + return fs->open(path, "w", true); +#endif +} + +static bool appendRuleLine(char* reply, size_t reply_len, size_t* used, const char* line) { + int written = snprintf(&reply[*used], reply_len - *used, "%s%s", *used == 0 ? "" : "\n", line); + if (written < 0) return false; + if ((size_t)written >= reply_len - *used) { + reply[reply_len - 1] = 0; + return false; + } + *used += written; + return true; +} + +static bool canAppendPagedLine(size_t used, size_t reply_len, const char* line) { + const size_t continuation_reserve = 8; // "\n..255" plus terminator + size_t line_len = strlen(line); + size_t needed = line_len + (used == 0 ? 0 : 1); + + return reply_len > continuation_reserve && needed < reply_len - continuation_reserve - used; +} + +static void consumePagedLine(size_t* used, const char* line) { + *used += strlen(line) + (*used == 0 ? 0 : 1); +} + +static bool appendContinuation(char* reply, size_t reply_len, size_t* used, uint8_t next_page) { + char line[8]; + snprintf(line, sizeof(line), "..%u", next_page); + return appendRuleLine(reply, reply_len, used, line); +} + +PacketFilter::PacketFilter(mesh::RNG* rng) { + _rule_count = 0; + _filtered_count = 0; + _rng = rng; +} + +bool PacketFilter::shouldRepeatPacket(const mesh::Packet* packet) { + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + if (!_rules[i].shouldRepeatPacket(packet)) { + _filtered_count++; + return false; + } + } + return true; +} + +bool PacketFilter::block(const char* classifier_name, char* args) { + int empty_rule_index = -1; + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + if (!_rules[i].isActive()) { + empty_rule_index = i; + break; + } + } + if (empty_rule_index < 0) return false; + + PacketFilterRule& new_rule = _rules[empty_rule_index]; + if (!new_rule.configure(classifier_name, args, nullptr, _rng)) return false; + + char new_rule_text[160]; + char existing_rule_text[160]; + new_rule.formatRule(new_rule_text, sizeof(new_rule_text)); + + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + if (i == empty_rule_index || !_rules[i].isActive()) continue; + + _rules[i].formatRule(existing_rule_text, sizeof(existing_rule_text)); + if (strcmp(new_rule_text, existing_rule_text) == 0) { + new_rule.clear(); + return false; + } + } + + _rule_count++; + return true; +} + +bool PacketFilter::deleteRule(uint8_t rule_index) { + if (rule_index == 0 || rule_index > PACKET_FILTER_MAX_RULES) return false; + + PacketFilterRule& rule = _rules[rule_index - 1]; + if (!rule.isActive()) return false; + + rule.clear(); + if (_rule_count > 0) _rule_count--; + return true; +} + +void PacketFilter::clear() { + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + _rules[i].clear(); + } + _rule_count = 0; +} + +void PacketFilter::formatRules(char* reply, size_t reply_len, uint8_t page) const { + if (reply_len == 0) return; + reply[0] = 0; + + if (page == 0) page = 1; + + size_t used = 0; + uint8_t current_page = 1; + bool has_active_rules = false; + bool has_page_rules = false; + char rule_text[128]; + char line[144]; + + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + if (!_rules[i].isActive()) continue; + has_active_rules = true; + + _rules[i].formatRule(rule_text, sizeof(rule_text)); + snprintf(line, sizeof(line), "%u: %s", (uint32_t)i + 1, rule_text); + + while (!canAppendPagedLine(used, reply_len, line)) { + if (current_page == page) { + appendContinuation(reply, reply_len, &used, current_page + 1); + return; + } + + current_page++; + used = 0; + } + + if (current_page == page) { + if (!appendRuleLine(reply, reply_len, &used, line)) return; + has_page_rules = true; + } else { + consumePagedLine(&used, line); + } + } + + if (!has_active_rules) { + strncpy(reply, "-none-", reply_len - 1); + reply[reply_len - 1] = 0; + } else if (!has_page_rules) { + snprintf(reply, reply_len, "Err - no such page"); + } +} + +static void formatBlockHelp(char* reply, size_t reply_len) { + snprintf(reply, reply_len, "block list [page]\nblock del \nblock clear\nblock stats\nblock advert help\nblock channelmessage help"); +} + +static bool addBlockRule(void* context, const char* classifier_name, char* args) { + if (context == NULL) return false; + return ((PacketFilter*)context)->block(classifier_name, args); +} + +void PacketFilter::handleBlockCommand(FILESYSTEM* fs, const char* path, char* command, char* reply, size_t reply_len) { + if (reply_len == 0) return; + + char* args = command + 5; + char* rule = BasePacketFilterClassifier::nextToken(args); + + bool should_save = false; + if (rule == NULL || strcmp(rule, "help") == 0) { + formatBlockHelp(reply, reply_len); + return; + } else if (strcmp(rule, "list") == 0) { + char* page_token = BasePacketFilterClassifier::nextToken(args); + char* extra = BasePacketFilterClassifier::nextToken(args); + unsigned long page = 1; + if (extra != NULL || (page_token != NULL && !BasePacketFilterClassifier::parseUnsigned(page_token, 1, 255, &page))) { + snprintf(reply, reply_len, "Usage: block list [page]"); + return; + } + formatRules(reply, reply_len, (uint8_t)page); + return; + } else if (strcmp(rule, "del") == 0) { + char* index_token = BasePacketFilterClassifier::nextToken(args); + char* extra = BasePacketFilterClassifier::nextToken(args); + unsigned long rule_index = 0; + if (extra != NULL || !BasePacketFilterClassifier::parseUnsigned(index_token, 1, PACKET_FILTER_MAX_RULES, &rule_index)) { + snprintf(reply, reply_len, "Usage: block del "); + return; + } + if (deleteRule((uint8_t)rule_index)) { + snprintf(reply, reply_len, "OK - packet filter rule deleted"); + should_save = true; + } else { + snprintf(reply, reply_len, "Err - packet filter rule not found"); + } + } else if (strcmp(rule, "clear") == 0) { + clear(); + snprintf(reply, reply_len, "OK - packet filter rules cleared"); + should_save = true; + } else if (strcmp(rule, "stats") == 0) { + snprintf(reply, reply_len, "Filtered packets: %u", _filtered_count); + } else if (strcmp(rule, "channelmessage") == 0) { + should_save = ChannelMessageClassifier::handleBlockCommand(args, reply, reply_len, addBlockRule, this); + } else if (strcmp(rule, "advert") == 0) { + should_save = AdvertRateLimitClassifier::handleBlockCommand(args, reply, reply_len, addBlockRule, this); + } else { + snprintf(reply, reply_len, "Err - unknown block rule"); + } + + if (should_save && !save(fs, path)) { + snprintf(reply, reply_len, "Err - packet filter rule save failed"); + } +} + +static void loadPacketFilterRuleLine(PacketFilter* packet_filter, char* line) { + char* args = line; + char* classifier_name = BasePacketFilterClassifier::nextToken(args); + if (classifier_name == NULL) return; + + packet_filter->block(classifier_name, args); +} + +bool PacketFilter::load(FILESYSTEM* fs, const char* path) { + clear(); + if (fs == NULL || path == NULL || !fs->exists(path)) return true; + + File file = openPacketFilterRead(fs, path); + if (!file) return false; + + char line[160]; + size_t line_len = 0; + + while (file.available()) { + int c = file.read(); + if (c == '\r') continue; + if (c == '\n') { + line[line_len] = 0; + loadPacketFilterRuleLine(this, line); + line_len = 0; + } else if (line_len < sizeof(line) - 1) { + line[line_len++] = (char)c; + } + } + + if (line_len > 0) { + line[line_len] = 0; + loadPacketFilterRuleLine(this, line); + } + + file.close(); + return true; +} + +bool PacketFilter::save(FILESYSTEM* fs, const char* path) const { + if (fs == NULL || path == NULL) return false; + + File file = openPacketFilterWrite(fs, path); + if (!file) return false; + + char line[160]; + for (uint8_t i = 0; i < PACKET_FILTER_MAX_RULES; i++) { + if (!_rules[i].isActive()) continue; + _rules[i].formatRule(line, sizeof(line)); + file.print(line); + file.print('\n'); + } + + file.close(); + return true; +} diff --git a/src/helpers/PacketFilter.h b/src/helpers/PacketFilter.h new file mode 100644 index 0000000000..6824069416 --- /dev/null +++ b/src/helpers/PacketFilter.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#ifndef PACKET_FILTER_MAX_RULES +#define PACKET_FILTER_MAX_RULES 20 +#endif + +class PacketFilter { + PacketFilterRule _rules[PACKET_FILTER_MAX_RULES]; + uint8_t _rule_count; + uint32_t _filtered_count; + mesh::RNG* _rng; + +public: + PacketFilter(mesh::RNG* rng = nullptr); + + bool shouldRepeatPacket(const mesh::Packet* packet); + bool block(const char* classifier_name, char* args); + bool deleteRule(uint8_t rule_index); + void clear(); + + uint32_t getFilteredCount() const { return _filtered_count; } + void resetStats() { _filtered_count = 0; } + + uint8_t getRuleCount() const { return _rule_count; } + void formatRules(char* reply, size_t reply_len, uint8_t page = 1) const; + void handleBlockCommand(FILESYSTEM* fs, const char* path, char* command, char* reply, size_t reply_len); + + bool load(FILESYSTEM* fs, const char* path); + bool save(FILESYSTEM* fs, const char* path) const; +}; diff --git a/src/helpers/SimpleBloomFilter.cpp b/src/helpers/SimpleBloomFilter.cpp new file mode 100644 index 0000000000..756b6a8025 --- /dev/null +++ b/src/helpers/SimpleBloomFilter.cpp @@ -0,0 +1,140 @@ +#include + +#include +#include + +#define BLOOM_GOLDEN_RATIO_32 0x9E3779B9UL + +SimpleBloomFilter::SimpleBloomFilter() { + _data = NULL; + _byte_count = 0; + _bit_count = 0; + _hash_count = 0; + _salt = NULL; + _salt_len = 0; + _set_bits = 0; +} + +bool SimpleBloomFilter::begin( + uint8_t* storage, + uint16_t storage_len, + uint32_t bit_count, + uint8_t hash_count +) { + if (storage == NULL) return false; + if (storage_len == 0) return false; + if (bit_count == 0) return false; + if (hash_count == 0) return false; + if (bit_count > ((uint32_t)storage_len * 8UL)) return false; + + _data = storage; + _byte_count = storage_len; + _bit_count = bit_count; + _hash_count = hash_count; + + clear(); + return true; +} + +void SimpleBloomFilter::setSalt(const uint8_t* salt, uint8_t salt_len) { + _salt = salt; + _salt_len = salt == NULL ? 0 : salt_len; +} + +void SimpleBloomFilter::clear() { + if (_data != NULL && _byte_count > 0) { + memset(_data, 0, _byte_count); + } + + _set_bits = 0; +} + +bool SimpleBloomFilter::getBit(uint32_t bit) const { + return (_data[bit >> 3] & (uint8_t)(1U << (bit & 7))) != 0; +} + +void SimpleBloomFilter::setBit(uint32_t bit) { + uint8_t mask = (uint8_t)(1U << (bit & 7)); + uint8_t* byte = &_data[bit >> 3]; + + if ((*byte & mask) == 0) { + *byte |= mask; + _set_bits++; + } +} + +void SimpleBloomFilter::calculateDigest( + uint8_t digest[], + const uint8_t* data, + size_t data_len +) const { + SHA256 sha; + + if (_salt != NULL && _salt_len > 0) { + sha.update(_salt, _salt_len); + } + + sha.update(data, data_len); + sha.finalize(digest, SHA256::HASH_SIZE); +} + +uint32_t SimpleBloomFilter::calculateBit( + const uint8_t digest[], + uint8_t index +) const { + uint32_t h1; + uint32_t h2; + + memcpy(&h1, digest, sizeof(h1)); + memcpy(&h2, digest + sizeof(h1), sizeof(h2)); + + if (h2 == 0) { + h2 = BLOOM_GOLDEN_RATIO_32; + } + + uint32_t i = (uint32_t)index; + uint32_t mixed = h1 + (i * h2) + (i * i); + + return mixed % _bit_count; +} + +bool SimpleBloomFilter::contains(const uint8_t* data, size_t data_len) const { + if (_data == NULL) return false; + if (_bit_count == 0) return false; + if (_hash_count == 0) return false; + if (data == NULL) return false; + + uint8_t digest[SHA256::HASH_SIZE]; + calculateDigest(digest, data, data_len); + + for (uint8_t i = 0; i < _hash_count; i++) { + uint32_t bit = calculateBit(digest, i); + + if (!getBit(bit)) { + return false; + } + } + + return true; +} + +void SimpleBloomFilter::add(const uint8_t* data, size_t data_len) { + if (_data == NULL) return; + if (_bit_count == 0) return; + if (_hash_count == 0) return; + if (data == NULL) return; + + uint8_t digest[SHA256::HASH_SIZE]; + calculateDigest(digest, data, data_len); + + for (uint8_t i = 0; i < _hash_count; i++) { + uint32_t bit = calculateBit(digest, i); + setBit(bit); + } +} + +bool SimpleBloomFilter::isSaturated(uint8_t max_fill_percent) const { + if (_bit_count == 0) return true; + + return (_set_bits * 100UL) >= (_bit_count * (uint32_t)max_fill_percent); +} \ No newline at end of file diff --git a/src/helpers/SimpleBloomFilter.h b/src/helpers/SimpleBloomFilter.h new file mode 100644 index 0000000000..ee698dcfa9 --- /dev/null +++ b/src/helpers/SimpleBloomFilter.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +class SimpleBloomFilter { +public: + SimpleBloomFilter(); + + bool begin(uint8_t* storage, uint16_t storage_len, uint32_t bit_count, uint8_t hash_count); + + void setSalt(const uint8_t* salt, uint8_t salt_len); + void clear(); + + bool contains(const uint8_t* data, size_t data_len) const; + void add(const uint8_t* data, size_t data_len); + + bool isSaturated(uint8_t max_fill_percent) const; + +private: + uint8_t* _data; + uint16_t _byte_count; + uint32_t _bit_count; + uint8_t _hash_count; + + const uint8_t* _salt; + uint8_t _salt_len; + + uint32_t _set_bits; + + bool getBit(uint32_t bit) const; + void setBit(uint32_t bit); + + void calculateDigest(uint8_t digest[], const uint8_t* data, size_t data_len) const; + uint32_t calculateBit(const uint8_t digest[], uint8_t index) const; +}; \ No newline at end of file diff --git a/src/helpers/SimplePatternMatcher.cpp b/src/helpers/SimplePatternMatcher.cpp new file mode 100644 index 0000000000..5c8870bea1 --- /dev/null +++ b/src/helpers/SimplePatternMatcher.cpp @@ -0,0 +1,33 @@ +#include "SimplePatternMatcher.h" + +bool SimplePatternMatcher::matches(const char* pattern, const char* text) { + if (pattern == nullptr || text == nullptr) return false; + return matchHere(pattern, text); +} + +bool SimplePatternMatcher::matchHere(const char* pattern, const char* text) { + const char* starPattern = nullptr; + const char* starText = nullptr; + + while (*text) { + if (*pattern == '*') { + while (*pattern == '*') pattern++; + if (*pattern == '\0') return true; + + starPattern = pattern; + starText = text; + } else if (*pattern == '?' || *pattern == *text) { + pattern++; + text++; + } else if (starPattern != nullptr) { + pattern = starPattern; + text = ++starText; + } else { + return false; + } + } + + while (*pattern == '*') pattern++; + + return *pattern == '\0'; +} \ No newline at end of file diff --git a/src/helpers/SimplePatternMatcher.h b/src/helpers/SimplePatternMatcher.h new file mode 100644 index 0000000000..2b2e1bfa9e --- /dev/null +++ b/src/helpers/SimplePatternMatcher.h @@ -0,0 +1,9 @@ +#pragma once + +class SimplePatternMatcher { +public: + static bool matches(const char* pattern, const char* text); + +private: + static bool matchHere(const char* pattern, const char* text); +}; \ No newline at end of file diff --git a/src/helpers/packet_filter/AdvertRateLimitClassifier.cpp b/src/helpers/packet_filter/AdvertRateLimitClassifier.cpp new file mode 100644 index 0000000000..37721b74d3 --- /dev/null +++ b/src/helpers/packet_filter/AdvertRateLimitClassifier.cpp @@ -0,0 +1,182 @@ +#include "helpers/packet_filter/AdvertRateLimitClassifier.h" + +#include +#include +#include + +AdvertRateLimitClassifier::AdvertRateLimitClassifier(mesh::RNG* rng) { + _rng = rng; + _saturated = false; + _advert_type = ADV_TYPE_REPEATER; + _period_minutes = 0; + _max_per_period = 0; + _window_id = 0xFFFFFFFFUL; + memset(_salts, 0, sizeof(_salts)); + configureFilters(); +} + +void AdvertRateLimitClassifier::generateSalts() { + for (uint8_t i = 0; i < PACKET_FILTER_ADVERT_MAX_PER_PERIOD; i++) { + if (_rng != nullptr) { + _rng->random(_salts[i], sizeof(_salts[i])); + } + _filters[i].setSalt(_salts[i], sizeof(_salts[i])); + } +} + +void AdvertRateLimitClassifier::clearFilters() { + for (uint8_t i = 0; i < PACKET_FILTER_ADVERT_MAX_PER_PERIOD; i++) { + _filters[i].clear(); + } + _saturated = false; +} + +void AdvertRateLimitClassifier::configureFilters() { + for (uint8_t i = 0; i < PACKET_FILTER_ADVERT_MAX_PER_PERIOD; i++) { + _filters[i].begin(_filter_storage[i], sizeof(_filter_storage[i]), PACKET_FILTER_ADVERT_BLOOM_BITS, PACKET_FILTER_ADVERT_BLOOM_HASH_COUNT); + _filters[i].setSalt(_salts[i], sizeof(_salts[i])); + } + clearFilters(); +} + +bool AdvertRateLimitClassifier::configureParsed(uint8_t advert_type, uint16_t period_minutes, uint8_t max_per_period) { + if (period_minutes == 0 || max_per_period == 0) return false; + if (max_per_period > PACKET_FILTER_ADVERT_MAX_PER_PERIOD) return false; + if (advert_type > 0x0F) return false; + if (PACKET_FILTER_ADVERT_BLOOM_BITS == 0 || PACKET_FILTER_ADVERT_BLOOM_HASH_COUNT == 0) return false; + + _advert_type = advert_type; + _period_minutes = period_minutes; + _max_per_period = max_per_period; + _window_id = 0xFFFFFFFFUL; + generateSalts(); + clearFilters(); + return true; +} + +bool AdvertRateLimitClassifier::configure(char* args) { + char* advert_type = nextToken(args); + char* period = nextToken(args); + char* max_count = nextToken(args); + char* extra = nextToken(args); + + unsigned long period_minutes = 0; + unsigned long max_per_period = 1; + uint8_t parsed_type; + if (!parseAdvertType(advert_type, &parsed_type)) return false; + if (!parseUnsigned(period, 1, 65535UL, &period_minutes)) return false; + if (max_count != NULL && !parseUnsigned(max_count, 1, PACKET_FILTER_ADVERT_MAX_PER_PERIOD, &max_per_period)) return false; + if (extra != NULL) return false; + + return configureParsed(parsed_type, (uint16_t)period_minutes, (uint8_t)max_per_period); +} + +void AdvertRateLimitClassifier::formatRule(char* dest, size_t dest_len) const { + char type_name[4]; + snprintf( + dest, + dest_len, + "advert %s %u %u", + formatAdvertType(_advert_type, type_name, sizeof(type_name)), + _period_minutes, + _max_per_period + ); +} + +void AdvertRateLimitClassifier::resetWindowIfNeeded() { + uint32_t period_ms = ((uint32_t)_period_minutes) * 60UL * 1000UL; + uint32_t window_id = period_ms == 0 ? 0 : millis() / period_ms; + if (window_id != _window_id) { + generateSalts(); + clearFilters(); + _window_id = window_id; + } +} + +bool AdvertRateLimitClassifier::filterContains(uint8_t filter_idx, const uint8_t* cert) const { + return _filters[filter_idx].contains(cert, PUB_KEY_SIZE); +} + +void AdvertRateLimitClassifier::addToFilter(uint8_t filter_idx, const uint8_t* cert) { + _filters[filter_idx].add(cert, PUB_KEY_SIZE); + if (isFilterSaturated(filter_idx)) _saturated = true; +} + +bool AdvertRateLimitClassifier::isFilterSaturated(uint8_t filter_idx) const { + return _filters[filter_idx].isSaturated(PACKET_FILTER_ADVERT_BLOOM_MAX_FILL_PERCENT); +} + +bool AdvertRateLimitClassifier::shouldRepeatPacket(const mesh::Packet* packet) { + if (packet == NULL || packet->getPayloadType() != PAYLOAD_TYPE_ADVERT) return true; + + const uint16_t app_data_offset = PUB_KEY_SIZE + 4 + SIGNATURE_SIZE; + if (packet->payload_len <= app_data_offset) return true; + + AdvertDataParser parser(&packet->payload[app_data_offset], packet->payload_len - app_data_offset); + if (!parser.isValid() || parser.getType() != _advert_type) return true; + + resetWindowIfNeeded(); + if (_saturated) return true; + + uint8_t seen_count = 0; + uint8_t first_available = 0xFF; + for (uint8_t i = 0; i < _max_per_period; i++) { + if (filterContains(i, packet->payload)) { + seen_count++; + } else if (first_available == 0xFF) { + first_available = i; + } + } + + if (seen_count >= _max_per_period) return false; + if (first_available != 0xFF) addToFilter(first_available, packet->payload); + return true; +} + +bool AdvertRateLimitClassifier::handleBlockCommand(char* args, char* reply, size_t reply_len, BlockRuleFn block_rule, void* block_context) { + return BasePacketFilterClassifier::handleBlockCommand( + "advert", + args, + reply, + reply_len, + "Usage: block advert repeater 60 [max]", + "OK - advert block added", + block_rule, + block_context + ); +} + +bool AdvertRateLimitClassifier::parseAdvertType(const char* name, uint8_t* advert_type) { + if (name == NULL || advert_type == NULL || name[0] == 0) return false; + + if (strcmp(name, "companion") == 0) { + *advert_type = ADV_TYPE_CHAT; + } else if (strcmp(name, "repeater") == 0) { + *advert_type = ADV_TYPE_REPEATER; + } else if (strcmp(name, "room") == 0) { + *advert_type = ADV_TYPE_ROOM; + } else { + char* end = NULL; + long value = strtol(name, &end, 10); + if (end == name || *end != 0 || value <= 0 || value > 15) return false; + *advert_type = (uint8_t)value; + } + + return true; +} + +const char* AdvertRateLimitClassifier::advertTypeName(uint8_t advert_type) { + switch (advert_type) { + case ADV_TYPE_CHAT: return "companion"; + case ADV_TYPE_REPEATER: return "repeater"; + case ADV_TYPE_ROOM: return "room"; + default: return "?"; + } +} + +const char* AdvertRateLimitClassifier::formatAdvertType(uint8_t advert_type, char* dest, size_t dest_len) { + const char* name = advertTypeName(advert_type); + if (strcmp(name, "?") != 0) return name; + snprintf(dest, dest_len, "%u", advert_type); + return dest; +} diff --git a/src/helpers/packet_filter/AdvertRateLimitClassifier.h b/src/helpers/packet_filter/AdvertRateLimitClassifier.h new file mode 100644 index 0000000000..70daa5d460 --- /dev/null +++ b/src/helpers/packet_filter/AdvertRateLimitClassifier.h @@ -0,0 +1,67 @@ +#pragma once + +#include "BasePacketFilterClassifier.h" + +#include +#include +#include + +#ifndef PACKET_FILTER_ADVERT_BLOOM_BITS +#define PACKET_FILTER_ADVERT_BLOOM_BITS 1024 +#endif + +#ifndef PACKET_FILTER_ADVERT_BLOOM_STORAGE_BYTES +#define PACKET_FILTER_ADVERT_BLOOM_STORAGE_BYTES ((PACKET_FILTER_ADVERT_BLOOM_BITS + 7) / 8) +#endif + +#ifndef PACKET_FILTER_ADVERT_MAX_PER_PERIOD +#define PACKET_FILTER_ADVERT_MAX_PER_PERIOD 4 +#endif + +#ifndef PACKET_FILTER_ADVERT_BLOOM_MAX_FILL_PERCENT +#define PACKET_FILTER_ADVERT_BLOOM_MAX_FILL_PERCENT 75 +#endif + +#ifndef PACKET_FILTER_ADVERT_BLOOM_HASH_COUNT +#define PACKET_FILTER_ADVERT_BLOOM_HASH_COUNT 4 +#endif + +#ifndef PACKET_FILTER_ADVERT_BLOOM_SALT_BYTES +#define PACKET_FILTER_ADVERT_BLOOM_SALT_BYTES 16 +#endif + +class AdvertRateLimitClassifier : public BasePacketFilterClassifier { + bool _saturated; + uint8_t _advert_type; + uint16_t _period_minutes; + uint8_t _max_per_period; + uint32_t _window_id; + mesh::RNG* _rng; + uint8_t _salts[PACKET_FILTER_ADVERT_MAX_PER_PERIOD][PACKET_FILTER_ADVERT_BLOOM_SALT_BYTES]; + uint8_t _filter_storage[PACKET_FILTER_ADVERT_MAX_PER_PERIOD][PACKET_FILTER_ADVERT_BLOOM_STORAGE_BYTES]; + SimpleBloomFilter _filters[PACKET_FILTER_ADVERT_MAX_PER_PERIOD]; + + void generateSalts(); + void clearFilters(); + void configureFilters(); + void resetWindowIfNeeded(); + bool filterContains(uint8_t filter_idx, const uint8_t* cert) const; + void addToFilter(uint8_t filter_idx, const uint8_t* cert); + bool isFilterSaturated(uint8_t filter_idx) const; + bool configureParsed(uint8_t advert_type, uint16_t period_minutes, uint8_t max_per_period); + +public: + AdvertRateLimitClassifier(mesh::RNG* rng = nullptr); + + bool configure(char* args); + + void formatRule(char* dest, size_t dest_len) const; + bool shouldRepeatPacket(const mesh::Packet* packet); + + static bool handleBlockCommand(char* args, char* reply, size_t reply_len, BlockRuleFn block_rule, void* block_context); + +private: + static bool parseAdvertType(const char* name, uint8_t* advert_type); + static const char* advertTypeName(uint8_t advert_type); + static const char* formatAdvertType(uint8_t advert_type, char* dest, size_t dest_len); +}; diff --git a/src/helpers/packet_filter/BasePacketFilterClassifier.h b/src/helpers/packet_filter/BasePacketFilterClassifier.h new file mode 100644 index 0000000000..c42f45d6e8 --- /dev/null +++ b/src/helpers/packet_filter/BasePacketFilterClassifier.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include + +class BasePacketFilterClassifier { +public: + typedef bool (*BlockRuleFn)(void* context, const char* classifier_name, char* args); + + static bool isArgSpace(char c) { + return c == ' ' || c == '\t'; + } + + static char* nextToken(char*& cursor) { + if (cursor == NULL) return NULL; + while (isArgSpace(*cursor)) cursor++; + if (*cursor == 0) return NULL; + + char* token = cursor; + while (*cursor && !isArgSpace(*cursor)) cursor++; + if (*cursor) *cursor++ = 0; + return token; + } + + static char* remainingArgs(char*& cursor) { + if (cursor == NULL) return NULL; + while (isArgSpace(*cursor)) cursor++; + return cursor; + } + + static char* unquote(char* value) { + if (value == NULL || *value != '"') return value; + + value++; + char* end = strrchr(value, '"'); + if (end != NULL) *end = 0; + return value; + } + + static bool parseUnsigned(const char* token, unsigned long min_value, unsigned long max_value, unsigned long* value) { + if (token == NULL || value == NULL) return false; + + char* end = NULL; + unsigned long parsed = strtoul(token, &end, 10); + if (end == token || *end != 0 || parsed < min_value || parsed > max_value) return false; + + *value = parsed; + return true; + } + + static bool handleBlockCommand(const char* classifier_name, char* args, char* reply, size_t reply_len, + const char* usage, const char* success_message, + BlockRuleFn block_rule, void* block_context) { + char* classifier_args = remainingArgs(args); + if (classifier_args == NULL || *classifier_args == 0 || strcmp(classifier_args, "help") == 0) { + snprintf(reply, reply_len, "%s", usage); + return false; + } + + if (block_rule != NULL && block_rule(block_context, classifier_name, classifier_args)) { + snprintf(reply, reply_len, "%s", success_message); + return true; + } + + snprintf(reply, reply_len, "Err - packet filter rule invalid or full"); + return false; + } +}; diff --git a/src/helpers/packet_filter/ChannelMessageClassifier.cpp b/src/helpers/packet_filter/ChannelMessageClassifier.cpp new file mode 100644 index 0000000000..282961e19f --- /dev/null +++ b/src/helpers/packet_filter/ChannelMessageClassifier.cpp @@ -0,0 +1,69 @@ +#include "helpers/packet_filter/ChannelMessageClassifier.h" + +#include +#include +#include +#include + +ChannelMessageClassifier::ChannelMessageClassifier() { + memset(_channel_name, 0, sizeof(_channel_name)); + memset(_pattern, 0, sizeof(_pattern)); + memset(_channel_hash, 0, sizeof(_channel_hash)); + memset(_channel_secret, 0, sizeof(_channel_secret)); +} + +bool ChannelMessageClassifier::configure(char* args) { + char* channel_name = nextToken(args); + char* pattern = unquote(remainingArgs(args)); + + if (channel_name == NULL || pattern == NULL || channel_name[0] == 0 || pattern[0] == 0) return false; + if (strlen(channel_name) >= sizeof(_channel_name) || strlen(pattern) >= sizeof(_pattern)) return false; + + uint8_t secret[PUB_KEY_SIZE]; + if (!mesh::GroupChannel::derivePublicSecret(channel_name, secret)) return false; + + memset(_channel_name, 0, sizeof(_channel_name)); + memset(_pattern, 0, sizeof(_pattern)); + strncpy(_channel_name, channel_name, sizeof(_channel_name) - 1); + strncpy(_pattern, pattern, sizeof(_pattern) - 1); + memcpy(_channel_secret, secret, sizeof(_channel_secret)); + mesh::GroupChannel channel; + memcpy(channel.secret, _channel_secret, sizeof(channel.secret)); + channel.deriveHash(); + memcpy(_channel_hash, channel.hash, sizeof(_channel_hash)); + return true; +} + +void ChannelMessageClassifier::formatRule(char* dest, size_t dest_len) const { + snprintf(dest, dest_len, "channelmessage %s %s", _channel_name, _pattern); +} + +bool ChannelMessageClassifier::shouldRepeatPacket(const mesh::Packet* packet) { + if (packet == NULL || packet->getPayloadType() != PAYLOAD_TYPE_GRP_TXT || packet->payload_len <= 1) return true; + if (packet->payload[0] != _channel_hash[0]) return true; + + uint8_t data[MAX_PACKET_PAYLOAD + 1]; + int len = mesh::Utils::MACThenDecrypt(_channel_secret, data, &packet->payload[1], packet->payload_len - 1); + if (len <= 5) return true; + + uint8_t txt_type = data[4] >> 2; + if (txt_type != TXT_TYPE_PLAIN) return true; + + data[len] = 0; + const char* text = (const char*)&data[5]; + return !SimplePatternMatcher::matches(_pattern, text); +} + +bool ChannelMessageClassifier::handleBlockCommand(char* args, char* reply, size_t reply_len, + BlockRuleFn block_rule, void* block_context) { + return BasePacketFilterClassifier::handleBlockCommand( + "channelmessage", + args, + reply, + reply_len, + "Usage: block channelmessage #channel \"pattern\"", + "OK - channel message block added", + block_rule, + block_context + ); +} diff --git a/src/helpers/packet_filter/ChannelMessageClassifier.h b/src/helpers/packet_filter/ChannelMessageClassifier.h new file mode 100644 index 0000000000..6d81230751 --- /dev/null +++ b/src/helpers/packet_filter/ChannelMessageClassifier.h @@ -0,0 +1,30 @@ +#pragma once + +#include "BasePacketFilterClassifier.h" + +#include + +#ifndef PACKET_FILTER_CHANNEL_NAME_SIZE +#define PACKET_FILTER_CHANNEL_NAME_SIZE 32 +#endif + +#ifndef PACKET_FILTER_PATTERN_SIZE +#define PACKET_FILTER_PATTERN_SIZE 80 +#endif + +class ChannelMessageClassifier : public BasePacketFilterClassifier { + char _channel_name[PACKET_FILTER_CHANNEL_NAME_SIZE]; + char _pattern[PACKET_FILTER_PATTERN_SIZE]; + uint8_t _channel_hash[PATH_HASH_SIZE]; + uint8_t _channel_secret[PUB_KEY_SIZE]; + +public: + ChannelMessageClassifier(); + + bool configure(char* args); + + void formatRule(char* dest, size_t dest_len) const; + bool shouldRepeatPacket(const mesh::Packet* packet); + + static bool handleBlockCommand(char* args, char* reply, size_t reply_len, BlockRuleFn block_rule, void* block_context); +}; diff --git a/src/helpers/packet_filter/PacketFilterRule.h b/src/helpers/packet_filter/PacketFilterRule.h new file mode 100644 index 0000000000..71b83544b8 --- /dev/null +++ b/src/helpers/packet_filter/PacketFilterRule.h @@ -0,0 +1,96 @@ +#pragma once + +#include "AdvertRateLimitClassifier.h" +#include "ChannelMessageClassifier.h" + +#include +#include + +class PacketFilterRule { +public: + enum Type { + RULE_NONE, + RULE_CHANNEL_MESSAGE, + RULE_ADVERT + }; + +private: + Type _type; + + union Classifier { + ChannelMessageClassifier channel_message; + AdvertRateLimitClassifier advert; + + Classifier() {} + ~Classifier() {} + } _classifier; + +public: + PacketFilterRule() : _type(RULE_NONE) {} + ~PacketFilterRule() { clear(); } + + void clear() { + switch (_type) { + case RULE_CHANNEL_MESSAGE: + _classifier.channel_message.~ChannelMessageClassifier(); + break; + case RULE_ADVERT: + _classifier.advert.~AdvertRateLimitClassifier(); + break; + default: + break; + } + _type = RULE_NONE; + } + + bool configure(const char* classifier_name, char* args, const mesh::Identity* self = NULL, mesh::RNG* rng = NULL) { + (void)self; + if (classifier_name == NULL) return false; + + if (strcmp(classifier_name, "channelmessage") == 0) { + ChannelMessageClassifier next; + if (!next.configure(args)) return false; + clear(); + new (&_classifier.channel_message) ChannelMessageClassifier(next); + _type = RULE_CHANNEL_MESSAGE; + return true; + } else if (strcmp(classifier_name, "advert") == 0) { + AdvertRateLimitClassifier next(rng); + if (!next.configure(args)) return false; + clear(); + new (&_classifier.advert) AdvertRateLimitClassifier(next); + _type = RULE_ADVERT; + return true; + } + + return false; + } + + bool isActive() const { return _type != RULE_NONE; } + Type getType() const { return _type; } + + void formatRule(char* dest, size_t dest_len) const { + switch (_type) { + case RULE_CHANNEL_MESSAGE: + _classifier.channel_message.formatRule(dest, dest_len); + break; + case RULE_ADVERT: + _classifier.advert.formatRule(dest, dest_len); + break; + default: + if (dest_len > 0) dest[0] = 0; + break; + } + } + + bool shouldRepeatPacket(const mesh::Packet* packet) { + switch (_type) { + case RULE_CHANNEL_MESSAGE: + return _classifier.channel_message.shouldRepeatPacket(packet); + case RULE_ADVERT: + return _classifier.advert.shouldRepeatPacket(packet); + default: + return true; + } + } +};