diff --git a/chartly/chartly.py b/chartly/chartly.py index f15d40d..c06af93 100644 --- a/chartly/chartly.py +++ b/chartly/chartly.py @@ -177,6 +177,52 @@ def add_overlay(self, plot, data, axes_labels=None, customs=None): } ) + # pylint: disable=too-many-arguments,too-many-positional-arguments + def add_subplots( + self, + plots, + data=None, + data_list=None, + axes_labels_list=None, + customs_list=None, + ): + """Add multiple subplots in a single call. + + :param list plots: List of plot names. + :param data: Shared data for all plots. + :param list data_list: Optional list of datasets, one per plot. + :param list axes_labels_list: Optional list of axes_labels dicts. + :param list customs_list: Optional list of customs dicts. + """ + if data is None and data_list is None: + raise ValueError("Either 'data' or 'data_list' must be provided.") + + if data is not None and data_list is not None: + raise ValueError("Provide only one of 'data' or 'data_list'.") + + if data_list is not None and len(data_list) != len(plots): + raise ValueError("'data_list' must have the same length as 'plots'.") + + if axes_labels_list is not None and len(axes_labels_list) != len(plots): + raise ValueError("'axes_labels_list' must have the same length as 'plots'.") + + if customs_list is not None and len(customs_list) != len(plots): + raise ValueError("'customs_list' must have the same length as 'plots'.") + + for idx, plot_name in enumerate(plots): + plot_data = data_list[idx] if data_list is not None else data + axes_labels = ( + axes_labels_list[idx] if axes_labels_list is not None else None + ) + customs = customs_list[idx] if customs_list is not None else None + + self.add_subplot( + plot_name, + plot_data, + axes_labels=axes_labels, + customs=customs, + ) + def add_basemap(self, lon, lat, values, customs=None): """Add a basemap subplot from raw longitude, latitude, and value grids.""" customs = {} if customs is None else customs diff --git a/chartly/tests/test_chartly.py b/chartly/tests/test_chartly.py index d9cfe61..0e86a19 100644 --- a/chartly/tests/test_chartly.py +++ b/chartly/tests/test_chartly.py @@ -22,6 +22,7 @@ from ..utilities import PlotUtilities +# pylint: disable=too-many-instance-attributes class TestPlotting(unittest.TestCase): """Test the plotting module.""" @@ -30,8 +31,8 @@ def setUp(self): self.util = PlotUtilities() # Create a data list - self.dataset_one = np.random.randint(50, size=(20)) - self.dataset_two = np.random.randint(50, size=(20)) + self.dataset_one = np.random.randint(50, size=20) + self.dataset_two = np.random.randint(50, size=20) self.data = [self.dataset_one, self.dataset_two] # Create a dictionary of arguments @@ -46,8 +47,9 @@ def setUp(self): # Create a dictionary of multiplot arguments args = { "super_title": " Test Title", - "super_x_label": "Test X Label", - "super_y_label": "Test Y Label", + "super_xlabel": "Test X Label", + "super_ylabel": "Test Y Label", + "show": False, } # Create a multiplot object self.multiplot = Chart(args) @@ -70,7 +72,7 @@ def test_gen_plot_data_length(self): def test_standardize_data(self): """Test that the data is standardized correctly.""" - data = [val for val in range(10, 100, 20)] + data = list(range(10, 100, 20)) expected_std_data = [-1.4, -0.7, 0, 0.7, 1.4] std_data = [np.round(val, 1) for val in self.util.standardize_dataset(data)] self.assertEqual(std_data, expected_std_data) @@ -103,6 +105,58 @@ def test_multi_clear_axis(self): # Test that the subplot count is now 0 self.assertEqual(self.multiplot.subplot_count, 0) + def test_add_subplots_shared_data(self): + """Test that add_subplots can add multiple subplots with shared data.""" + plots = ["histogram", "boxplot"] + + self.multiplot.add_subplots(plots, self.dataset_one) + + self.assertEqual(self.multiplot.subplot_count, 2) + self.assertEqual(len(self.multiplot.subplots), 1) + self.assertEqual(len(self.multiplot.current_subplot), 1) + + self.multiplot.render() + self.assertEqual(self.multiplot.subplot_count, 0) + + def test_add_subplots_per_plot_data(self): + """Test that add_subplots can use a separate dataset for each plot.""" + plots = ["line_plot", "normal_cdf"] + data_list = [ + [self.dataset_one, self.dataset_two], + self.dataset_one, + ] + + self.multiplot.add_subplots(plots, data_list=data_list) + + self.assertEqual(self.multiplot.subplot_count, 2) + self.assertEqual(self.multiplot.subplots[0][0][0], "line_plot") + self.assertEqual(self.multiplot.current_subplot[0][0], "normal_cdf") + + self.multiplot.render() + self.assertEqual(self.multiplot.subplot_count, 0) + + def test_add_subplots_axes_labels_length_validation(self): + """Test that add_subplots validates axes_labels_list length.""" + plots = ["histogram", "boxplot"] + + with self.assertRaises(ValueError): + self.multiplot.add_subplots( + plots, + self.dataset_one, + axes_labels_list=[{"title": "histogram"}], + ) + + def test_add_subplots_customs_length_validation(self): + """Test that add_subplots validates customs_list length.""" + plots = ["histogram", "boxplot"] + + with self.assertRaises(ValueError): + self.multiplot.add_subplots( + plots, + self.dataset_one, + customs_list=[{"color": "navy"}], + ) + def test_contour_data_length(self): """Test that the contour plot throws an error if the data lengths are unequal.""" # Test that the contour plot throws an error when a user does not send 3 datasets @@ -111,13 +165,16 @@ def test_contour_data_length(self): self.contour() # test that the contour plot does not throw an error when a user sends 3 datasets - X, Y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100)) - Z = np.sin(X) * np.cos(Y) - self.contour.data = [X, Y, Z] + x_grid, y_grid = np.meshgrid( + np.linspace(-5, 5, 100), + np.linspace(-5, 5, 100), + ) + z_grid = np.sin(x_grid) * np.cos(y_grid) + self.contour.data = [x_grid, y_grid, z_grid] self.assertIsNone(self.contour()) # Test that the contour plot throws an error when the data sets are not 2D - self.contour.data = [X, Y, Z[0]] + self.contour.data = [x_grid, y_grid, z_grid[0]] with self.assertRaises(AssertionError): self.contour() diff --git a/docs/source/Multiplots.rst b/docs/source/Multiplots.rst index 52c3311..9a2cfae 100644 --- a/docs/source/Multiplots.rst +++ b/docs/source/Multiplots.rst @@ -1,47 +1,45 @@ Multiple Plot Charts with Chartly Examples ========================================== -Chartly allows users to create multiple plots on the same figure using the `overlay` and `new_subplot` methods. The `overlay` method allows users to overlay multiple plots on a single subplot. The `new_subplot` method allows users to create a new subplot on the figure. - +Chartly allows users to create multiple plots on the same figure using a +simplified interface with ``add_subplot(...)``, ``add_subplots(...)``, +``add_overlay(...)``, and ``render()``. The ``add_subplot(...)`` method +allows users to create a new subplot on the figure. The +``add_subplots(...)`` method allows users to create multiple subplots in a +single call. The ``add_overlay(...)`` method allows users to overlay +additional plots on the current subplot. The ``render()`` method is used to +display the final figure once all plots have been added. Overlay Plots ~~~~~~~~~~~~~ -The `overlay` method allows users to overlay multiple plots on a single subplot. The overlay method requires a dictionary of arguments to be passed to the method. The dictionary should contain the following - -- `data`: The data that will be plotted. -- `plot`: The type of plot to be created. - -Users can also customize and label the plots by including the following keys in the dictionary: - -- `axes_labels`: A dictionary containing the labels of the subplot. -- `customs`: A dictionary containing the customization options of the plot. +To overlay multiple plots on a single subplot, first add the subplot with +``add_subplot(...)`` and then add additional plots to that same subplot with +``add_overlay(...)``. .. code-block:: python import chartly + import numpy as np # define main figure labels - args = {"super_title": "Overlay Example", "super_xlabel": "X", "super_ylabel": "Y", "share_axes": False} + args = { + "super_title": "Overlay Example", + "super_xlabel": "X", + "super_ylabel": "Y", + "share_axes": False, + } multi = chartly.Chart(args) # Define Some Data data = np.random.normal(loc=2, scale=1, size=1000) - # Create a subplot - multi.new_subplot() - - plots = ["histogram", "density"] - - for plot in plots: - # set up overlay payload - overlay_payload = {"plot": plot, "data": data, "axes_labels": {}} + # Add a subplot and overlay a second plot + multi.add_subplot("histogram", data) + multi.add_overlay("density", data) - # Overlay a histogram - multi.overlay(overlay_payload) - - multi() + multi.render() .. image:: https://chartly.s3.amazonaws.com/static/img/overlay_hetero_eg.jpg @@ -53,35 +51,53 @@ Users can also customize and label the plots by including the following keys in Subplots ~~~~~~~~ -The `new_subplot` method allows users to create a new subplot on the figure. The new_subplot method requires no arguments to be passed to the method. When a user is finished creating subplots, they can call the Charts instance to render the figure. - +To create multiple subplots on the same figure, add each subplot directly +with ``add_subplot(...)`` and render the figure with ``render()`` once all +subplots have been added. .. code-block:: python import chartly + import numpy as np # define main figure labels - args = {"super_title": "Subplots Example", "super_xlabel": "X", "super_ylabel": "Y", "share_axes": False} + args = { + "super_title": "Subplots Example", + "super_xlabel": "X", + "super_ylabel": "Y", + "share_axes": False, + } multi = chartly.Chart(args) # Define Some Data data = np.random.normal(loc=0.8, scale=2, size=50) - # Define Plots - plots = ["histogram", "density", "probability_plot", "line_plot", "normal_cdf"] - - for plot in plots: - # Create a subplot - multi.new_subplot() - axes_labels = {"xlabel": " ", "ylabel": " ", "title": plot} - - overlay_payload = {"plot": plot, "data": data, "axes_labels": axes_labels} - multi.overlay(overlay_payload) - - multi.overlay(overlay_payload) - - multi() + # Define plots + plots = [ + "histogram", + "density", + "probability_plot", + "line_plot", + "normal_cdf", + ] + + axes_labels_list = [ + {"title": "histogram"}, + {"title": "density"}, + {"title": "prob_plot"}, + {"title": "gen_plot"}, + {"title": "norm_cdf"}, + ] + + # Add all subplots in one call + multi.add_subplots( + plots, + data, + axes_labels_list=axes_labels_list, + ) + + multi.render() .. image:: https://chartly.s3.amazonaws.com/static/img/subplots_eg.jpg :alt: SubplotsExample diff --git a/docs/source/index.rst b/docs/source/index.rst index bb2b0c1..14442b9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,7 +49,7 @@ The Chartly package currently has eight (8) available scientific plots that can - Density Plot - Box Plot -Chartly allows users to build plots by first creating a main figure and then adding subplots to the figure. To initialize a main figure, users can create a `Chart` instance. Users can also label and customize the main figure my passing an optional dictionary. The dictionary should contain the following keys: +Chartly allows users to build plots by first creating a main figure and then adding subplots to the figure. To initialize a main figure, users can create a `Chart` instance. Users can also label and customize the main figure by passing an optional dictionary. The dictionary should contain the following keys: - `super_title` (str): The title of the main figure. - `super_xlabel` (str): The x-axis label of the main figure. @@ -68,15 +68,10 @@ Chartly allows users to build plots by first creating a main figure and then add plot = chartly.Chart(super_axes_labels) -To create a plot, a user must create a subplot by calling the `new_subplot` method and passing it an optional dictionary of arguments. The dictionary should contain the following keys: - -- `data`: The data that will be plotted. -- `plot`: The type of plot to be created. - -Users can also customize and label the plots by including the following keys in the dictionary: - -- `axes_labels`: A dictionary containing the labels of the subplot. -- `customs`: A dictionary containing the customization options of the plot. +To create a plot, users can directly add a subplot with ``add_subplot(...)``. +Additional plots can be added to the same subplot with ``add_overlay(...)``. +This keeps the public interface simpler by avoiding manual payload +dictionaries. .. code-block:: python @@ -84,44 +79,28 @@ Users can also customize and label the plots by including the following keys in # 3. Define Some Data data = np.random.randn(100) - # 4. Build the plot dictionary - plot_payload = { - "plot": "histogram", - "data": data, - } - - # 5. Plot the data - plot.new_subplot(plot_payload) - + # 4. Add a subplot directly + plot.add_subplot("histogram", data) -To overlay a new plot onto the current subplot, a user can call the `overlay` method and pass it a dictionary of arguments, similar to what is shown above: +To overlay a new plot onto the current subplot, users can call +``add_overlay(...)`` and pass the plot type and data directly. .. code-block:: python - # 6. build the overlay plot dictionary - plot_payload = { - "plot": "density", - "data": data, - } + # 5. Overlay another plot + plot.add_overlay("density", data) - # 7. Overlay the plot - plot.overlay(plot_payload) - - -To add a new subplot, users can call the `new_subplot` method again and pass it a dictionary of arguments. +To add multiple subplots at once, users can call ``add_subplots(...)``. .. code-block:: python - # 8. build the plot dictionary - plot_payload = { - "plot": "boxplot", - "data": data, - } - - # 9. Plot data onto new subplot - plot.new_subplot(plot_payload) + # 6. Add multiple subplots at once + plot.add_subplots( + ["boxplot", "normal_cdf"], + data + ) Users can also customize the axes of each subplot. @@ -131,37 +110,33 @@ Users can also customize the axes of each subplot. .. code-block:: python - # 10. Define a random exponential function + # 7. Define a random exponential function exp_func = lambda x: np.e ** (-500 * x + 2) x = np.linspace(0, 1, num=100) y = list(map(exp_func, x)) - # 11. build the plot dictionary - plot_payload = { - "plot": "line_plot", - "data": y, - "axes_labels": {"scale": "semilogy", "base": 10, "linelabel": "Semilogy Line"}, - } - - # 12. Plot exponential function - plot.new_subplot(plot_payload) - + # 8. Add customized subplot + plot.add_subplot( + "line_plot", + y, + axes_labels={"scale": "semilogy", "base": 10, "linelabel": "Semilogy Line"}, + ) -Finally, the figure can be rendered by calling the `Chart` instance. +Finally, the figure can be rendered by calling ``render()``. .. code-block:: python - # 13. Render the main figure - plot() + # 9. Render the main figure + plot.render() To save the figure that was rendered, users can call the `save` method. The default file format is `eps` and the default file name is `chartly_plot`. To change the file format and name, update the plot's properties. .. code-block:: python - # 14. Save the figure with a different file format and name + # 10. Save the figure with a different file format and name plot.format = "jpg" plot.fname = "my_plot" plot.save()