From b749a634a62bd409e13777deac404e32d6585975 Mon Sep 17 00:00:00 2001 From: Loris Moulin Date: Mon, 29 Jun 2026 13:34:40 -0300 Subject: [PATCH 1/2] Apply dispatch extension conditionally --- .../src/guards/check_coldkey_swap.rs | 25 ++- .../src/guards/check_delegate_take.rs | 57 +++++- .../src/guards/check_evm_key_association.rs | 54 ++++- .../subtensor/src/guards/check_rate_limits.rs | 81 +++++++- .../src/guards/check_serving_endpoints.rs | 61 +++++- pallets/subtensor/src/guards/check_weights.rs | 190 ++++++++++++------ pallets/subtensor/src/guards/mod.rs | 19 ++ 7 files changed, 384 insertions(+), 103 deletions(-) diff --git a/pallets/subtensor/src/guards/check_coldkey_swap.rs b/pallets/subtensor/src/guards/check_coldkey_swap.rs index 5f124be219..907fed1d3b 100644 --- a/pallets/subtensor/src/guards/check_coldkey_swap.rs +++ b/pallets/subtensor/src/guards/check_coldkey_swap.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf}; use crate::weights::WeightInfo; use crate::{Call, ColdkeySwapAnnouncements, ColdkeySwapDisputes, Config, Error}; use frame_support::{ @@ -8,9 +9,6 @@ use frame_support::{ use sp_runtime::traits::Dispatchable; use sp_std::marker::PhantomData; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; - /// Dispatch extension that blocks most calls when a coldkey swap is active. /// /// When a coldkey swap has been announced for the signing account: @@ -96,9 +94,14 @@ where #[allow(clippy::expect_used, clippy::unwrap_used)] mod tests { use super::CheckColdkeySwap; - use crate::{ColdkeySwapAnnouncements, ColdkeySwapDisputes, Error, tests::mock::*}; + use crate::{ + ColdkeySwapAnnouncements, ColdkeySwapDisputes, Error, tests::mock::*, + weights::WeightInfo as _, + }; use frame_support::{ - BoundedVec, assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable, + BoundedVec, assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ExtendedDispatchable, }; use frame_system::Call as SystemCall; use pallet_subtensor_proxy::Call as ProxyCall; @@ -176,6 +179,18 @@ mod tests { ) } + #[test] + fn weight_charges_all_calls_because_swap_state_can_block_any_signed_call() { + let expected = ::WeightInfo::check_coldkey_swap_extension(); + + for call in forbidden_calls().into_iter().chain(authorized_calls()) { + assert_eq!( + as DispatchExtension>::weight(&call), + expected + ); + } + } + #[test] fn no_active_swap_allows_calls() { new_test_ext(1).execute_with(|| { diff --git a/pallets/subtensor/src/guards/check_delegate_take.rs b/pallets/subtensor/src/guards/check_delegate_take.rs index c9f54d4cb5..c80d969afc 100644 --- a/pallets/subtensor/src/guards/check_delegate_take.rs +++ b/pallets/subtensor/src/guards/check_delegate_take.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf, applicable_call}; use crate::weights::WeightInfo; use crate::{Call, Config, Error, Pallet}; use frame_support::{ @@ -8,9 +9,6 @@ use frame_support::{ use sp_runtime::traits::Dispatchable; use sp_std::marker::PhantomData; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; - /// Dispatch extension for delegate-take bounds and ownership preconditions. /// /// Signed increase/decrease take calls are checked before dispatch; unrelated @@ -18,6 +16,13 @@ type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; pub struct CheckDelegateTake(PhantomData); impl CheckDelegateTake { + pub(crate) fn applies_to(call: &Call) -> bool { + matches!( + call, + Call::increase_take { .. } | Call::decrease_take { .. } + ) + } + pub fn check(who: &T::AccountId, call: &Call) -> Result<(), Error> { match call { Call::increase_take { hotkey, take } | Call::decrease_take { hotkey, take } => { @@ -42,8 +47,10 @@ where { type Pre = (); - fn weight(_call: &CallOf) -> Weight { - ::WeightInfo::check_delegate_take_extension() + fn weight(call: &CallOf) -> Weight { + applicable_call(call, Self::applies_to) + .map(|_| ::WeightInfo::check_delegate_take_extension()) + .unwrap_or(Weight::zero()) } fn pre_dispatch( @@ -54,7 +61,7 @@ where return Ok(()); }; - let Some(call) = call.is_sub_type() else { + let Some(call) = applicable_call(call, Self::applies_to) else { return Ok(()); }; @@ -68,7 +75,10 @@ mod tests { use super::*; use crate::{Error, tests::mock::*}; use frame_support::{ - assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable, + assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ExtendedDispatchable, + weights::Weight, }; use sp_core::U256; use sp_runtime::DispatchError; @@ -91,6 +101,39 @@ mod tests { result.unwrap_err().error } + fn add_stake_call() -> RuntimeCall { + RuntimeCall::SubtensorModule(SubtensorCall::add_stake { + hotkey: U256::from(1), + netuid: 1u16.into(), + amount_staked: 1_000u64.into(), + }) + } + + #[test] + fn weight_only_charges_delegate_take_calls() { + let expected = ::WeightInfo::check_delegate_take_extension(); + + for call in [ + RuntimeCall::System(frame_system::Call::remark { remark: vec![] }), + add_stake_call(), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + Weight::zero() + ); + } + + for call in [ + increase_take_call(U256::from(1), 0), + decrease_take_call(U256::from(1), 0), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + expected + ); + } + } + #[test] fn accepts_owner_with_valid_take() { new_test_ext(0).execute_with(|| { diff --git a/pallets/subtensor/src/guards/check_evm_key_association.rs b/pallets/subtensor/src/guards/check_evm_key_association.rs index d7e2847e99..d9b69e1a7d 100644 --- a/pallets/subtensor/src/guards/check_evm_key_association.rs +++ b/pallets/subtensor/src/guards/check_evm_key_association.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf, applicable_call}; use crate::weights::WeightInfo; use crate::{Call, Config, Error, Pallet}; use frame_support::{ @@ -8,9 +9,6 @@ use frame_support::{ use sp_runtime::traits::Dispatchable; use sp_std::marker::PhantomData; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; - /// Dispatch extension for EVM-key association preconditions. /// /// Signed EVM-key association calls are checked for subnet registration and @@ -18,6 +16,10 @@ type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; pub struct CheckEvmKeyAssociation(PhantomData); impl CheckEvmKeyAssociation { + pub(crate) fn applies_to(call: &Call) -> bool { + matches!(call, Call::associate_evm_key { .. }) + } + pub fn check(who: &T::AccountId, call: &Call) -> Result<(), Error> { match call { Call::associate_evm_key { netuid, .. } => { @@ -40,8 +42,10 @@ where { type Pre = (); - fn weight(_call: &CallOf) -> Weight { - ::WeightInfo::check_evm_key_association_extension() + fn weight(call: &CallOf) -> Weight { + applicable_call(call, Self::applies_to) + .map(|_| ::WeightInfo::check_evm_key_association_extension()) + .unwrap_or(Weight::zero()) } fn pre_dispatch( @@ -52,7 +56,7 @@ where return Ok(()); }; - let Some(call) = call.is_sub_type() else { + let Some(call) = applicable_call(call, Self::applies_to) else { return Ok(()); }; @@ -64,10 +68,13 @@ where #[allow(clippy::unwrap_used, clippy::arithmetic_side_effects)] mod tests { use super::CheckEvmKeyAssociation; - use crate::{AssociatedEvmAddress, Error, tests::mock::*}; + use crate::{AssociatedEvmAddress, Error, tests::mock::*, weights::WeightInfo as _}; use codec::Encode; use frame_support::{ - assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable, + assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ExtendedDispatchable, + weights::Weight, }; use frame_system::Call as SystemCall; use sp_core::{H160, Pair, U256, ecdsa, keccak_256}; @@ -139,6 +146,37 @@ mod tests { ) } + fn add_stake_call() -> RuntimeCall { + RuntimeCall::SubtensorModule(SubtensorCall::add_stake { + hotkey: U256::from(1), + netuid: 1u16.into(), + amount_staked: 1_000u64.into(), + }) + } + + #[test] + fn weight_only_charges_evm_key_association_calls() { + let netuid = NetUid::from(1); + let expected = ::WeightInfo::check_evm_key_association_extension(); + + for call in [ + RuntimeCall::System(SystemCall::remark { remark: vec![] }), + add_stake_call(), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + Weight::zero() + ); + } + + assert_eq!( + as DispatchExtension>::weight( + &dummy_associate_call(netuid) + ), + expected + ); + } + #[test] fn unrelated_calls_pass_through() { new_test_ext(0).execute_with(|| { diff --git a/pallets/subtensor/src/guards/check_rate_limits.rs b/pallets/subtensor/src/guards/check_rate_limits.rs index d2c021dd5d..e12c9d064b 100644 --- a/pallets/subtensor/src/guards/check_rate_limits.rs +++ b/pallets/subtensor/src/guards/check_rate_limits.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf, applicable_call}; use crate::weights::WeightInfo; use crate::{Call, Config, Error, Pallet, TransactionType}; use frame_support::{ @@ -9,9 +10,6 @@ use sp_runtime::traits::Dispatchable; use sp_std::marker::PhantomData; use subtensor_runtime_common::{NetUid, NetUidStorageIndex}; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; - /// Dispatch extension for rate-limit checks that are safe to reject before dispatch. /// /// Signed weight and network-registration calls are checked before dispatch; @@ -19,6 +17,17 @@ type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; pub struct CheckRateLimits(PhantomData); impl CheckRateLimits { + pub(crate) fn applies_to(call: &Call) -> bool { + matches!( + call, + Call::commit_weights { .. } + | Call::commit_mechanism_weights { .. } + | Call::set_weights { .. } + | Call::set_mechanism_weights { .. } + | Call::register_network { .. } + ) + } + fn check_weights_rate_limit( who: &T::AccountId, netuid: NetUid, @@ -89,8 +98,10 @@ where { type Pre = (); - fn weight(_call: &CallOf) -> Weight { - ::WeightInfo::check_rate_limits_extension() + fn weight(call: &CallOf) -> Weight { + applicable_call(call, Self::applies_to) + .map(|_| ::WeightInfo::check_rate_limits_extension()) + .unwrap_or(Weight::zero()) } fn pre_dispatch( @@ -101,7 +112,7 @@ where return Ok(()); }; - let Some(call) = call.is_sub_type() else { + let Some(call) = applicable_call(call, Self::applies_to) else { return Ok(()); }; @@ -113,9 +124,12 @@ where #[allow(clippy::unwrap_used)] mod tests { use super::CheckRateLimits; - use crate::{Error, tests::mock::*}; + use crate::{Error, tests::mock::*, weights::WeightInfo as _}; use frame_support::{ - assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable, + assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ExtendedDispatchable, + weights::Weight, }; use frame_system::Call as SystemCall; use sp_core::U256; @@ -155,6 +169,57 @@ mod tests { add_balance_to_coldkey_account(&coldkey, amount); } + fn add_stake_call() -> RuntimeCall { + RuntimeCall::SubtensorModule(SubtensorCall::add_stake { + hotkey: U256::from(1), + netuid: 1u16.into(), + amount_staked: 1_000u64.into(), + }) + } + + #[test] + fn weight_only_charges_rate_limited_calls() { + let netuid = NetUid::from(1); + let expected = ::WeightInfo::check_rate_limits_extension(); + let charged_calls = [ + RuntimeCall::SubtensorModule(SubtensorCall::commit_weights { + netuid, + commit_hash: sp_core::H256::zero(), + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_mechanism_weights { + netuid, + mecid: MechId::MAIN, + commit_hash: sp_core::H256::zero(), + }), + set_weights_call(netuid, 0), + RuntimeCall::SubtensorModule(SubtensorCall::set_mechanism_weights { + netuid, + mecid: MechId::MAIN, + dests: vec![0], + weights: vec![1], + version_key: 0, + }), + register_network_call(U256::from(1)), + ]; + + for call in [ + RuntimeCall::System(SystemCall::remark { remark: vec![] }), + add_stake_call(), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + Weight::zero() + ); + } + + for call in charged_calls { + assert_eq!( + as DispatchExtension>::weight(&call), + expected + ); + } + } + #[test] fn unrelated_calls_pass_through() { new_test_ext(0).execute_with(|| { diff --git a/pallets/subtensor/src/guards/check_serving_endpoints.rs b/pallets/subtensor/src/guards/check_serving_endpoints.rs index 46304d337f..f8b2da64ed 100644 --- a/pallets/subtensor/src/guards/check_serving_endpoints.rs +++ b/pallets/subtensor/src/guards/check_serving_endpoints.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf, applicable_call}; use crate::weights::WeightInfo; use crate::{Call, Config, Error, Pallet}; use frame_support::{ @@ -8,9 +9,6 @@ use frame_support::{ use sp_runtime::traits::Dispatchable; use sp_std::marker::PhantomData; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; - /// Dispatch extension for axon/prometheus endpoint validation. /// /// Signed serving calls are checked before dispatch; unrelated calls and @@ -18,6 +16,13 @@ type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; pub struct CheckServingEndpoints(PhantomData); impl CheckServingEndpoints { + pub(crate) fn applies_to(call: &Call) -> bool { + matches!( + call, + Call::serve_axon { .. } | Call::serve_axon_tls { .. } | Call::serve_prometheus { .. } + ) + } + pub fn check(who: &T::AccountId, call: &Call) -> Result<(), Error> { match call { Call::serve_axon { @@ -74,8 +79,10 @@ where { type Pre = (); - fn weight(_call: &CallOf) -> Weight { - ::WeightInfo::check_serving_endpoints_extension() + fn weight(call: &CallOf) -> Weight { + applicable_call(call, Self::applies_to) + .map(|_| ::WeightInfo::check_serving_endpoints_extension()) + .unwrap_or(Weight::zero()) } fn pre_dispatch( @@ -86,7 +93,7 @@ where return Ok(()); }; - let Some(call) = call.is_sub_type() else { + let Some(call) = applicable_call(call, Self::applies_to) else { return Ok(()); }; @@ -98,9 +105,12 @@ where #[allow(clippy::unwrap_used)] mod tests { use super::CheckServingEndpoints; - use crate::{Error, tests::mock::*}; + use crate::{Error, tests::mock::*, weights::WeightInfo as _}; use frame_support::{ - assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable, + assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ExtendedDispatchable, + weights::Weight, }; use frame_system::Call as SystemCall; use sp_core::U256; @@ -160,6 +170,41 @@ mod tests { register_ok_neuron(netuid, hotkey, coldkey, 0); } + fn add_stake_call() -> RuntimeCall { + RuntimeCall::SubtensorModule(SubtensorCall::add_stake { + hotkey: U256::from(1), + netuid: 1u16.into(), + amount_staked: 1_000u64.into(), + }) + } + + #[test] + fn weight_only_charges_serving_endpoint_calls() { + let netuid = NetUid::from(1); + let expected = ::WeightInfo::check_serving_endpoints_extension(); + + for call in [ + RuntimeCall::System(SystemCall::remark { remark: vec![] }), + add_stake_call(), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + Weight::zero() + ); + } + + for call in [ + serve_axon_call(netuid), + serve_axon_tls_call(netuid), + serve_prometheus_call(netuid), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + expected + ); + } + } + #[test] fn unrelated_calls_pass_through() { new_test_ext(0).execute_with(|| { diff --git a/pallets/subtensor/src/guards/check_weights.rs b/pallets/subtensor/src/guards/check_weights.rs index d10e7b8b8c..c116071ba6 100644 --- a/pallets/subtensor/src/guards/check_weights.rs +++ b/pallets/subtensor/src/guards/check_weights.rs @@ -1,3 +1,4 @@ +use super::{CallOf, DispatchableOriginOf, applicable_call}; use crate::weights::WeightInfo; use crate::{Call, Config, Error, Pallet, WeightCommits}; use frame_support::{ @@ -10,8 +11,6 @@ use sp_runtime::traits::Dispatchable; use sp_std::{collections::vec_deque::VecDeque, marker::PhantomData, vec::Vec}; use subtensor_runtime_common::{NetUid, NetUidStorageIndex}; -type CallOf = ::RuntimeCall; -type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; type WeightCommitQueue = VecDeque<(H256, u64, u64, u64)>; /// Dispatch extension for weight-setting preconditions. @@ -21,6 +20,24 @@ type WeightCommitQueue = VecDeque<(H256, u64, u64, u64)>; pub struct CheckWeights(PhantomData); impl CheckWeights { + pub(crate) fn applies_to(call: &Call) -> bool { + matches!( + call, + Call::batch_commit_weights { .. } + | Call::batch_reveal_weights { .. } + | Call::batch_set_weights { .. } + | Call::commit_weights { .. } + | Call::commit_mechanism_weights { .. } + | Call::reveal_weights { .. } + | Call::reveal_mechanism_weights { .. } + | Call::set_weights { .. } + | Call::set_mechanism_weights { .. } + | Call::commit_timelocked_weights { .. } + | Call::commit_timelocked_mechanism_weights { .. } + | Call::commit_crv3_mechanism_weights { .. } + ) + } + pub fn check(who: &T::AccountId, call: &Call) -> Result<(), Error> { Self::check_input_lengths(call)?; Self::check_min_stake(who, call)?; @@ -227,8 +244,10 @@ where { type Pre = (); - fn weight(_call: &CallOf) -> Weight { - ::WeightInfo::check_weights_extension() + fn weight(call: &CallOf) -> Weight { + applicable_call(call, Self::applies_to) + .map(|_| ::WeightInfo::check_weights_extension()) + .unwrap_or(Weight::zero()) } fn pre_dispatch( @@ -239,7 +258,7 @@ where return Ok(()); }; - let Some(call) = call.is_sub_type() else { + let Some(call) = applicable_call(call, Self::applies_to) else { return Ok(()); }; @@ -251,11 +270,14 @@ where #[allow(clippy::unwrap_used)] mod tests { use super::CheckWeights; - use crate::{Error, MAX_CRV3_COMMIT_SIZE_BYTES, tests::mock::*}; + use crate::{Error, MAX_CRV3_COMMIT_SIZE_BYTES, tests::mock::*, weights::WeightInfo as _}; use codec::Compact; use frame_support::{ - BoundedVec, assert_ok, dispatch::DispatchResultWithPostInfo, traits::ConstU32, + BoundedVec, assert_ok, + dispatch::{DispatchExtension, DispatchResultWithPostInfo}, + traits::ConstU32, traits::ExtendedDispatchable, + weights::Weight, }; use frame_system::Call as SystemCall; use pallet_drand::LastStoredRound; @@ -309,6 +331,99 @@ mod tests { }) } + fn add_stake_call() -> RuntimeCall { + RuntimeCall::SubtensorModule(SubtensorCall::add_stake { + hotkey: U256::from(1), + netuid: 1u16.into(), + amount_staked: 1_000u64.into(), + }) + } + + fn checked_weight_calls(netuid: NetUid) -> Vec { + let bounded_commit = + BoundedVec::>::try_from(vec![0]).unwrap(); + + vec![ + set_weights_call(netuid, 0), + RuntimeCall::SubtensorModule(SubtensorCall::set_mechanism_weights { + netuid, + mecid: MechId::MAIN, + dests: vec![0], + weights: vec![1], + version_key: 0, + }), + RuntimeCall::SubtensorModule(SubtensorCall::batch_set_weights { + netuids: vec![Compact(netuid)], + weights: vec![vec![(Compact(0_u16), Compact(1_u16))]], + version_keys: vec![Compact(0_u64)], + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_weights { + netuid, + commit_hash: H256::zero(), + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_mechanism_weights { + netuid, + mecid: MechId::MAIN, + commit_hash: H256::zero(), + }), + RuntimeCall::SubtensorModule(SubtensorCall::batch_commit_weights { + netuids: vec![Compact(netuid)], + commit_hashes: vec![H256::zero()], + }), + reveal_weights_call(netuid), + reveal_mechanism_weights_call(netuid, MechId::MAIN), + RuntimeCall::SubtensorModule(SubtensorCall::batch_reveal_weights { + netuid, + uids_list: vec![vec![0]], + values_list: vec![vec![1]], + salts_list: vec![vec![1]], + version_keys: vec![0], + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_timelocked_weights { + netuid, + commit: bounded_commit.clone(), + reveal_round: 0, + commit_reveal_version: 0, + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_timelocked_mechanism_weights { + netuid, + mecid: MechId::MAIN, + commit: bounded_commit.clone(), + reveal_round: 0, + commit_reveal_version: 0, + }), + RuntimeCall::SubtensorModule(SubtensorCall::commit_crv3_mechanism_weights { + netuid, + mecid: MechId::MAIN, + commit: bounded_commit, + reveal_round: 0, + }), + ] + } + + #[test] + fn weight_only_charges_weight_related_calls() { + let netuid = NetUid::from(1); + let expected = ::WeightInfo::check_weights_extension(); + + for call in [ + RuntimeCall::System(SystemCall::remark { remark: vec![] }), + add_stake_call(), + ] { + assert_eq!( + as DispatchExtension>::weight(&call), + Weight::zero() + ); + } + + for call in checked_weight_calls(netuid) { + assert_eq!( + as DispatchExtension>::weight(&call), + expected + ); + } + } + #[test] fn unrelated_calls_pass_through() { new_test_ext(0).execute_with(|| { @@ -360,72 +475,13 @@ mod tests { let netuid = NetUid::from(1); let hotkey = U256::from(1); let coldkey = U256::from(2); - let bounded_commit = - BoundedVec::>::try_from(vec![0]).unwrap(); add_network_disable_commit_reveal(netuid, 1, 0); setup_reserves(netuid, DEFAULT_RESERVE.into(), DEFAULT_RESERVE.into()); SubtensorModule::append_neuron(netuid, &hotkey, 0); crate::Owner::::insert(hotkey, coldkey); SubtensorModule::set_stake_threshold(1_000_000_000_000_u64); - let calls = [ - set_weights_call(netuid, 0), - RuntimeCall::SubtensorModule(SubtensorCall::set_mechanism_weights { - netuid, - mecid: MechId::MAIN, - dests: vec![0], - weights: vec![1], - version_key: 0, - }), - RuntimeCall::SubtensorModule(SubtensorCall::batch_set_weights { - netuids: vec![Compact(netuid)], - weights: vec![vec![(Compact(0_u16), Compact(1_u16))]], - version_keys: vec![Compact(0_u64)], - }), - RuntimeCall::SubtensorModule(SubtensorCall::commit_weights { - netuid, - commit_hash: H256::zero(), - }), - RuntimeCall::SubtensorModule(SubtensorCall::commit_mechanism_weights { - netuid, - mecid: MechId::MAIN, - commit_hash: H256::zero(), - }), - RuntimeCall::SubtensorModule(SubtensorCall::batch_commit_weights { - netuids: vec![Compact(netuid)], - commit_hashes: vec![H256::zero()], - }), - reveal_weights_call(netuid), - reveal_mechanism_weights_call(netuid, MechId::MAIN), - RuntimeCall::SubtensorModule(SubtensorCall::batch_reveal_weights { - netuid, - uids_list: vec![vec![0]], - values_list: vec![vec![1]], - salts_list: vec![vec![1]], - version_keys: vec![0], - }), - RuntimeCall::SubtensorModule(SubtensorCall::commit_timelocked_weights { - netuid, - commit: bounded_commit.clone(), - reveal_round: 0, - commit_reveal_version: 0, - }), - RuntimeCall::SubtensorModule(SubtensorCall::commit_timelocked_mechanism_weights { - netuid, - mecid: MechId::MAIN, - commit: bounded_commit.clone(), - reveal_round: 0, - commit_reveal_version: 0, - }), - RuntimeCall::SubtensorModule(SubtensorCall::commit_crv3_mechanism_weights { - netuid, - mecid: MechId::MAIN, - commit: bounded_commit, - reveal_round: 0, - }), - ]; - - for call in calls { + for call in checked_weight_calls(netuid) { assert_eq!( err(dispatch_with_ext(call, RuntimeOrigin::signed(hotkey))), Error::::NotEnoughStakeToSetWeights.into() diff --git a/pallets/subtensor/src/guards/mod.rs b/pallets/subtensor/src/guards/mod.rs index 485fc65a04..3865352858 100644 --- a/pallets/subtensor/src/guards/mod.rs +++ b/pallets/subtensor/src/guards/mod.rs @@ -5,9 +5,28 @@ mod check_rate_limits; mod check_serving_endpoints; mod check_weights; +use crate::{Call, Config}; +use frame_support::traits::IsSubType; +use sp_runtime::traits::Dispatchable; + pub use check_coldkey_swap::*; pub use check_delegate_take::*; pub use check_evm_key_association::*; pub use check_rate_limits::*; pub use check_serving_endpoints::*; pub use check_weights::*; + +pub(crate) type CallOf = ::RuntimeCall; +pub(crate) type DispatchableOriginOf = as Dispatchable>::RuntimeOrigin; + +pub(crate) fn applicable_call( + call: &CallOf, + applies_to: impl FnOnce(&Call) -> bool, +) -> Option<&Call> +where + T: Config, + CallOf: IsSubType>, +{ + let call = call.is_sub_type()?; + applies_to(call).then_some(call) +} From ebe1ac631e04498b3e295ac87030e0c68f566585 Mon Sep 17 00:00:00 2001 From: Loris Moulin Date: Mon, 29 Jun 2026 13:34:53 -0300 Subject: [PATCH 2/2] Apply subtensor extension conditionally --- pallets/subtensor/src/extensions/subtensor.rs | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/pallets/subtensor/src/extensions/subtensor.rs b/pallets/subtensor/src/extensions/subtensor.rs index ea91c87c6e..7899ed855e 100644 --- a/pallets/subtensor/src/extensions/subtensor.rs +++ b/pallets/subtensor/src/extensions/subtensor.rs @@ -1,6 +1,6 @@ use crate::{ Call, CheckColdkeySwap, CheckDelegateTake, CheckEvmKeyAssociation, CheckRateLimits, - CheckServingEndpoints, CheckWeights, Config, Error, + CheckServingEndpoints, CheckWeights, Config, Error, guards::applicable_call, }; use codec::{Decode, DecodeWithMemTracking, Encode}; use frame_support::{ @@ -89,15 +89,23 @@ impl SubtensorTransactionExtension { CheckColdkeySwap::::check(who, call)?; - let Some(call) = call.is_sub_type() else { - return Ok(()); - }; + if let Some(call) = applicable_call(call, CheckWeights::::applies_to) { + CheckWeights::::check(who, call)?; + } + if let Some(call) = applicable_call(call, CheckRateLimits::::applies_to) { + CheckRateLimits::::check(who, call)?; + } + if let Some(call) = applicable_call(call, CheckDelegateTake::::applies_to) { + CheckDelegateTake::::check(who, call)?; + } + if let Some(call) = applicable_call(call, CheckServingEndpoints::::applies_to) { + CheckServingEndpoints::::check(who, call)?; + } + if let Some(call) = applicable_call(call, CheckEvmKeyAssociation::::applies_to) { + CheckEvmKeyAssociation::::check(who, call)?; + } - CheckWeights::::check(who, call)?; - CheckRateLimits::::check(who, call)?; - CheckDelegateTake::::check(who, call)?; - CheckServingEndpoints::::check(who, call)?; - CheckEvmKeyAssociation::::check(who, call) + Ok(()) } }