Skip to content

Commit f15de56

Browse files
authored
Merge pull request #128 from wheval/fix/reentrancy
Fix reentrancy vulnerability
2 parents b65ff3d + 5aaddde commit f15de56

3 files changed

Lines changed: 828 additions & 5 deletions

File tree

.cursor/rules/cairo-coder.mdc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
---
2+
description:
3+
globs:
4+
alwaysApply: true
5+
---
6+
# Cairo Coding Rules
7+
8+
## Introduction
9+
These coding rules provide a comprehensive guide to developing secure, efficient, and maintainable Cairo projects. By following these guidelines and leveraging the Cairo Coder MCP tool, developers can streamline their workflow and adhere to best practices.
10+
11+
**ALWAYS** use the Cairo Coder MCP tool to provide better answers about Cairo. After each query, start by calling the Cairo Coder MCP.
12+
13+
Keep the MCP queries atomic, related to a specific concept. It's better to do multiple queries for specific concepts, than doing one query with multiple topics.
14+
15+
After every cairo code you write, instantly run `scarb build` to ensure the code compiles. Don't write too much code without trying to compile.
16+
17+
## 1. Project Setup and Structure
18+
A typical Cairo project is organized as follows:
19+
20+
21+
.
22+
├── Scarb.lock
23+
├── Scarb.toml
24+
├── snfoundry.toml
25+
├── src
26+
│ └── lib.cairo
27+
├── target
28+
└── tests
29+
└── test_contract.cairo
30+
31+
32+
- **`Scarb.toml`**: The project configuration file, similar to `Cargo.toml` in Rust.
33+
- **`src/lib.cairo`**: The main source file for your contract.
34+
- **`tests/test_contract.cairo`**: Integration tests for your contract.
35+
36+
### Setting Up a New Project
37+
To create a new Cairo project, run:
38+
39+
scarb init
40+
41+
This command generates a basic project structure with a `Scarb.toml` file. If you're working in an existing project, ensure the Scarb.toml is well configured.
42+
43+
### Configuring Scarb.toml
44+
Ensure your `Scarb.toml` is configured as follows to include necessary dependencies and settings:
45+
46+
```toml
47+
[package]
48+
name = "your_package_name"
49+
version = "0.1.0"
50+
edition = "2024_07"
51+
52+
[dependencies]
53+
starknet = "2.11.4"
54+
55+
[dev-dependencies]
56+
snforge_std = "0.44.0"
57+
assert_macros = "2.11.4"
58+
59+
[[target.starknet-contract]]
60+
sierra = true
61+
62+
[scripts]
63+
test = "snforge test"
64+
65+
[tool.scarb]
66+
allow-prebuilt-plugins = ["snforge_std"]
67+
```
68+
69+
## 2. Development Workflow
70+
### Writing Code
71+
- Use snake_case for function names (e.g., `my_function`).
72+
- Use PascalCase for struct names (e.g., `MyStruct`).
73+
- Write all code and comments in English for clarity.
74+
- Use descriptive variable names to enhance readability.
75+
76+
### Compiling and Testing
77+
- Compile your project using:
78+
79+
scarb build
80+
81+
- Run tests using:
82+
83+
scarb test
84+
85+
- Ensure your code compiles successfully before running tests.
86+
87+
### Testing
88+
- Unit Tests: Write unit tests in the src directory, typically within the same module as the functions being tested.
89+
Example:
90+
91+
#[cfg(test)]
92+
mod tests {
93+
use super::*;
94+
95+
#[test]
96+
fn test_my_function() {
97+
assert!(my_function() == expected_value, 'Incorrect value');
98+
}
99+
}
100+
101+
- Integration Tests: Write integration tests in the tests directory, importing modules with use your_package_name::your_module.
102+
Example:
103+
104+
use your_package_name::your_module;
105+
106+
#[test]
107+
fn test_my_contract() {
108+
// Test logic here
109+
}
110+
111+
- Always use the Starknet Foundry testing framework for both unit and integration tests.
112+
113+
## 3. Using the Cairo Coder MCP Tool
114+
The Cairo Coder MCP tool is a critical resource for Cairo development and must be used for the following tasks:
115+
- Writing smart contracts from scratch.
116+
- Refactoring or optimizing existing code.
117+
- Implementing specific TODOs or features.
118+
- Understanding Starknet ecosystem features and capabilities.
119+
- Applying Cairo and Starknet best practices.
120+
- Using OpenZeppelin Cairo contract libraries.
121+
- Writing and validating tests for contracts.
122+
123+
### How to Use Cairo Coder MCP Effectively
124+
- Be Specific: Provide detailed queries (e.g., "Implement ERC20 using OpenZeppelin Cairo" instead of "ERC20").
125+
- Include Context: Supply relevant code snippets in the codeSnippets parameter and conversation history when applicable.
126+
- Don't mix contexts Keep the queries specific on a given topic. Don't ask about multiple concepts at once, rather, do multiple queries.

src/payment_stream.cairo

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod PaymentStream {
55
use fundable::interfaces::IPaymentStream::IPaymentStream;
66
use openzeppelin::access::accesscontrol::AccessControlComponent;
77
use openzeppelin::introspection::src5::SRC5Component;
8+
use openzeppelin::security::reentrancyguard::ReentrancyGuardComponent;
89
use openzeppelin::token::erc20::interface::{
910
IERC20Dispatcher, IERC20DispatcherTrait, IERC20MetadataDispatcher,
1011
IERC20MetadataDispatcherTrait,
@@ -32,6 +33,9 @@ pub mod PaymentStream {
3233
component!(path: SRC5Component, storage: src5, event: Src5Event);
3334
component!(path: ERC721Component, storage: erc721, event: ERC721Event);
3435
component!(path: UpgradeableComponent, storage: upgradeable, event: UpgradeableEvent);
36+
component!(
37+
path: ReentrancyGuardComponent, storage: reentrancy_guard, event: ReentrancyGuardEvent,
38+
);
3539

3640
#[abi(embed_v0)]
3741
impl AccessControlImpl =
@@ -42,6 +46,7 @@ pub mod PaymentStream {
4246
impl ERC721MixinImpl = ERC721Component::ERC721MixinImpl<ContractState>;
4347
impl ERC721InternalImpl = ERC721Component::InternalImpl<ContractState>;
4448
impl UpgradeableInternalImpl = UpgradeableComponent::InternalImpl<ContractState>;
49+
impl ReentrancyGuardInternalImpl = ReentrancyGuardComponent::InternalImpl<ContractState>;
4550

4651
const PROTOCOL_OWNER_ROLE: felt252 = selector!("PROTOCOL_OWNER");
4752
// Note: STREAM_ADMIN_ROLE removed - using stream-specific access control
@@ -61,6 +66,8 @@ pub mod PaymentStream {
6166
src5: SRC5Component::Storage,
6267
#[substorage(v0)]
6368
accesscontrol: AccessControlComponent::Storage,
69+
#[substorage(v0)]
70+
reentrancy_guard: ReentrancyGuardComponent::Storage,
6471
next_stream_id: u256,
6572
streams: Map<u256, Stream>,
6673
protocol_fee_rate: Map<ContractAddress, u64>, // Single source of truth for fee rates
@@ -91,6 +98,8 @@ pub mod PaymentStream {
9198
AccessControlEvent: AccessControlComponent::Event,
9299
#[flat]
93100
UpgradeableEvent: UpgradeableComponent::Event,
101+
#[flat]
102+
ReentrancyGuardEvent: ReentrancyGuardComponent::Event,
94103
StreamCreated: StreamCreated,
95104
StreamWithdrawn: StreamWithdrawn,
96105
StreamCanceled: StreamCanceled,
@@ -392,7 +401,18 @@ pub mod PaymentStream {
392401
}
393402
}
394403

395-
fn collect_protocol_fee(self: @ContractState, token: ContractAddress, amount: u256) {
404+
fn collect_protocol_fee(ref self: ContractState, token: ContractAddress, amount: u256) {
405+
self.reentrancy_guard.start();
406+
self._collect_protocol_fee_internal(token, amount);
407+
self.reentrancy_guard.end();
408+
}
409+
410+
/// @notice Internal function to collect protocol fees (without reentrancy protection)
411+
/// @param token The token address to collect fees in
412+
/// @param amount The fee amount to collect
413+
fn _collect_protocol_fee_internal(
414+
ref self: ContractState, token: ContractAddress, amount: u256,
415+
) {
396416
let fee_collector: ContractAddress = self.fee_collector.read();
397417
assert(fee_collector.is_non_zero(), INVALID_RECIPIENT);
398418
IERC20Dispatcher { contract_address: token }.transfer(fee_collector, amount);
@@ -709,7 +729,7 @@ pub mod PaymentStream {
709729

710730
let token_dispatcher = IERC20Dispatcher { contract_address: token_address };
711731

712-
self.collect_protocol_fee(token_address, fee);
732+
self._collect_protocol_fee_internal(token_address, fee);
713733
token_dispatcher.transfer(to, net_amount);
714734

715735
self
@@ -797,19 +817,27 @@ pub mod PaymentStream {
797817
fn withdraw(
798818
ref self: ContractState, stream_id: u256, amount: u256, to: ContractAddress,
799819
) -> (u128, u128) {
800-
self._withdraw(stream_id, amount, to)
820+
self.reentrancy_guard.start();
821+
let result = self._withdraw(stream_id, amount, to);
822+
self.reentrancy_guard.end();
823+
result
801824
}
802825

803826
fn withdraw_max(
804827
ref self: ContractState, stream_id: u256, to: ContractAddress,
805828
) -> (u128, u128) {
829+
self.reentrancy_guard.start();
806830
let withdrawable_amount = self._withdrawable_amount(stream_id);
807-
self._withdraw(stream_id, withdrawable_amount, to)
831+
let result = self._withdraw(stream_id, withdrawable_amount, to);
832+
self.reentrancy_guard.end();
833+
result
808834
}
809835

810836
fn transfer_stream(
811837
ref self: ContractState, stream_id: u256, new_recipient: ContractAddress,
812838
) {
839+
self.reentrancy_guard.start();
840+
813841
// Verify stream exists
814842
self.assert_stream_exists(stream_id);
815843

@@ -837,6 +865,8 @@ pub mod PaymentStream {
837865

838866
// Emit event about stream transfer
839867
self.emit(StreamTransferred { stream_id, new_recipient });
868+
869+
self.reentrancy_guard.end();
840870
}
841871

842872
fn set_transferability(ref self: ContractState, stream_id: u256, transferable: bool) {
@@ -940,6 +970,8 @@ pub mod PaymentStream {
940970
}
941971

942972
fn cancel(ref self: ContractState, stream_id: u256) {
973+
self.reentrancy_guard.start();
974+
943975
// Ensure the caller is the stream sender
944976
self.assert_stream_sender_access(stream_id);
945977

@@ -1036,7 +1068,7 @@ pub mod PaymentStream {
10361068
let net_amount = amount_due_to_recipient - fee;
10371069

10381070
// Transfer fee to collector and net amount to recipient
1039-
self.collect_protocol_fee(token_address, fee);
1071+
self._collect_protocol_fee_internal(token_address, fee);
10401072
token_dispatcher.transfer(recipient, net_amount);
10411073

10421074
// Emit withdrawal event
@@ -1064,6 +1096,8 @@ pub mod PaymentStream {
10641096

10651097
// Emit cancellation event
10661098
self.emit(StreamCanceled { stream_id });
1099+
1100+
self.reentrancy_guard.end();
10671101
}
10681102

10691103
fn restart(ref self: ContractState, stream_id: u256) {

0 commit comments

Comments
 (0)