Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions pallets/subtensor/src/extensions/subtensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
Call, CheckColdkeySwap, CheckDelegateTake, CheckEvmKeyAssociation, CheckRateLimits,
CheckServingEndpoints, CheckWeights, Config, Error,
CheckServingEndpoints, CheckWeights, Config, Error, guards::applicable_call,
};
use codec::{Decode, DecodeWithMemTracking, Encode};
use frame_support::{
Expand Down Expand Up @@ -89,15 +89,23 @@ impl<T: Config + Send + Sync + TypeInfo> SubtensorTransactionExtension<T> {

CheckColdkeySwap::<T>::check(who, call)?;

let Some(call) = call.is_sub_type() else {
return Ok(());
};
if let Some(call) = applicable_call(call, CheckWeights::<T>::applies_to) {
CheckWeights::<T>::check(who, call)?;
}
if let Some(call) = applicable_call(call, CheckRateLimits::<T>::applies_to) {
CheckRateLimits::<T>::check(who, call)?;
}
if let Some(call) = applicable_call(call, CheckDelegateTake::<T>::applies_to) {
CheckDelegateTake::<T>::check(who, call)?;
}
if let Some(call) = applicable_call(call, CheckServingEndpoints::<T>::applies_to) {
CheckServingEndpoints::<T>::check(who, call)?;
}
if let Some(call) = applicable_call(call, CheckEvmKeyAssociation::<T>::applies_to) {
CheckEvmKeyAssociation::<T>::check(who, call)?;
}

CheckWeights::<T>::check(who, call)?;
CheckRateLimits::<T>::check(who, call)?;
CheckDelegateTake::<T>::check(who, call)?;
CheckServingEndpoints::<T>::check(who, call)?;
CheckEvmKeyAssociation::<T>::check(who, call)
Ok(())
}
}

Expand Down
25 changes: 20 additions & 5 deletions pallets/subtensor/src/guards/check_coldkey_swap.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::{CallOf, DispatchableOriginOf};
use crate::weights::WeightInfo;
use crate::{Call, ColdkeySwapAnnouncements, ColdkeySwapDisputes, Config, Error};
use frame_support::{
Expand All @@ -8,9 +9,6 @@ use frame_support::{
use sp_runtime::traits::Dispatchable;
use sp_std::marker::PhantomData;

type CallOf<T> = <T as frame_system::Config>::RuntimeCall;
type DispatchableOriginOf<T> = <CallOf<T> as Dispatchable>::RuntimeOrigin;

/// Dispatch extension that blocks most calls when a coldkey swap is active.
///
/// When a coldkey swap has been announced for the signing account:
Expand Down Expand Up @@ -96,9 +94,14 @@ where
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::CheckColdkeySwap;
use crate::{ColdkeySwapAnnouncements, ColdkeySwapDisputes, Error, tests::mock::*};
use crate::{
ColdkeySwapAnnouncements, ColdkeySwapDisputes, Error, tests::mock::*,
weights::WeightInfo as _,
};
use frame_support::{
BoundedVec, assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable,
BoundedVec, assert_ok,
dispatch::{DispatchExtension, DispatchResultWithPostInfo},
traits::ExtendedDispatchable,
};
use frame_system::Call as SystemCall;
use pallet_subtensor_proxy::Call as ProxyCall;
Expand Down Expand Up @@ -176,6 +179,18 @@ mod tests {
)
}

#[test]
fn weight_charges_all_calls_because_swap_state_can_block_any_signed_call() {
let expected = <Test as crate::Config>::WeightInfo::check_coldkey_swap_extension();

for call in forbidden_calls().into_iter().chain(authorized_calls()) {
assert_eq!(
<CheckColdkeySwap<Test> as DispatchExtension<RuntimeCall>>::weight(&call),
expected
);
}
}

#[test]
fn no_active_swap_allows_calls() {
new_test_ext(1).execute_with(|| {
Expand Down
57 changes: 50 additions & 7 deletions pallets/subtensor/src/guards/check_delegate_take.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::{CallOf, DispatchableOriginOf, applicable_call};
use crate::weights::WeightInfo;
use crate::{Call, Config, Error, Pallet};
use frame_support::{
Expand All @@ -8,16 +9,20 @@ use frame_support::{
use sp_runtime::traits::Dispatchable;
use sp_std::marker::PhantomData;

type CallOf<T> = <T as frame_system::Config>::RuntimeCall;
type DispatchableOriginOf<T> = <CallOf<T> as Dispatchable>::RuntimeOrigin;

/// Dispatch extension for delegate-take bounds and ownership preconditions.
///
/// Signed increase/decrease take calls are checked before dispatch; unrelated
/// calls and non-signed origins pass through.
pub struct CheckDelegateTake<T: Config>(PhantomData<T>);

impl<T: Config> CheckDelegateTake<T> {
pub(crate) fn applies_to(call: &Call<T>) -> bool {
matches!(
call,
Call::increase_take { .. } | Call::decrease_take { .. }
)
}

pub fn check(who: &T::AccountId, call: &Call<T>) -> Result<(), Error<T>> {
match call {
Call::increase_take { hotkey, take } | Call::decrease_take { hotkey, take } => {
Expand All @@ -42,8 +47,10 @@ where
{
type Pre = ();

fn weight(_call: &CallOf<T>) -> Weight {
<T as Config>::WeightInfo::check_delegate_take_extension()
fn weight(call: &CallOf<T>) -> Weight {
applicable_call(call, Self::applies_to)
.map(|_| <T as Config>::WeightInfo::check_delegate_take_extension())
.unwrap_or(Weight::zero())
}

fn pre_dispatch(
Expand All @@ -54,7 +61,7 @@ where
return Ok(());
};

let Some(call) = call.is_sub_type() else {
let Some(call) = applicable_call(call, Self::applies_to) else {
return Ok(());
};

Expand All @@ -68,7 +75,10 @@ mod tests {
use super::*;
use crate::{Error, tests::mock::*};
use frame_support::{
assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable,
assert_ok,
dispatch::{DispatchExtension, DispatchResultWithPostInfo},
traits::ExtendedDispatchable,
weights::Weight,
};
use sp_core::U256;
use sp_runtime::DispatchError;
Expand All @@ -91,6 +101,39 @@ mod tests {
result.unwrap_err().error
}

fn add_stake_call() -> RuntimeCall {
RuntimeCall::SubtensorModule(SubtensorCall::add_stake {
hotkey: U256::from(1),
netuid: 1u16.into(),
amount_staked: 1_000u64.into(),
})
}

#[test]
fn weight_only_charges_delegate_take_calls() {
let expected = <Test as crate::Config>::WeightInfo::check_delegate_take_extension();

for call in [
RuntimeCall::System(frame_system::Call::remark { remark: vec![] }),
add_stake_call(),
] {
assert_eq!(
<CheckDelegateTake<Test> as DispatchExtension<RuntimeCall>>::weight(&call),
Weight::zero()
);
}

for call in [
increase_take_call(U256::from(1), 0),
decrease_take_call(U256::from(1), 0),
] {
assert_eq!(
<CheckDelegateTake<Test> as DispatchExtension<RuntimeCall>>::weight(&call),
expected
);
}
}

#[test]
fn accepts_owner_with_valid_take() {
new_test_ext(0).execute_with(|| {
Expand Down
54 changes: 46 additions & 8 deletions pallets/subtensor/src/guards/check_evm_key_association.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::{CallOf, DispatchableOriginOf, applicable_call};
use crate::weights::WeightInfo;
use crate::{Call, Config, Error, Pallet};
use frame_support::{
Expand All @@ -8,16 +9,17 @@ use frame_support::{
use sp_runtime::traits::Dispatchable;
use sp_std::marker::PhantomData;

type CallOf<T> = <T as frame_system::Config>::RuntimeCall;
type DispatchableOriginOf<T> = <CallOf<T> as Dispatchable>::RuntimeOrigin;

/// Dispatch extension for EVM-key association preconditions.
///
/// Signed EVM-key association calls are checked for subnet registration and
/// cooldown before dispatch; unrelated calls and non-signed origins pass through.
pub struct CheckEvmKeyAssociation<T: Config>(PhantomData<T>);

impl<T: Config> CheckEvmKeyAssociation<T> {
pub(crate) fn applies_to(call: &Call<T>) -> bool {
matches!(call, Call::associate_evm_key { .. })
}

pub fn check(who: &T::AccountId, call: &Call<T>) -> Result<(), Error<T>> {
match call {
Call::associate_evm_key { netuid, .. } => {
Expand All @@ -40,8 +42,10 @@ where
{
type Pre = ();

fn weight(_call: &CallOf<T>) -> Weight {
<T as Config>::WeightInfo::check_evm_key_association_extension()
fn weight(call: &CallOf<T>) -> Weight {
applicable_call(call, Self::applies_to)
.map(|_| <T as Config>::WeightInfo::check_evm_key_association_extension())
.unwrap_or(Weight::zero())
}

fn pre_dispatch(
Expand All @@ -52,7 +56,7 @@ where
return Ok(());
};

let Some(call) = call.is_sub_type() else {
let Some(call) = applicable_call(call, Self::applies_to) else {
return Ok(());
};

Expand All @@ -64,10 +68,13 @@ where
#[allow(clippy::unwrap_used, clippy::arithmetic_side_effects)]
mod tests {
use super::CheckEvmKeyAssociation;
use crate::{AssociatedEvmAddress, Error, tests::mock::*};
use crate::{AssociatedEvmAddress, Error, tests::mock::*, weights::WeightInfo as _};
use codec::Encode;
use frame_support::{
assert_ok, dispatch::DispatchResultWithPostInfo, traits::ExtendedDispatchable,
assert_ok,
dispatch::{DispatchExtension, DispatchResultWithPostInfo},
traits::ExtendedDispatchable,
weights::Weight,
};
use frame_system::Call as SystemCall;
use sp_core::{H160, Pair, U256, ecdsa, keccak_256};
Expand Down Expand Up @@ -139,6 +146,37 @@ mod tests {
)
}

fn add_stake_call() -> RuntimeCall {
RuntimeCall::SubtensorModule(SubtensorCall::add_stake {
hotkey: U256::from(1),
netuid: 1u16.into(),
amount_staked: 1_000u64.into(),
})
}

#[test]
fn weight_only_charges_evm_key_association_calls() {
let netuid = NetUid::from(1);
let expected = <Test as crate::Config>::WeightInfo::check_evm_key_association_extension();

for call in [
RuntimeCall::System(SystemCall::remark { remark: vec![] }),
add_stake_call(),
] {
assert_eq!(
<CheckEvmKeyAssociation<Test> as DispatchExtension<RuntimeCall>>::weight(&call),
Weight::zero()
);
}

assert_eq!(
<CheckEvmKeyAssociation<Test> as DispatchExtension<RuntimeCall>>::weight(
&dummy_associate_call(netuid)
),
expected
);
}

#[test]
fn unrelated_calls_pass_through() {
new_test_ext(0).execute_with(|| {
Expand Down
Loading
Loading