diff --git a/README.md b/README.md index eef67a3..2bf6bd6 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,6 @@ To turn up your first dashboard and test if your installation works, type: ktdashboard -demo tests/test_cache_1000.json ``` This creates a KTdashboard using a test cache with about a 1000 different benchmark configurations in it. -The ``-demo`` switch enables demo mode, which means that KTdashboard mimicks a live tuning run. KTdashboard uses Kernel Tuner's cachefiles to visualize the auto-tuning results as they come in. Cache files are used within Kernel Tuner to record all information about all benchmarked kernel configurations. This allows the tuner to do several @@ -41,6 +40,15 @@ allows you to monitor the tuner's progress using: ktdashboard my_cache_filename.json ``` +You can choose the backend to visualize results. Example: +``` +# Launch the Panel + Bokeh dashboard (default) +ktdashboard my_cache_filename.json + +# Launch the Streamlit + Plotly dashboard +ktdashboard --backend streamlit my_cache_filename.json +``` + ## License, contributions, citation KTdashboard is considered part of the Kernel Tuner project, for licensing, contribution guide, and citation information please see diff --git a/ktdashboard/dashboard.py b/ktdashboard/dashboard.py new file mode 100644 index 0000000..4623465 --- /dev/null +++ b/ktdashboard/dashboard.py @@ -0,0 +1,139 @@ +from typing import Dict, Any, List +import json +import pandas as pd + + +class Dashboard: + def __init__(self, cache_file: str): + self.cache_file = cache_file + + # Open cachefile for reading appended results + self.cache_file_handle = open(cache_file, "r") + filestr = self.cache_file_handle.read().strip() + + # Try to be tolerant for trailing missing brackets or commas + if filestr and not filestr.endswith("}\n}"): + if filestr[-1] == ",": + filestr = filestr[:-1] + filestr = filestr + "}\n}" + + cached_data = json.loads(filestr) if filestr else {} + + self.kernel_name = cached_data.get("kernel_name", "") + self.device_name = cached_data.get("device_name", "") + self.objective = cached_data.get("objective", "time") + + # raw performance records + data = list(cached_data.get("cache", {}).values()) + data = [ + d + for d in data + if d.get(self.objective) != 1e20 + and not isinstance(d.get(self.objective), str) + ] + + self.index = len(data) + self._all_data = data + + # tune parameters + self.tune_params_keys = cached_data.get("tune_params_keys", []) + self.all_tune_params: Dict[str, List[Any]] = {} + for key in self.tune_params_keys: + values = cached_data.get("tune_params", {}).get(key, []) + for row in data: + if row.get(key) not in values: + values = sorted(values + [row.get(key)]) + self.all_tune_params[key] = values + + # find keys + self.single_value_tune_param_keys = [ + k for k in self.tune_params_keys if len(self.all_tune_params[k]) == 1 + ] + self.tune_param_keys = [ + k + for k in self.tune_params_keys + if k not in self.single_value_tune_param_keys + ] + + scalar_value_keys = [ + k + for k in (data[0].keys() if data else []) + if not isinstance(data[0][k], list) + and k not in self.single_value_tune_param_keys + ] + self.scalar_value_keys = scalar_value_keys + self.output_keys = [ + k for k in scalar_value_keys if k not in self.tune_param_keys + ] + self.float_keys = [ + k for k in self.output_keys if isinstance(data[0].get(k), float) if data + ] + + # prepare DataFrame + self.data_df = ( + pd.DataFrame(data)[self.scalar_value_keys] + if len(self.scalar_value_keys) > 0 + else pd.DataFrame() + ) + self.data_df = self.data_df.reset_index(drop=True) + if not self.data_df.empty: + self.data_df.insert(0, "index", self.data_df.index.astype(int)) + + # categorical conversion for tune params + for key in self.tune_param_keys: + if key in self.data_df.columns: + self.data_df[key] = pd.Categorical( + self.data_df[key], + categories=self.all_tune_params[key], + ordered=True, + ) + + # selections for filtering + self.selected_tune_params = { + key: self.all_tune_params[key].copy() for key in self.tune_param_keys + } + + def close(self): + try: + self.cache_file_handle.close() + except Exception: + pass + + def get_filtered_df(self) -> pd.DataFrame: + """Return filtered DataFrame based on selected_tune_params.""" + mask = pd.Series(True, index=self.data_df.index) + for k, v in self.selected_tune_params.items(): + mask &= self.data_df[k].isin(v) + return self.data_df[mask] + + def update_selection(self, key: str, values: List[Any]): + """Update selection for a tune parameter.""" + self.selected_tune_params[key] = values + + def get_stream_for_index(self, i: int) -> Dict[str, List[Any]]: + """Return a stream dict for a single element at index i.""" + element = self._all_data[i] + stream_dict = { + k: [v] + for k, v in dict(element, index=i).items() + if k in ["index"] + self.scalar_value_keys + } + return stream_dict + + def read_new_contents(self) -> List[Dict[str, List[Any]]]: + """Read appended JSON from the cachefile and return list of stream_dicts for new entries.""" + new_contents = self.cache_file_handle.read().strip() + stream_dicts = [] + if new_contents: + # process new contents (parse as JSON, make into dict that goes into source.stream) + new_contents_json = "{" + new_contents[:-1] + "}" + new_data = list(json.loads(new_contents_json).values()) + for i, element in enumerate(new_data): + stream_dict = { + k: [v] + for k, v in dict(element, index=self.index + i).items() + if k in ["index"] + self.scalar_value_keys + } + stream_dicts.append(stream_dict) + self.index += len(new_data) + return stream_dicts diff --git a/ktdashboard/ktdashboard.py b/ktdashboard/ktdashboard.py index c262aa8..5048dfa 100644 --- a/ktdashboard/ktdashboard.py +++ b/ktdashboard/ktdashboard.py @@ -1,306 +1,61 @@ #!/usr/bin/env python -import json -import sys -import os - -import panel as pn -import panel.widgets as pnw -import pandas as pd -import bokeh.palettes -from bokeh.models.ranges import FactorRange -from bokeh.transform import jitter -from bokeh.models import HoverTool, LinearColorMapper, CategoricalColorMapper -from bokeh.plotting import ColumnDataSource, figure - - -class KTdashboard: - """ Main object to instantiate to hold everything related to a running dashboard""" - - def __init__(self, cache_file, demo=False, default_key=None): - self.demo = demo - self.cache_file = cache_file - - # read in the cachefile - self.cache_file_handle = open(cache_file, "r") - filestr = self.cache_file_handle.read().strip() - # if file was not properly closed, pretend it was properly closed - if not filestr[-3:] == "}\n}": - # remove the trailing comma if any, and append closing brackets - if filestr[-1] == ",": - filestr = filestr[:-1] - filestr = filestr + "}\n}" - - cached_data = json.loads(filestr) - self.kernel_name = cached_data["kernel_name"] - self.device_name = cached_data["device_name"] - if "objective" in cached_data: - self.objective = cached_data["objective"] - else: - self.objective = "time" - - # get the performance data - data = list(cached_data["cache"].values()) - data = [d for d in data if d[self.objective] != 1e20 and not isinstance(d[self.objective], str)] - - # use all data or just the first 1000 records in demo mode - self.index = len(data) - if self.demo: - self.index = min(len(data), 1000) - - all_tune_param_keys = cached_data["tune_params_keys"] - all_tune_params = dict() - - for key in all_tune_param_keys: - values = cached_data["tune_params"][key] - for row in data: - if row[key] not in values: - values = sorted(values + [row[key]]) - - all_tune_params[key] = values - - # figure out which keys are interesting - single_value_tune_param_keys = [key for key in all_tune_param_keys if len(all_tune_params[key]) == 1] - tune_param_keys = [key for key in all_tune_param_keys if key not in single_value_tune_param_keys] - scalar_value_keys = [key for key in data[0].keys() if not isinstance(data[0][key],list) and key not in single_value_tune_param_keys] - output_keys = [key for key in scalar_value_keys if key not in tune_param_keys] - float_keys = [key for key in output_keys if isinstance(data[0][key], float)] - - self.single_value_tune_param_keys = single_value_tune_param_keys - self.tune_param_keys = tune_param_keys - self.scalar_value_keys = scalar_value_keys - self.output_keys = output_keys - self.float_keys = float_keys - - # Convert to a data frame - data_df = pd.DataFrame(data[:self.index])[scalar_value_keys] - - # Replace all column that are objects by categorical - for column, dtype in data_df.dtypes.items(): - if column in tune_param_keys and dtype == "object": - data_df[column] = pd.Categorical( - data_df[column], - categories=all_tune_params[column], - ordered=True) - - self.data = data - self.data_df = data_df - self.source = ColumnDataSource(data=self.data_df) - self.selected_tune_params = {key: all_tune_params[key].copy() for key in tune_param_keys} - - self.plot_width = 900 - self.plot_height = 600 - plot_options=dict(width=self.plot_width, min_width=self.plot_width, height=self.plot_height, min_height=self.plot_height) - plot_options['tools'] = [HoverTool(tooltips=[(k, "@{"+k+"}" + ("{0.00}" if k in float_keys else "")) for k in scalar_value_keys]), "box_select,box_zoom,save,reset"] - - self.plot_options = plot_options - - # find default key - if default_key is None: - default_key = 'GFLOP/s' - if default_key not in scalar_value_keys: - default_key = 'time' # Check if time is defined - - if default_key not in scalar_value_keys: - default_key = scalar_value_keys[0] - - # setup widgets - self.yvariable = pnw.Select(name='Y', value=default_key, options=scalar_value_keys) - self.xvariable = pnw.Select(name='X', value='index', options=['index']+scalar_value_keys) - self.colorvariable = pnw.Select(name='Color By', value=default_key, options=scalar_value_keys) - self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) - self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) - - # connect widgets with the function that draws the scatter plot - self.scatter = pn.bind( - self.make_scatter, - xvariable=self.xvariable, - yvariable=self.yvariable, - color_by=self.colorvariable, - xscale=self.xscale, - yscale=self.yscale) - - # actually build up the dashboard - self.dashboard = pn.template.BootstrapTemplate(title='Kernel Tuner Dashboard') - self.dashboard.main.append(self.scatter) - self.dashboard.sidebar.append(pn.Column( - self.yvariable, - self.xvariable, - self.colorvariable)) - - self.dashboard.sidebar.append(pn.layout.Divider()) - - self.dashboard.sidebar.append(pn.Row( - pn.pane.Markdown("X axis"), - self.xscale - )) - - self.dashboard.sidebar.append(pn.Row( - pn.pane.Markdown("Y axis"), - self.yscale - )) - - self.dashboard.sidebar.append(pn.layout.Divider()) - - self.multi_choice = list() - for tune_param in self.tune_param_keys: - values = all_tune_params[tune_param] - - multi_choice = pnw.MultiChoice(name=tune_param, value=values, options=values) - self.dashboard.sidebar.append(multi_choice) - - row = pn.bind(self.update_data_selection, tune_param, multi_choice) - self.dashboard.sidebar.append(row) - - def __del__(self): - self.cache_file_handle.close() - - def notebook(self): - """ Return a static version of the dashboard without the template """ - return pn.Row(pn.Column(self.yvariable, self.xvariable, self.colorvariable), self.scatter) - - def update_data_selection(self, tune_param, multi_choice): - """ Update view according to values selected by the user """ - selection_key = tune_param - selection_values = multi_choice - - # The idea here is to remember multiple selections across different tunable parameters - # but also allowing these to shrink or grow over time - # this is why the mask is recomputed every time the selection changes - self.selected_tune_params[selection_key] = selection_values - - # Cross selection based on all selections in all tunable parameters - mask = pd.Series(True, index=self.data_df.index) - for k,v in self.selected_tune_params.items(): - mask &= self.data_df[k].isin(v) - - index = self.data_df.index[mask].values - self.index = index - - data_df = self.data_df[mask] - self.source.data = data_df - - def update_colors(self, color_by): - dtype = self.data_df.dtypes[color_by] - - if dtype == "category": - factors = dtype.categories - if len(factors) < 10: - palette = bokeh.palettes.Category10[10] - else: - palette = bokeh.palettes.Category20[20] - - - color_mapper = CategoricalColorMapper(palette=palette, factors=factors) - - else: - color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.data_df[color_by]), - high=max(self.data_df[color_by])) - - color = {'field': color_by, 'transform': color_mapper} - return color - - def make_scatter(self, xvariable, yvariable, color_by, xscale, yscale): - color = self.update_colors(color_by) - - x = xvariable - y = yvariable - - plot_options = dict(self.plot_options) - plot_options["x_axis_type"] = xscale - plot_options["y_axis_type"] = yscale - - # For categorical data, we add some jitter - dtype = self.data_df.dtypes.get(xvariable) - if dtype == "category": - plot_options["x_range"] = list(dtype.categories) - x = jitter(xvariable, width=0.02, distribution="normal", - range=FactorRange(*dtype.categories)) - - dtype = self.data_df.dtypes.get(yvariable) - if dtype == "category": - plot_options["y_range"] = list(dtype.categories) - x = jitter(yvariable, width=0.02, distribution="normal", - range=FactorRange(*dtype.categories)) - - f = figure(**plot_options) - f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source) - f.xaxis.axis_label = xvariable - f.yaxis.axis_label = yvariable - - bokeh_pane = pn.pane.Bokeh(object=f, min_width=self.plot_width, min_height=self.plot_height, max_width=self.plot_width, max_height=self.plot_height) - - pane = pn.Column(pn.pane.Markdown(f"## Auto-tuning {self.kernel_name} on {self.device_name}"), bokeh_pane) - - return pane - - def update_plot(self, i): - stream_dict = {k:[v] for k,v in dict(self.data[i], index=i).items() if k in ['index']+self.scalar_value_keys} - self.source.stream(stream_dict) - - def update_data(self): - if not self.demo: - new_contents = self.cache_file_handle.read().strip() - if new_contents: - - # process new contents (parse as JSON, make into dict that goes into source.stream) - new_contents_json = "{" + new_contents[:-1] + "}" - new_data = list(json.loads(new_contents_json).values()) - - for i,element in enumerate(new_data): - - stream_dict = {k:[v] for k,v in dict(element, index=self.index+i).items() if k in ['index']+self.scalar_value_keys} - self.source.stream(stream_dict) - - self.index += len(new_data) - - if self.demo: - if self.index < (len(self.data)-1): - self.update_plot(self.index) - self.index += 1 - - - -def print_usage(): - print("Usage: ./dashboard.py [-demo] filename") - print(" -demo option to enable demo mode that mimicks a running Kernel Tuner session") - print(" filename name of the cachefile") - exit(0) - - - -def cli(): - """ implements the command-line interface to start the dashboard """ - - if len(sys.argv) < 2: - print_usage() - - filename = "" - demo = False - if len(sys.argv) == 2: - if os.path.isfile(sys.argv[1]): - filename = sys.argv[1] - else: +import argparse, subprocess, sys, os + + +def main(): + parser = argparse.ArgumentParser(prog="ktdashboard") + parser.add_argument( + "--backend", + choices=["panel", "streamlit"], + default="panel", + help="Backend to use for visualization", + ) + parser.add_argument( + "filename", nargs="?", help="Path to cache JSON file (optional for streamlit)" + ) + + args = parser.parse_args() + + if args.backend == "panel": + if not args.filename: + print("Cachefile is required for the 'panel' backend") + exit(1) + if not os.path.isfile(args.filename): print("Cachefile not found") exit(1) - elif len(sys.argv) == 3: - if sys.argv[1] == "-demo": - demo = True - else: - print_usage() - if os.path.isfile(sys.argv[2]): - filename = sys.argv[2] - - db = KTdashboard(filename, demo=demo) - - db.dashboard.servable() - - def dashboard_f(): - """ wrapper function to add the callback, doesn't work without this construct """ - pn.state.add_periodic_callback(db.update_data, 1000) - return db.dashboard - server = pn.serve(dashboard_f, show=False) + if ( + args.backend == "streamlit" + and args.filename + and not os.path.isfile(args.filename) + ): + print("Cachefile not found") + exit(1) + + if args.backend == "streamlit": + script_path = os.path.join(os.path.dirname(__file__), "streamlit_dashboard.py") + cmd = [ + sys.executable, + "-m", + "streamlit", + "run", + script_path, + "--", + ] + # pass filename only when provided + if args.filename: + cmd.append(args.filename) + subprocess.run(cmd) + return + + if args.backend == "panel": + from panel_dashboard import serve_panel + + serve_panel(args.filename) + return + + exit(1) if __name__ == "__main__": - cli() + main() diff --git a/ktdashboard/panel_dashboard.py b/ktdashboard/panel_dashboard.py new file mode 100644 index 0000000..8766cf2 --- /dev/null +++ b/ktdashboard/panel_dashboard.py @@ -0,0 +1,260 @@ +from typing import Optional +import panel as pn +import panel.widgets as pnw +import pandas as pd +import bokeh.palettes +from bokeh.models.ranges import FactorRange +from bokeh.transform import jitter +from bokeh.models import ( + DataTable, + CategoricalColorMapper, + HoverTool, + LinearColorMapper, + TableColumn, +) +from bokeh.plotting import ColumnDataSource, figure + +from dashboard import Dashboard + + +class PanelDashboard: + def __init__( + self, cachefile: str, default_key: Optional[str] = None): + self.model = Dashboard(cachefile) + + # local copies for UI + self.data_df = self.model.data_df + self.scalar_value_keys = self.model.scalar_value_keys + self.tune_param_keys = self.model.tune_param_keys + self.all_tune_params = self.model.all_tune_params + self.source = ColumnDataSource(data=self._df_categorical_to_str(self.data_df)) + self.selected_tune_params = { + key: self.all_tune_params[key].copy() for key in self.tune_param_keys + } + + # layout parameters + self.plot_height = 600 + plot_options = dict( + height=self.plot_height, + min_height=self.plot_height, + sizing_mode="stretch_width", + ) + float_keys = [k for k in self.scalar_value_keys if k in self.model.float_keys] + plot_options["tools"] = [ + HoverTool( + tooltips=[ + (k, "@{" + k + "}" + ("{0.00}" if k in float_keys else "")) + for k in self.scalar_value_keys + ] + ), + "box_select,box_zoom,save,reset", + ] + self.plot_options = plot_options + + # find default key + if default_key is None: + default_key = "GFLOP/s" + if default_key not in self.scalar_value_keys: + default_key = ( + "time" + if "time" in self.scalar_value_keys + else (self.scalar_value_keys[0] if self.scalar_value_keys else None) + ) + + # Widgets + self.yvariable = pnw.Select( + name="Y", value=default_key, options=self.scalar_value_keys + ) + self.xvariable = pnw.Select( + name="X", value="index", options=["index"] + self.scalar_value_keys + ) + self.colorvariable = pnw.Select( + name="Color By", value=default_key, options=self.scalar_value_keys + ) + self.xscale = pnw.RadioButtonGroup(name="xscale", options=["linear", "log"]) + self.yscale = pnw.RadioButtonGroup(name="yscale", options=["linear", "log"]) + + # checkbox to show/hide the data table, toggling re-renders the pane + self.show_table_checkbox = pnw.Checkbox(name="Show table", value=False) + + # connect widgets + self.scatter = pn.bind( + self.make_pane, + xvariable=self.xvariable, + yvariable=self.yvariable, + color_by=self.colorvariable, + xscale=self.xscale, + yscale=self.yscale, + show_table=self.show_table_checkbox, + ) + + # build up the dashboard + self.dashboard = pn.template.BootstrapTemplate(title="Kernel Tuner Dashboard") + self.dashboard.main.append(self.scatter) + self.dashboard.sidebar.append( + pn.Column(self.yvariable, self.xvariable, self.colorvariable) + ) + self.dashboard.sidebar.append(pn.layout.Divider()) + self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("X axis"), self.xscale)) + self.dashboard.sidebar.append(pn.Row(pn.pane.Markdown("Y axis"), self.yscale)) + self.dashboard.sidebar.append(pn.layout.Divider()) + self.dashboard.sidebar.append(self.show_table_checkbox) + self.dashboard.sidebar.append(pn.layout.Divider()) + + for tune_param in self.tune_param_keys: + values = self.all_tune_params[tune_param] + multi_choice = pnw.MultiChoice( + name=tune_param, value=values, options=values + ) + self.dashboard.sidebar.append(multi_choice) + row = pn.bind(self.update_data_selection, tune_param, multi_choice) + self.dashboard.sidebar.append(row) + + def _df_categorical_to_str(self, df: pd.DataFrame) -> pd.DataFrame: + """Return a copy of `df` where categorical columns are converted to strings.""" + df2 = df.copy() + for c in df2.columns: + if pd.api.types.is_categorical_dtype(df2[c]): + df2[c] = df2[c].astype(str) + return df2 + + def _convert_stream_dict(self, sd: dict) -> dict: + """Convert values inside a stream-dict to string for categorical columns.""" + sd2 = {} + for k, v in sd.items(): + if k in self.data_df.columns and pd.api.types.is_categorical_dtype( + self.data_df[k] + ): + sd2[k] = [str(x) for x in v] + else: + sd2[k] = v + return sd2 + + def update_data_selection(self, tune_param, multi_choice): + self.selected_tune_params[tune_param] = multi_choice + # Cross selection based on all selections in all tunable parameters + mask = pd.Series(True, index=self.data_df.index) + for k, v in self.selected_tune_params.items(): + mask &= self.data_df[k].isin(v) + data_df = self.data_df[mask] + self.source.data = self._df_categorical_to_str(data_df) + + def update_colors(self, color_by): + dtype = self.data_df.dtypes[color_by] + + if dtype == "category": + factors = [str(f) for f in dtype.categories] + if len(factors) < 10: + palette = bokeh.palettes.Category10[10] + else: + palette = bokeh.palettes.Category20[20] + color_mapper = CategoricalColorMapper(palette=palette, factors=factors) + else: + color_mapper = LinearColorMapper( + palette="Viridis256", + low=min(self.data_df[color_by]), + high=max(self.data_df[color_by]), + ) + + color = {"field": color_by, "transform": color_mapper} + return color + + def make_pane(self, xvariable, yvariable, color_by, xscale, yscale, show_table: bool = True): + color = self.update_colors(color_by) + + x = xvariable + y = yvariable + + plot_options = dict(self.plot_options) + plot_options["x_axis_type"] = xscale + plot_options["y_axis_type"] = yscale + + # If the table is disabled we want the plot to take the full page height + if not show_table: + plot_options.pop("height", None) + plot_options.pop("min_height", None) + plot_options["sizing_mode"] = "stretch_both" + + dtype = self.data_df.dtypes.get(xvariable) + if pd.api.types.is_categorical_dtype(dtype): + x_factors = [str(f) for f in dtype.categories] + plot_options["x_range"] = x_factors + x = jitter( + xvariable, + width=0.02, + distribution="normal", + range=FactorRange(*x_factors), + ) + + dtype = self.data_df.dtypes.get(yvariable) + if pd.api.types.is_categorical_dtype(dtype): + y_factors = [str(f) for f in dtype.categories] + plot_options["y_range"] = y_factors + y = jitter( + yvariable, + width=0.02, + distribution="normal", + range=FactorRange(*y_factors), + ) + + f = figure(**plot_options) + f.scatter(x, y, size=5, color=color, alpha=0.5, source=self.source) + f.xaxis.axis_label = xvariable + f.yaxis.axis_label = yvariable + + # DataTable showing the raw data + columns = [TableColumn(field=c, title=c) for c in self.source.column_names][1:] + data_table = DataTable( + source=self.source, + columns=columns, + selectable=True, + sizing_mode="stretch_width", + ) + + pane_title = ( + f"## Auto-tuning {self.model.kernel_name} on {self.model.device_name}" + ) + + if show_table: + bokeh_pane = pn.pane.Bokeh( + object=f, + sizing_mode="stretch_width", + min_height=self.plot_height, + max_height=self.plot_height, + ) + pane_children = [ + pn.pane.Markdown(pane_title), + bokeh_pane, + pn.layout.Divider(), + data_table, + ] + else: + bokeh_pane = pn.pane.Bokeh(object=f, sizing_mode="stretch_both") + pane_children = [ + pn.pane.Markdown(pane_title), + bokeh_pane, + ] + + pane = pn.Column(*pane_children) + return pane + + def update_plot(self, i): + sd = self.model.get_stream_for_index(i) + self.source.stream(self._convert_stream_dict(sd)) + + def update_data(self): + stream_dicts = self.model.read_new_contents() + for sd in stream_dicts: + self.source.stream(self._convert_stream_dict(sd)) + + +def serve_panel(cachefile: str, show_table: bool = True) -> None: + ui = PanelDashboard(cachefile) + + ui.dashboard.servable() + + def dashboard_f(): + pn.state.add_periodic_callback(ui.update_data, 1000) + return ui.dashboard + + pn.serve(dashboard_f, show=False) diff --git a/ktdashboard/streamlit_dashboard.py b/ktdashboard/streamlit_dashboard.py new file mode 100644 index 0000000..516d396 --- /dev/null +++ b/ktdashboard/streamlit_dashboard.py @@ -0,0 +1,204 @@ +import argparse +from typing import Any +import streamlit as st +import pandas as pd +import plotly.express as px +import numpy as np +import tempfile + +from dashboard import Dashboard + + +class StreamlitDashboard: + def __init__(self, cachefile: str | None = None, show_table: bool = True): + self.cachefile = cachefile + self.model = Dashboard(cachefile) if cachefile else None + self.df = self.model.data_df if self.model else pd.DataFrame() + self.plot_height = 600 + self.show_table = show_table + + def load_from_uploaded_file(self, uploaded_file) -> None: + tf = tempfile.NamedTemporaryFile(delete=False, suffix=".json") + content = uploaded_file.read() + tf.write(content) + tf.flush() + tf.close() + self.cachefile = tf.name + self.model = Dashboard(self.cachefile) + self.df = self.model.data_df + + def plot_scatter( + self, + df: pd.DataFrame, + x: str, + y: str, + color: str, + xscale: str, + yscale: str, + palette: str = "Viridis", + ) -> Any: + df_plot = df.copy() + + def jitter(col): + dtype = df_plot[col].dtype + if isinstance(dtype, pd.CategoricalDtype) or dtype == object: + categories = list(pd.Categorical(df_plot[col]).categories) + mapping = {c: i for i, c in enumerate(categories)} + arr = df_plot[col].map(mapping).astype(float) + arr += np.random.normal(scale=0.15, size=len(arr)) + return arr, categories + else: + return df_plot[col], None + + x_vals, x_cats = jitter(x) + y_vals, y_cats = jitter(y) + + df_plot["_x"] = x_vals + df_plot["_y"] = y_vals + + color_arg = color if color in df_plot.columns else None + + seq = getattr(px.colors.sequential, palette, None) + color_kwargs = {"color_continuous_scale": seq} + + fig = px.scatter( + df_plot, + x="_x", + y="_y", + color=color_arg, + hover_data=df_plot.columns, + height=self.plot_height, + labels={"_x": x, "_y": y}, + **color_kwargs, + ) + + if x_cats is not None: + fig.update_xaxes( + tickmode="array", tickvals=list(range(len(x_cats))), ticktext=x_cats + ) + if y_cats is not None: + fig.update_yaxes( + tickmode="array", tickvals=list(range(len(y_cats))), ticktext=y_cats + ) + + if xscale == "log": + fig.update_xaxes(type="log") + if yscale == "log": + fig.update_yaxes(type="log") + + return fig + + def render(self): + st.set_page_config(layout="wide", page_title="Kernel Tuner Dashboard") + + if self.model is None: + uploaded = st.sidebar.file_uploader("Upload a cache file", type=["json"]) + if uploaded is None: + st.info("Upload a cache JSON file via the sidebar to get started.") + return + try: + self.load_from_uploaded_file(uploaded) + except Exception as e: + st.error(f"Failed to read uploaded file: {e}") + return + + kernel_name = self.model.kernel_name + device_name = self.model.device_name + + st.sidebar.markdown(f"**Kernel:** {kernel_name}") + st.sidebar.markdown(f"**Device:** {device_name}") + + scalar_value_keys = self.model.scalar_value_keys + tune_param_keys = self.model.tune_param_keys + all_tune_params = self.model.all_tune_params + + default_key = ( + "GFLOP/s" + if "GFLOP/s" in scalar_value_keys + else ("time" if "time" in scalar_value_keys else scalar_value_keys[0]) + ) + + yvariable = st.sidebar.selectbox( + "Y", options=scalar_value_keys, index=scalar_value_keys.index(default_key) + ) + xvariable = st.sidebar.selectbox( + "X", options=["index"] + scalar_value_keys, index=0 + ) + colorvariable = st.sidebar.selectbox( + "Color By", + options=scalar_value_keys, + index=scalar_value_keys.index(default_key), + ) + xscale = st.sidebar.radio("X axis scale", options=["linear", "log"], index=0) + yscale = st.sidebar.radio("Y axis scale", options=["linear", "log"], index=0) + + # Show table control + show_table = st.sidebar.checkbox("Show table", value=self.show_table) + + # Color palette chooser (sequential palettes only) + seq_names = [ + name + for name in dir(px.colors.sequential) + if not name.startswith("_") + and isinstance(getattr(px.colors.sequential, name), list) + ] + seq_names = sorted(seq_names) + default_idx = seq_names.index("Viridis") if "Viridis" in seq_names else 0 + palette = st.sidebar.selectbox( + "Color palette", options=seq_names, index=default_idx + ) + + # tune param multi-selects + selections = {} + for tp in tune_param_keys: + selections[tp] = st.sidebar.multiselect( + tp, options=all_tune_params[tp], default=list(all_tune_params[tp]) + ) + self.model.update_selection(tp, selections[tp]) + + filtered_df = self.model.get_filtered_df() + + st.markdown(f"### Auto-tuning {kernel_name} on {device_name}") + + plot_height = self.plot_height + if not show_table: + plot_height = int(plot_height * 1.5) + + fig = self.plot_scatter( + filtered_df, + xvariable, + yvariable, + colorvariable, + xscale, + yscale, + palette=palette, + ) + + st.plotly_chart(fig, height=plot_height, width="stretch") + + if show_table: + st.markdown("---") + + if pd.api.types.is_numeric_dtype(filtered_df[yvariable]): + sorted_df = filtered_df.sort_values(yvariable) + st.dataframe(sorted_df) + else: + st.dataframe(filtered_df) + + +def serve_streamlit(cachefile: str | None = None, show_table: bool = True) -> None: + sd = StreamlitDashboard(cachefile, show_table=show_table) + sd.render() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="ktdashboard") + parser.add_argument("filename", nargs="?", help="Path to cache JSON file (optional)") + parser.add_argument( + "--table", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable data table (default: enabled)", + ) + args = parser.parse_args() + serve_streamlit(args.filename if args.filename else None, show_table=args.table) diff --git a/setup.py b/setup.py index c50cd83..c6d0f3d 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,6 @@ 'Topic :: System :: Distributed Computing', 'Development Status :: 3 - Alpha ', ], - install_requires=['bokeh','pandas','panel'], + install_requires=['bokeh','pandas','panel','streamlit','plotly'], entry_points={'console_scripts': ['ktdashboard = ktdashboard.ktdashboard:cli']}, )