Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: CI Testing

on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]

jobs:

testing:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install -r requirements.txt -f "https://download.pytorch.org/whl/cpu/torch_stable.html"
pip list
#- name: Test with pytest
# run: pytest -v .
- name: Runing main
run: python main.py --epochs=2
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
Attention-based Deep Multiple Instance Learning
================================================

by Maximilian Ilse (<ilse.maximilian@gmail.com>), Jakub M. Tomczak (<jakubmkt@gmail.com>) and Max Welling
by [Maximilian Ilse](ilse.maximilian@gmail.com), [Jakub M. Tomczak](jakubmkt@gmail.com) and Max Welling

[![CI Testing](https://github.com/Borda/AttentionDeepMIL/actions/workflows/ci-tests.yml/badge.svg?event=push)](https://github.com/Borda/AttentionDeepMIL/actions/workflows/ci-tests.yml)

Overview
--------
Expand Down
2 changes: 1 addition & 1 deletion dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _create_bags(self):
labels_list = []

for i in range(self.num_bag):
bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
bag_length = int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
if bag_length < 1:
bag_length = 1

Expand Down
5 changes: 2 additions & 3 deletions mnist_bags_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form."""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
Expand Down Expand Up @@ -47,7 +46,7 @@ def _form_bags(self):
labels = batch_data[1]

while valid_bags_counter < self.num_bag:
bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
bag_length = int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
if bag_length < 1:
bag_length = 1
indices = torch.LongTensor(self.r.randint(0, self.num_in_train, bag_length))
Expand Down Expand Up @@ -99,7 +98,7 @@ def _form_bags(self):
labels = batch_data[1]

while valid_bags_counter < self.num_bag:
bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
bag_length = int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1))
if bag_length < 1:
bag_length = 1
indices = torch.LongTensor(self.r.randint(0, self.num_in_test, bag_length))
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
numpy
torch >=2.0.0
torchvision