diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 21e3a4d..4859d07 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -23,16 +23,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install ruff pytest
+ pip install pytest
pip install git+https://github.com/ultralytics/ultralytics.git@embeddings
pip install -e .
-
- - name: Lint with ruff
- run: |
- # stop the build if there are Python syntax errors or undefined names
- ruff --format=github --select=E9,F63,F7,F82 --target-version=py37 --line-length=120 .
- # default set of ruff rules with GitHub Annotations
- ruff --format=github --target-version=py37 --line-length=120 .
+
- name: Test with pytest
run: |
pytest tests
diff --git a/.gitignore b/.gitignore
index a840fa3..3ffd235 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,10 +1,361 @@
-.pytest_cache
-env/
-run/
-*.pt
-__pycache__/
-.ipynb_checkpoints/
-*_updated.yaml
-*_updated.txt
+### JetBrains ###
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+# User-specific stuff
+.idea/**/workspace.xml
+.idea/**/tasks.xml
+.idea/**/usage.statistics.xml
+.idea/**/dictionaries
+.idea/**/shelf
+
+# AWS User-specific
+.idea/**/aws.xml
+
+# Generated files
+.idea/**/contentModel.xml
+
+# Sensitive or high-churn files
+.idea/**/dataSources/
+.idea/**/dataSources.ids
+.idea/**/dataSources.local.xml
+.idea/**/sqlDataSources.xml
+.idea/**/dynamic.xml
+.idea/**/uiDesigner.xml
+.idea/**/dbnavigator.xml
+
+# Gradle
+.idea/**/gradle.xml
+.idea/**/libraries
+
+# Gradle and Maven with auto-import
+# When using Gradle or Maven with auto-import, you should exclude module files,
+# since they will be recreated, and may cause churn. Uncomment if using
+# auto-import.
+# .idea/artifacts
+# .idea/compiler.xml
+# .idea/jarRepositories.xml
+# .idea/modules.xml
+# .idea/*.iml
+# .idea/modules
+# *.iml
+# *.ipr
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# SonarLint plugin
+.idea/sonarlint/
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+### JetBrains Patch ###
+# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
+
+# *.iml
+# modules.xml
+# .idea/misc.xml
+# *.ipr
+
+# Sonarlint plugin
+# https://plugins.jetbrains.com/plugin/7973-sonarlint
+.idea/**/sonarlint/
+
+# SonarQube Plugin
+# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
+.idea/**/sonarIssues.xml
+
+# Markdown Navigator plugin
+# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
+.idea/**/markdown-navigator.xml
+.idea/**/markdown-navigator-enh.xml
+.idea/**/markdown-navigator/
+
+# Cache file creation bug
+# See https://youtrack.jetbrains.com/issue/JBR-2257
+.idea/$CACHE_FILE$
+
+# CodeStream plugin
+# https://plugins.jetbrains.com/plugin/12206-codestream
+.idea/codestream.xml
+
+# Azure Toolkit for IntelliJ plugin
+# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
+.idea/**/azureSettings.xml
+
+### Linux ###
+*~
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### macOS ###
+# General
.DS_Store
-*.egg-info/*
\ No newline at end of file
+.config
+*.egg-info/*
+
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### macOS Patch ###
+# iCloud generated files
+*.icloud
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# LSP config files
+pyrightconfig.json
+
+### Windows ###
+# Windows thumbnail cache files
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+
+# Dump file
+*.stackdump
+
+# Folder config file
+[Dd]esktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msi
+*.msix
+*.msm
+*.msp
+
+# Windows shortcuts
+*.lnk
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..7b3f55f
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,36 @@
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+
+default_language_version:
+ python: python3.8
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: check-added-large-files
+ - id: check-toml
+ - id: check-yaml
+ args:
+ - --unsafe
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+- repo: https://github.com/asottile/pyupgrade
+ rev: v3.7.0
+ hooks:
+ - id: pyupgrade
+ args:
+ - --py3-plus
+ - --keep-runtime-typing
+- repo: https://github.com/charliermarsh/ruff-pre-commit
+ rev: v0.0.275
+ hooks:
+ - id: ruff
+ args:
+ - --fix
+- repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+ci:
+ autofix_commit_msg: "fix(pre_commit): 🎨 auto format pre-commit hooks"
+ autoupdate_commit_msg: "fix(pre_commit): ⬆ pre_commit autoupdate"
diff --git a/MANIFEST.in b/MANIFEST.in
index 540b720..f9bd145 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1 +1 @@
-include requirements.txt
\ No newline at end of file
+include requirements.txt
diff --git a/README.md b/README.md
index fa4668e..f0b469e 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# YOLOExplorer
-Explore, manipulate and iterate on Computer Vision datasets with precision using simple APIs.
+Explore, manipulate and iterate on Computer Vision datasets with precision using simple APIs.
Supports SQL filters, vector similarity search, native interface with Pandas and more.
@@ -111,12 +111,12 @@ coco_exp.remove_imgs([100,120,300..n]) # Removes images at the given ids.
Adding data
For adding data from another dataset, you need an explorer object of that dataset with embeddings built. You can then pass that object along with the ids of the imgs that you'd like to add from that dataset.
```
-coco_exp.add_imgs(exp, idxs) #
+coco_exp.add_imgs(exp, idxs) #
```
Note: You can use SQL querying and/or similarity searches to get the desired ids from the datasets.
Persisting the Table: Create new dataset and start training
-After making the desired changes, you can persist the table to create the new dataset.
+After making the desired changes, you can persist the table to create the new dataset.
```
coco_exp.persist()
```
@@ -175,13 +175,3 @@ Pre-filtering will enable powerful queries like - "Show me images similar to None:
self.data = data
self.table = None
self.model = model
+ self.device = device
self.project = project
self.dataset_info = None
self.predictor = None
@@ -100,33 +103,25 @@ def build_embeddings(self, batch_size=1000, verbose=False, force=False):
self.trainset = trainset
self.verbose = verbose
- dataset = Dataset(
- img_path=trainset, data=self.dataset_info, augment=False, cache=False
- )
+ dataset = Dataset(img_path=trainset, data=self.dataset_info, augment=False, cache=False)
batch_size = dataset.ni # TODO: fix this hardcoding
db = self._connect()
if not force and self.table_name in db.table_names():
- LOGGER.info(
- "LanceDB embedding space already exists. Attempting to reuse it. Use force=True to overwrite."
- )
+ LOGGER.info("LanceDB embedding space already exists. Attempting to reuse it. Use force=True to overwrite.")
self.table = self._open_table(self.table_name)
self.version = self.table.version
if len(self.table) == dataset.ni:
return
else:
self.table = None
- LOGGER.info(
- "Table length does not match the number of images in the dataset. Building embeddings..."
- )
+ LOGGER.info("Table length does not match the number of images in the dataset. Building embeddings...")
table_data = defaultdict(list)
for idx, batch in enumerate(dataset):
batch.pop("img")
batch["id"] = idx
batch["cls"] = batch["cls"].flatten().int().tolist()
- box_cls_pair = sorted(
- zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]
- )
+ box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
batch["bboxes"] = [box for box, _ in box_cls_pair]
batch["cls"] = [cls for _, cls in box_cls_pair]
batch["labels"] = [self.dataset_info["names"][i] for i in batch["cls"]]
@@ -143,15 +138,11 @@ def build_embeddings(self, batch_size=1000, verbose=False, force=False):
if len(table_data[key]) == batch_size or idx == dataset.ni - 1:
df = pd.DataFrame(table_data)
- df = with_embeddings(
- self._embedding_func, df, "img", batch_size=batch_size
- )
+ df = with_embeddings(self._embedding_func, df, "img", batch_size=batch_size)
if self.table:
self.table.add(table_data)
else:
- self.table = self._create_table(
- self.table_name, data=df, mode="overwrite"
- )
+ self.table = self._create_table(self.table_name, data=df, mode="overwrite")
self.version = self.table.version
table_data = defaultdict(list)
@@ -165,9 +156,7 @@ def plot_embeddings(self):
n_components (int, optional): number of components. Defaults to 2.
"""
if self.table is None:
- LOGGER.error(
- "No embedding space found. Please build the embedding space first."
- )
+ LOGGER.error("No embedding space found. Please build the embedding space first.")
return None
pca = PCA(n_components=2)
embeddings = np.array(self.table.to_arrow()["vector"].to_pylist())
@@ -188,9 +177,7 @@ def get_similar_imgs(self, img, n=10):
"""
embeddings = None
if self.table is None:
- LOGGER.error(
- "No embedding space found. Please build the embedding space first."
- )
+ LOGGER.error("No embedding space found. Please build the embedding space first.")
return None
if isinstance(img, int):
embeddings = self.table.to_pandas()["vector"][img]
@@ -198,12 +185,17 @@ def get_similar_imgs(self, img, n=10):
img = img
elif isinstance(img, bytes):
img = decode(img)
+ elif isinstance(img, list): # exceptional case for batch search from dash
+ df = self.table.to_pandas().set_index("path")
+ array = df.loc[img]["vector"].to_list()
+ embeddings = np.array(array)
+ if len(embeddings) > 1:
+ embeddings = np.mean(embeddings, axis=0)
+ else:
+ embeddings = np.squeeze(embeddings)
else:
- LOGGER.error(
- "img should be index from the table(int) or path of an image (str or Path)"
- )
+ LOGGER.error("img should be index from the table(int) or path of an image (str or Path)")
return
-
if embeddings is None:
embeddings = self.predictor.embed(img).squeeze().cpu().numpy()
sim = self.table.search(embeddings).limit(n).to_df()
@@ -226,11 +218,7 @@ def plot_imgs(self, ids=None, query=None, labels=True):
# Resize the images to the minimum and maximum width and height
resized_images = []
- df = (
- self.sql(query)
- if query
- else self.table.to_pandas().iloc[ids]
- )
+ df = self.sql(query) if query else self.table.to_pandas().iloc[ids]
for _, row in df.iterrows():
img = decode(row["img"])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -247,7 +235,7 @@ def plot_imgs(self, ids=None, query=None, labels=True):
return
# Create a grid of the images
- cols = 10 if len(resized_images) > 10 else max(2,len(resized_images))
+ cols = 10 if len(resized_images) > 10 else max(2, len(resized_images))
rows = max(1, math.ceil(len(resized_images) / cols))
fig, axes = plt.subplots(nrows=rows, ncols=cols)
fig.subplots_adjust(hspace=0, wspace=0)
@@ -257,11 +245,8 @@ def plot_imgs(self, ids=None, query=None, labels=True):
ax.axis("off")
# Display the grid of images
plt.show()
-
- def get_similarity_index(
- self, top_k=0.01, sim_thres=0.90, reduce=False, sorted=False
- ):
+ def get_similarity_index(self, top_k=0.01, sim_thres=0.90, reduce=False, sorted=False):
"""
Args:
@@ -273,9 +258,7 @@ def get_similarity_index(
np.array: Similarity index
"""
if self.table is None:
- LOGGER.error(
- "No embedding space found. Please build the embedding space first."
- )
+ LOGGER.error("No embedding space found. Please build the embedding space first.")
return None
if top_k > 1.0:
LOGGER.warning("top_k should be between 0 and 1. Setting top_k to 1.0")
@@ -285,14 +268,10 @@ def get_similarity_index(
top_k = 0.0
if sim_thres is not None:
if sim_thres > 1.0:
- LOGGER.warning(
- "sim_thres should be between 0 and 1. Setting sim_thres to 1.0"
- )
+ LOGGER.warning("sim_thres should be between 0 and 1. Setting sim_thres to 1.0")
sim_thres = 1.0
if sim_thres < 0.0:
- LOGGER.warning(
- "sim_thres should be between 0 and 1. Setting sim_thres to 0.0"
- )
+ LOGGER.warning("sim_thres should be between 0 and 1. Setting sim_thres to 0.0")
sim_thres = 0.0
embs = np.array(self.table.to_arrow()["vector"].to_pylist())
self._sim_index = np.zeros(len(embs))
@@ -307,12 +286,8 @@ def get_similarity_index(
dim = embs.shape[1]
values = pa.array(embs.reshape(-1), type=pa.float32())
table_data = pa.FixedSizeListArray.from_arrays(values, dim)
- table = pa.table(
- [table_data, self.table.to_arrow()["id"]], names=["vector", "id"]
- )
- self._search_table = self._create_table(
- "reduced_embs", data=table, mode="overwrite"
- )
+ table = pa.table([table_data, self.table.to_arrow()["id"]], names=["vector", "id"])
+ self._search_table = self._create_table("reduced_embs", data=table, mode="overwrite")
# with multiprocessing.Pool() as pool: # multiprocessing doesn't do much. Need to revisit
# list(tqdm(pool.imap(build_index, iterable)))
@@ -327,9 +302,7 @@ def get_similarity_index(
return self._sim_index if not sorted else np.sort(self._sim_index)
- def plot_similarity_index(
- self, sim_thres=0.90, top_k=0.01, reduce=False, sorted=False
- ):
+ def plot_similarity_index(self, sim_thres=0.90, top_k=0.01, reduce=False, sorted=False):
"""
Plots the similarity index
@@ -366,10 +339,8 @@ def remove_imgs(self, idxs):
table = pa_table.filter(mask)
ids = [i for i in range(len(table))]
- table = table.set_column(0, 'id', [ids]) # TODO: Revisit this. This is a hack to fix the ids==dix
- self.table = self._create_table(
- self.temp_table_name, data=table, mode="overwrite"
- ) # work on a temporary table
+ table = table.set_column(0, "id", [ids]) # TODO: Revisit this. This is a hack to fix the ids==dix
+ self.table = self._create_table(self.temp_table_name, data=table, mode="overwrite") # work on a temporary table
self.log_status()
@@ -382,9 +353,7 @@ def add_imgs(self, exp, idxs):
"""
table_df = self.table.to_pandas()
data = exp.table.to_pandas().iloc[idxs]
- assert len(table_df["vector"].iloc[0]) == len(
- data["vector"].iloc[0]
- ), "Vector dimension mismatch"
+ assert len(table_df["vector"].iloc[0]) == len(data["vector"].iloc[0]), "Vector dimension mismatch"
table_df = pd.concat([table_df, data], ignore_index=True)
ids = [i for i in range(len(table_df))]
table_df["id"] = ids
@@ -441,25 +410,17 @@ def persist(self, name=None):
new_dataset_info = self.dataset_info.copy()
new_dataset_info.pop("yaml_file")
- new_dataset_info.pop(
- "path"
- ) # relative paths will get messed up when merging datasets
- new_dataset_info.pop(
- "download"
- ) # Assume all files are present offline, there is no way to store metadata yet
+ new_dataset_info.pop("path") # relative paths will get messed up when merging datasets
+ new_dataset_info.pop("download") # Assume all files are present offline, there is no way to store metadata yet
new_dataset_info["train"] = (path / train_txt).resolve().as_posix()
for key, value in new_dataset_info.items():
if isinstance(value, Path):
new_dataset_info[key] = value.as_posix()
- yaml.dump(
- new_dataset_info, open(path / datafile_name, "w")
- ) # update dataset.yaml file
+ yaml.dump(new_dataset_info, open(path / datafile_name, "w")) # update dataset.yaml file
# TODO: not sure if this should be called data_final to prevent overwriting the original data?
- self.table = self._create_table(
- self.table_name, data=self.table.to_arrow(), mode="overwrite"
- )
+ self.table = self._create_table(self.table_name, data=self.table.to_arrow(), mode="overwrite")
db.drop_table(self.temp_table_name)
LOGGER.info("Changes persisted to the dataset.")
@@ -487,6 +448,25 @@ def sql(self, query: str):
return result
+ def dash(self):
+ """
+ Launches a dashboard to visualize the dataset.
+ """
+ Path(TEMP_CONFIG_PATH).parent.mkdir(exist_ok=True, parents=True)
+ with open(TEMP_CONFIG_PATH, "w+") as file:
+ json.dump(self.config, file)
+
+ launch()
+
+ @property
+ def config(self):
+ return {
+ "project": self.project,
+ "model": self.model,
+ "device": self.device,
+ "data": self.data
+ }
+
def _log_training_cmd(self, data_path):
LOGGER.info(
f'{colorstr("LanceDB: ") }New dataset created successfully! Run the following command to train a model:'
@@ -528,17 +508,13 @@ def _copy_table_to_project(self, table_path):
name = Path(table_path).stem # lancedb doesn't need .lance extension
db = lancedb.connect(path)
table = db.open_table(name)
- return self._create_table(
- self.table_name, data=table.to_arrow(), mode="overwrite"
- )
+ return self._create_table(self.table_name, data=table.to_arrow(), mode="overwrite")
def _embedding_func(self, imgs):
embeddings = []
for img in tqdm(imgs):
img = decode(img)
- embeddings.append(
- self.predictor.embed(img, verbose=self.verbose).squeeze().cpu().numpy()
- )
+ embeddings.append(self.predictor.embed(img, verbose=self.verbose).squeeze().cpu().numpy())
return embeddings
def _setup_predictor(self, model, device=""):
diff --git a/yoloexplorer/frontend/__init__.py b/yoloexplorer/frontend/__init__.py
new file mode 100644
index 0000000..81634c2
--- /dev/null
+++ b/yoloexplorer/frontend/__init__.py
@@ -0,0 +1,3 @@
+from .layout import launch
+
+__all__ = ["launch"]
\ No newline at end of file
diff --git a/yoloexplorer/frontend/layout.py b/yoloexplorer/frontend/layout.py
new file mode 100644
index 0000000..aa40d26
--- /dev/null
+++ b/yoloexplorer/frontend/layout.py
@@ -0,0 +1,93 @@
+import json
+import subprocess
+
+import streamlit as st
+from streamlit_dash import image_select
+from yoloexplorer import config
+from yoloexplorer.frontend.states import init_states, update_state, INDEX_PAGE_QUERY_FORM_KEY, INDEX_PAGE_SIMILARITY_FORM_KEY
+
+@st.cache_data
+def _get_dataset():
+ from yoloexplorer import Explorer # function scope import
+
+ with open(config.TEMP_CONFIG_PATH) as json_file:
+ data = json.load(json_file)
+ exp = Explorer(**data)
+ exp.build_embeddings()
+
+ return exp
+
+def reset_to_init_state():
+ if st.session_state.get("EXPLORER") is None:
+ init_states()
+ exp = _get_dataset()
+ update_state("EXPLORER", exp)
+ update_state("IMGS", exp.table.to_pandas()["path"].to_list())
+
+def query_form():
+ with st.form(INDEX_PAGE_QUERY_FORM_KEY):
+ col1, col2 = st.columns([0.8, 0.2])
+ with col1:
+ query = st.text_input("Query", "", label_visibility="collapsed")
+ with col2:
+ submitted = st.form_submit_button("Query")
+ if submitted:
+ if query:
+ exp = st.session_state.EXPLORER
+ df = exp.sql(query)
+ update_state("IMGS", df["path"].to_list())
+
+def similarity_form(selected_imgs):
+ st.write("Similarity Search")
+ with st.form(INDEX_PAGE_SIMILARITY_FORM_KEY):
+ subcol1, subcol2 = st.columns([1,1])
+ with subcol1:
+ st.write("Limit")
+ limit = st.number_input("limit", min_value=None, max_value=None, value=25, label_visibility="collapsed")
+
+ with subcol2:
+ st.write("Selected: ", len(selected_imgs))
+ submitted = st.form_submit_button("Search")
+
+ if submitted:
+ find_similar_imgs(selected_imgs, limit=limit)
+
+def find_similar_imgs(imgs, limit=25):
+ exp = st.session_state.EXPLORER
+ df = exp.table.to_pandas()
+ _, idx = exp.get_similar_imgs(imgs, limit)
+ paths = df["path"][idx].to_list()
+ update_state("IMGS", paths)
+ st.experimental_rerun()
+ print("updated IMGS")
+
+
+
+def layout():
+ st.set_page_config(layout='wide')
+ col1, col2 = st.columns([0.75, 0.25], gap="small")
+
+ reset_to_init_state()
+ with col1:
+ subcol1, subcol2 = st.columns([0.2, 0.8])
+ with subcol1:
+ num = st.number_input("Max Images Displayed", min_value=0, max_value=len(st.session_state.IMGS), value=min(250, len(st.session_state.IMGS)))
+ query_form()
+
+ if st.session_state.IMGS:
+ selected_imgs = image_select(f"Total samples: {len(st.session_state.IMGS)}", images=st.session_state.IMGS[0:num], indices=st.session_state.SELECTED_IMGS, use_container_width=False) #noqa
+
+ with col2:
+ similarity_form(selected_imgs)
+ display_labels = st.checkbox("Labels", value=False) #noqa
+
+
+def launch():
+ cmd = ["streamlit", "run", __file__, "--server.maxMessageSize", "1024"]
+ try:
+ subprocess.run(cmd, check=True)
+ except Exception as e:
+ print(e)
+
+if __name__ == "__main__":
+ layout()
\ No newline at end of file
diff --git a/yoloexplorer/frontend/pages/embedding.py b/yoloexplorer/frontend/pages/embedding.py
new file mode 100644
index 0000000..e69de29
diff --git a/yoloexplorer/frontend/states.py b/yoloexplorer/frontend/states.py
new file mode 100644
index 0000000..c745e9a
--- /dev/null
+++ b/yoloexplorer/frontend/states.py
@@ -0,0 +1,12 @@
+import streamlit as st
+
+INDEX_PAGE_QUERY_FORM_KEY = "index_page_query_form"
+INDEX_PAGE_SIMILARITY_FORM_KEY = "index_page_similarity_form"
+
+def init_states():
+ st.session_state.EXPLORER = None
+ st.session_state.IMGS = []
+ st.session_state.SELECTED_IMGS = []
+
+def update_state(state, value):
+ st.session_state[state] = value
\ No newline at end of file