This project implements a value-based reinforcement learning approach using Deep Q-Networks (DQN) and Double DQN (DDQN) for detecting lesions in segmented brain images. The agent navigates a grid to accurately identify and localize lesions in MRI scans.
The project is organized into three main folders:
-
DDQN Model:
- Implements the Double DQN approach for more stable training and better lesion detection.
- Key File:
main.py - Run Command:
python main.py
-
DQN Model:
- Implements the standard DQN algorithm, where the agent uses a single model for both policy and evaluation.
- Key File:
main.py - Run Command:
python main.py
-
Prototype:
- Contains the initial prototype of the lesion detection system. This can be used for testing and quick iterations.
Before running the project, ensure the following dependencies are installed:
pip install pygame torch matplotlib numpy opencv-pythonBoth models are set up to train an agent that navigates through segmented MRI images. To run either model, use the following commands:
cd DDQN_Model
python main.pycd DQN_Model
python main.pyEach model is executed through main.py, which trains the agent and logs rewards and performance. Example usage:
python main.pyThis command will:
- Train the agent on the MRI images.
- Save images labeled with detected lesions.
- Output episode durations and rewards.
main.py: Main training script that initializes the environment and agent, runs the training loop, and logs results.environment.py: Defines the environment for the agent, where MRI images are loaded and the grid navigation is implemented.agent.py: Contains the logic for the DDQN agent, including action selection and training of the neural network.model.py: Defines the architecture of the Deep Q-Network (DQN) used by the agent.replay_memory.py: Implements experience replay, which stores past experiences and samples batches for training.test.py: Script to evaluate the trained model on new images and visualize the agent’s performance.
During training, the agent navigates a grid over an MRI image. The model outputs include:
- Rewards: Total rewards per episode.
- Losses: Average loss for each training episode.
- Labeled Images: Images saved with detected lesion regions highlighted.
The training script automatically generates plots:
- Real-Time Rewards: Shows how rewards evolve over training episodes.
- Cumulative Time: Tracks the cumulative training time for each milestone episode.
This project is licensed under the Apache License, Version 2.0. See the LICENSE file for details.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.