Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions chartly/chartly.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,52 @@ def add_overlay(self, plot, data, axes_labels=None, customs=None):
}
)

# pylint: disable=too-many-arguments,too-many-positional-arguments
def add_subplots(
self,
plots,
data=None,
data_list=None,
axes_labels_list=None,
customs_list=None,
):
"""Add multiple subplots in a single call.

:param list plots: List of plot names.
:param data: Shared data for all plots.
:param list data_list: Optional list of datasets, one per plot.
:param list axes_labels_list: Optional list of axes_labels dicts.
:param list customs_list: Optional list of customs dicts.
"""
if data is None and data_list is None:
raise ValueError("Either 'data' or 'data_list' must be provided.")

if data is not None and data_list is not None:
raise ValueError("Provide only one of 'data' or 'data_list'.")

if data_list is not None and len(data_list) != len(plots):
raise ValueError("'data_list' must have the same length as 'plots'.")

if axes_labels_list is not None and len(axes_labels_list) != len(plots):
raise ValueError("'axes_labels_list' must have the same length as 'plots'.")

if customs_list is not None and len(customs_list) != len(plots):
raise ValueError("'customs_list' must have the same length as 'plots'.")

for idx, plot_name in enumerate(plots):
plot_data = data_list[idx] if data_list is not None else data
axes_labels = (
axes_labels_list[idx] if axes_labels_list is not None else None
)
customs = customs_list[idx] if customs_list is not None else None

self.add_subplot(
plot_name,
plot_data,
axes_labels=axes_labels,
customs=customs,
)

def add_basemap(self, lon, lat, values, customs=None):
"""Add a basemap subplot from raw longitude, latitude, and value grids."""
customs = {} if customs is None else customs
Expand Down
75 changes: 66 additions & 9 deletions chartly/tests/test_chartly.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..utilities import PlotUtilities


# pylint: disable=too-many-instance-attributes
class TestPlotting(unittest.TestCase):
"""Test the plotting module."""

Expand All @@ -30,8 +31,8 @@ def setUp(self):
self.util = PlotUtilities()

# Create a data list
self.dataset_one = np.random.randint(50, size=(20))
self.dataset_two = np.random.randint(50, size=(20))
self.dataset_one = np.random.randint(50, size=20)
self.dataset_two = np.random.randint(50, size=20)
self.data = [self.dataset_one, self.dataset_two]

# Create a dictionary of arguments
Expand All @@ -46,8 +47,9 @@ def setUp(self):
# Create a dictionary of multiplot arguments
args = {
"super_title": " Test Title",
"super_x_label": "Test X Label",
"super_y_label": "Test Y Label",
"super_xlabel": "Test X Label",
"super_ylabel": "Test Y Label",
"show": False,
}
# Create a multiplot object
self.multiplot = Chart(args)
Expand All @@ -70,7 +72,7 @@ def test_gen_plot_data_length(self):

def test_standardize_data(self):
"""Test that the data is standardized correctly."""
data = [val for val in range(10, 100, 20)]
data = list(range(10, 100, 20))
expected_std_data = [-1.4, -0.7, 0, 0.7, 1.4]
std_data = [np.round(val, 1) for val in self.util.standardize_dataset(data)]
self.assertEqual(std_data, expected_std_data)
Expand Down Expand Up @@ -103,6 +105,58 @@ def test_multi_clear_axis(self):
# Test that the subplot count is now 0
self.assertEqual(self.multiplot.subplot_count, 0)

def test_add_subplots_shared_data(self):
"""Test that add_subplots can add multiple subplots with shared data."""
plots = ["histogram", "boxplot"]

self.multiplot.add_subplots(plots, self.dataset_one)

self.assertEqual(self.multiplot.subplot_count, 2)
self.assertEqual(len(self.multiplot.subplots), 1)
self.assertEqual(len(self.multiplot.current_subplot), 1)

self.multiplot.render()
self.assertEqual(self.multiplot.subplot_count, 0)

def test_add_subplots_per_plot_data(self):
"""Test that add_subplots can use a separate dataset for each plot."""
plots = ["line_plot", "normal_cdf"]
data_list = [
[self.dataset_one, self.dataset_two],
self.dataset_one,
]

self.multiplot.add_subplots(plots, data_list=data_list)

self.assertEqual(self.multiplot.subplot_count, 2)
self.assertEqual(self.multiplot.subplots[0][0][0], "line_plot")
self.assertEqual(self.multiplot.current_subplot[0][0], "normal_cdf")

self.multiplot.render()
self.assertEqual(self.multiplot.subplot_count, 0)

def test_add_subplots_axes_labels_length_validation(self):
"""Test that add_subplots validates axes_labels_list length."""
plots = ["histogram", "boxplot"]

with self.assertRaises(ValueError):
self.multiplot.add_subplots(
plots,
self.dataset_one,
axes_labels_list=[{"title": "histogram"}],
)

def test_add_subplots_customs_length_validation(self):
"""Test that add_subplots validates customs_list length."""
plots = ["histogram", "boxplot"]

with self.assertRaises(ValueError):
self.multiplot.add_subplots(
plots,
self.dataset_one,
customs_list=[{"color": "navy"}],
)

def test_contour_data_length(self):
"""Test that the contour plot throws an error if the data lengths are unequal."""
# Test that the contour plot throws an error when a user does not send 3 datasets
Expand All @@ -111,13 +165,16 @@ def test_contour_data_length(self):
self.contour()

# test that the contour plot does not throw an error when a user sends 3 datasets
X, Y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
Z = np.sin(X) * np.cos(Y)
self.contour.data = [X, Y, Z]
x_grid, y_grid = np.meshgrid(
np.linspace(-5, 5, 100),
np.linspace(-5, 5, 100),
)
z_grid = np.sin(x_grid) * np.cos(y_grid)
self.contour.data = [x_grid, y_grid, z_grid]
self.assertIsNone(self.contour())

# Test that the contour plot throws an error when the data sets are not 2D
self.contour.data = [X, Y, Z[0]]
self.contour.data = [x_grid, y_grid, z_grid[0]]
with self.assertRaises(AssertionError):
self.contour()

Expand Down
98 changes: 57 additions & 41 deletions docs/source/Multiplots.rst
Original file line number Diff line number Diff line change
@@ -1,47 +1,45 @@
Multiple Plot Charts with Chartly Examples
==========================================

Chartly allows users to create multiple plots on the same figure using the `overlay` and `new_subplot` methods. The `overlay` method allows users to overlay multiple plots on a single subplot. The `new_subplot` method allows users to create a new subplot on the figure.

Chartly allows users to create multiple plots on the same figure using a
simplified interface with ``add_subplot(...)``, ``add_subplots(...)``,
``add_overlay(...)``, and ``render()``. The ``add_subplot(...)`` method
allows users to create a new subplot on the figure. The
``add_subplots(...)`` method allows users to create multiple subplots in a
single call. The ``add_overlay(...)`` method allows users to overlay
additional plots on the current subplot. The ``render()`` method is used to
display the final figure once all plots have been added.

Overlay Plots
~~~~~~~~~~~~~

The `overlay` method allows users to overlay multiple plots on a single subplot. The overlay method requires a dictionary of arguments to be passed to the method. The dictionary should contain the following

- `data`: The data that will be plotted.
- `plot`: The type of plot to be created.

Users can also customize and label the plots by including the following keys in the dictionary:

- `axes_labels`: A dictionary containing the labels of the subplot.
- `customs`: A dictionary containing the customization options of the plot.
To overlay multiple plots on a single subplot, first add the subplot with
``add_subplot(...)`` and then add additional plots to that same subplot with
``add_overlay(...)``.
Comment thread
k-alphonse marked this conversation as resolved.

.. code-block:: python

import chartly
import numpy as np

# define main figure labels
args = {"super_title": "Overlay Example", "super_xlabel": "X", "super_ylabel": "Y", "share_axes": False}
args = {
"super_title": "Overlay Example",
"super_xlabel": "X",
"super_ylabel": "Y",
"share_axes": False,
}

multi = chartly.Chart(args)

# Define Some Data
data = np.random.normal(loc=2, scale=1, size=1000)

# Create a subplot
multi.new_subplot()

plots = ["histogram", "density"]

for plot in plots:
# set up overlay payload
overlay_payload = {"plot": plot, "data": data, "axes_labels": {}}
# Add a subplot and overlay a second plot
multi.add_subplot("histogram", data)
multi.add_overlay("density", data)

# Overlay a histogram
multi.overlay(overlay_payload)

multi()
multi.render()


.. image:: https://chartly.s3.amazonaws.com/static/img/overlay_hetero_eg.jpg
Expand All @@ -53,35 +51,53 @@ Users can also customize and label the plots by including the following keys in
Subplots
~~~~~~~~

The `new_subplot` method allows users to create a new subplot on the figure. The new_subplot method requires no arguments to be passed to the method. When a user is finished creating subplots, they can call the Charts instance to render the figure.

To create multiple subplots on the same figure, add each subplot directly
with ``add_subplot(...)`` and render the figure with ``render()`` once all
subplots have been added.

.. code-block:: python

import chartly
import numpy as np

# define main figure labels
args = {"super_title": "Subplots Example", "super_xlabel": "X", "super_ylabel": "Y", "share_axes": False}
args = {
"super_title": "Subplots Example",
"super_xlabel": "X",
"super_ylabel": "Y",
"share_axes": False,
}

multi = chartly.Chart(args)

# Define Some Data
data = np.random.normal(loc=0.8, scale=2, size=50)

# Define Plots
plots = ["histogram", "density", "probability_plot", "line_plot", "normal_cdf"]

for plot in plots:
# Create a subplot
multi.new_subplot()
axes_labels = {"xlabel": " ", "ylabel": " ", "title": plot}

overlay_payload = {"plot": plot, "data": data, "axes_labels": axes_labels}
multi.overlay(overlay_payload)

multi.overlay(overlay_payload)

multi()
# Define plots
plots = [
"histogram",
"density",
"probability_plot",
"line_plot",
"normal_cdf",
]

axes_labels_list = [
{"title": "histogram"},
{"title": "density"},
{"title": "prob_plot"},
{"title": "gen_plot"},
{"title": "norm_cdf"},
]

# Add all subplots in one call
multi.add_subplots(
plots,
data,
axes_labels_list=axes_labels_list,
)

multi.render()

.. image:: https://chartly.s3.amazonaws.com/static/img/subplots_eg.jpg
:alt: SubplotsExample
Expand Down
Loading
Loading