diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..90ba71fe Binary files /dev/null and b/.DS_Store differ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..98dd0751 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: tseda serve", + "type": "debugpy", + "request": "launch", + "module": "tseda", // Use the module 'tseda' directly + "args": [ + "serve", "tests/data/test.trees.tseda" // Arguments to the module + ], + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "justMyCode": true, + "python": "${workspaceFolder}/.venv/bin/python" + } + ] + } + \ No newline at end of file diff --git a/src/tseda/__main__.py b/src/tseda/__main__.py index a9ac30e8..64b2a325 100644 --- a/src/tseda/__main__.py +++ b/src/tseda/__main__.py @@ -49,8 +49,8 @@ def cli(): ), ) def preprocess(tszip_path, output): - """ - Preprocess a tskit tree sequence or tszip file, producing a .tseda file. + """Preprocess a tskit tree sequence or tszip file, producing a .tseda file. + Calls tsbrowse.preprocess.preprocess. """ tszip_path = pathlib.Path(tszip_path) @@ -79,9 +79,7 @@ def preprocess(tszip_path, output): help="Do not filter the output log (advanced debugging only)", ) def serve(path, port, show, log_level, no_log_filter, admin): - """ - Run the tseda datastore server, version based on View base class. - """ + """Run the tseda datastore server, version based on View base class.""" setup_logging(log_level, no_log_filter) tsm = TSModel(path) @@ -91,8 +89,8 @@ def serve(path, port, show, log_level, no_log_filter, admin): app_ = app.DataStoreApp( datastore=datastore.DataStore( tsm=tsm, - individuals_table=individuals_table, sample_sets_table=sample_sets_table, + individuals_table=individuals_table, ), title="TSEda Datastore App", views=[IndividualsTable], diff --git a/src/tseda/app.py b/src/tseda/app.py index d6dd4e11..9ac2c0c1 100644 --- a/src/tseda/app.py +++ b/src/tseda/app.py @@ -1,8 +1,8 @@ """Main application for tseda. -Provides the DataStoreApp class that is the main application for -tseda. The DataStoreApp subclasses the Viewer class from panel and -renders a panel.FastListTemplate object. +Provides the DataStoreApp class that is the main application for tseda. The +DataStoreApp subclasses the Viewer class from panel and renders a +panel.FastListTemplate object. """ import time @@ -20,7 +20,7 @@ RAW_CSS = """ .sidenav#sidebar { - background-color: #15E3AC; + background-color: WhiteSmoke; } .title { font-size: var(--type-ramp-plus-2-font-size); @@ -50,7 +50,20 @@ class DataStoreApp(Viewer): - """Main application class for tseda visualization app.""" + """Main application class for tseda visualization app. + + Attributes: + datastore (DataStore): The data store instance for accessing and + managing data. + title (str): The title of the application. + views (List[str]): A list of views to show on startup. + + Methods: + __init__(**params): Initializes the application, loads pages, and sets + up data update listeners. + view(): Creates the main application view, including a header selector + for switching between different pages. + """ datastore = param.ClassSelector(class_=datastore.DataStore) @@ -71,21 +84,28 @@ def __init__(self, **params): logger.info(f"Initialised pages in {time.time() - t:.2f}s") updating = ( - self.datastore.individuals_table.data.rx.updating() - | self.datastore.sample_sets_table.data.rx.updating() + self.datastore.sample_sets_table.data.rx.updating() + | self.datastore.individuals_table.data.rx.updating() ) updating.rx.watch( - lambda updating: pn.state.curdoc.hold() - if updating - else pn.state.curdoc.unhold() + lambda updating: ( + pn.state.curdoc.hold() + if updating + else pn.state.curdoc.unhold() + ) ) @param.depends("views") def view(self): - """Main application view that renders a radio button group on - top with links to pages. Each page consists of a main content - page with plots and sidebars that provide user options for - configuring plots and outputs.""" + """Creates the main application view. Main application view that + renders a radio button group on top with links to pages. Each page + consists of a main content page with plots and sidebars that provide + user options for configuring plots and outputs. + + Returns: + pn.template.FastListTemplate: A Panel template containing the + header selector, sidebar, and main content. + """ page_titles = list(self.pages.keys()) header_selector = pn.widgets.RadioButtonGroup( options=page_titles, @@ -105,9 +125,11 @@ def get_sidebar(selected_page): yield self.pages[selected_page].sidebar self._template = pn.template.FastListTemplate( - title=self.datastore.tsm.name[:75] + "..." - if len(self.datastore.tsm.name) > 75 - else self.datastore.tsm.name, + title=( + self.datastore.tsm.name[:75] + "..." + if len(self.datastore.tsm.name) > 75 + else self.datastore.tsm.name + ), header=[header_selector], sidebar=get_sidebar, main=get_content, diff --git a/src/tseda/cache.py b/src/tseda/cache.py index e519d66e..686405ee 100644 --- a/src/tseda/cache.py +++ b/src/tseda/cache.py @@ -1,3 +1,6 @@ +"""This module provides a caching mechanism for the TSeDA application, +utilizing the `diskcache` library.""" + import pathlib import appdirs @@ -7,7 +10,14 @@ logger = daiquiri.getLogger("cache") -def get_cache_dir(): +def get_cache_dir() -> pathlib.Path: + """Retrieves the user's cache directory for the TSeDA application. Creates + the directory if it doesn't exist, ensuring its creation along with any + necessary parent directories. + + Returns: + pathlib.Path: The path to the cache directory. + """ cache_dir = pathlib.Path(appdirs.user_cache_dir("tseda", "tseda")) cache_dir.mkdir(exist_ok=True, parents=True) return cache_dir diff --git a/src/tseda/config.py b/src/tseda/config.py index 83e4adea..5e88884d 100644 --- a/src/tseda/config.py +++ b/src/tseda/config.py @@ -1,3 +1,9 @@ +"""Config file. + +This file stores configurations for the entire application such as figure +dimensions and color schemes. +""" + import holoviews as hv from holoviews.plotting.util import process_cmap @@ -8,9 +14,9 @@ PLOT_COLOURS = ["#15E3AC", "#0FA57E", "#0D5160"] # VCard settings -SIDEBAR_BACKGROUND = "#15E3AC" +SIDEBAR_BACKGROUND = "#5CB85D" VCARD_STYLE = { - "background": "#15E3AC", + "background": "WhiteSmoke", } # Global color map diff --git a/src/tseda/datastore.py b/src/tseda/datastore.py index 5ddcea58..0dd09e3a 100644 --- a/src/tseda/datastore.py +++ b/src/tseda/datastore.py @@ -1,4 +1,35 @@ +"""This module provides a collection of classes and functions for analyzing and +visualizing population genetic data. It uses the `tsbrowse` library for working +with TreeSequence data and the `panel` library for creating interactive +visualizations. + +Key Classes: + +- `SampleSetsTable`: Manages and displays information about sample sets, +including + their names, colors, and predefined status. +- `IndividualsTable`: Handles individual data, including their population, +sample set + assignments, and selection status. Enables filtering and modification of + individual + attributes. +- `DataStore`: Provides access to the underlying TreeSequence data, sample +sets, + and individuals data. Also includes methods for calculating haplotype + GNNs and + retrieving sample and population information. + +Methods: + +- `make_individuals_table`: Creates an `IndividualsTable` object from a given +TreeSequence. +- `make_sample_sets_table`: Creates a `SampleSetsTable` object from a given +TreeSequence. +- `preprocess`: Calls `make_individuals_table`and `make_sample_sets_table`. +""" + import random +from typing import Dict, List, Optional, Tuple import daiquiri import pandas as pd @@ -15,45 +46,383 @@ logger = daiquiri.getLogger("tseda") -def make_individuals_table(tsm): - result = [] - for ts_ind in tsm.ts.individuals(): - ind = Individual(individual=ts_ind) - result.append(ind) - return IndividualsTable(table=pd.DataFrame(result)) +class SampleSetsTable(Viewer): + """SampleSetsTable class represents a table for managing sample sets. + + Attributes: + columns (list): + The default columns displayed in the table (["name", "color", + "predefined"]). + editors (dict): + Dictionary specifying editor types for each column in the table. + formatters (dict): + Dictionary defining formatters for each column. + create_sample_set_textinput (String): + Parameter for entering a new sample set name (default=None). + create_sample_set_button (pn.widget.Button): + Button to create sample set. + sample_set_warning (pn.pane.Alert): + Warning alert for duplicate sample set names. + table (param.DataFrame): + Underlying DataFrame holding sample set data. + + Methods: + tooltip() -> pn.widgets.TooltipIcon: + Returns a tooltip for the table. + def __panel__(): + Creates the main panel for the table with functionalities. + get_ids() -> List: + Returns a list of sample set IDs. + sidebar_table() -> pn.Column: + Generates a sidebar table with quick view functionalities. + sidebar() - > pn.Column: + Creates the sidebar with options for managing sample sets. + color_by_name (dict): + Returns a dictionary with sample set colors as key-value pairs + (name-color). + names (dict): + Returns a dictionary with sample set names as key-value pairs + (index-name). + loc(self, i: int) -> pd.core.series.Series: + Returns a pd.core.series.Series (row) of a dataframe for a + specific id + """ + + columns = ["name", "color", "predefined"] + editors = {k: None for k in columns} + editors["color"] = { + "type": "list", + "values": config.COLORS, + "valueLookup": True, + } + editors = { + "name": {"type": "input", "validator": "unique", "search": True}, + "color": { + "type": "list", + "values": [ + { + "value": color, + "label": ( + f'
' + ), + } + for color in config.COLORS + ], + }, + "predefined": {"type": "tickCross"}, + "valueLookup": True, + } + formatters = { + "color": {"type": "color"}, + "predefined": {"type": "tickCross"}, + } + create_sample_set_textinput = param.String( + doc="Enter name of new sample set.", + default=None, + label="Create new sample set", + ) + create_sample_set_button = pn.widgets.Button( + description="Create new sample set.", + name="Create", + button_type="success", + align="end", + ) + sample_set_warning = pn.pane.Alert( + "This sample set name already exists, pick a unique name.", + alert_type="warning", + visible=False, + ) + table = param.DataFrame() + def __init__(self, **params): + super().__init__(**params) + self.table.set_index(["sample_set_id"], inplace=True) + self.data = self.param.table.rx() -def make_sample_sets_table(tsm): - result = [] - for ts_pop in tsm.ts.populations(): - ss = SampleSet(id=ts_pop.id, population=ts_pop, predefined=True) - result.append(ss) - return SampleSetsTable(table=pd.DataFrame(result)) + @property + def tooltip(self) -> pn.widgets.TooltipIcon: + """Returns a TooltipIcon widget containing instructions for editing + sample set names and colors, and assigning individuals to sample sets. + + Returns: + pn.widgets.TooltipIcon: A TooltipIcon widget displaying the + instructions. + """ + return pn.widgets.TooltipIcon( + value=( + "The name and color of each sample set are editable. In the " + "color column, select a color from the dropdown list. In the " + "individuals table, you can assign individuals to sample sets." + ), + ) + def create_new_sample_set(self): + """Creates a new sample set with the provided name in the + create_sample_set_textinput widget, if a name is entered and it's not + already in use.""" + if self.create_sample_set_textinput is not None: + previous_names = [ + self.table.name[i] for i in range(len(self.table)) + ] + if self.create_sample_set_textinput in previous_names: + self.sample_set_warning.visible = True + else: + previous_colors = [ + self.table.color[i] for i in range(len(self.table)) + ] + unused_colors = [ + color + for color in config.COLORS + if color not in previous_colors + ] + if len(unused_colors) != 0: + colors = unused_colors + else: + colors = config.COLORS + self.sample_set_warning.visible = False + i = max(self.param.table.rx.value.index) + 1 + self.param.table.rx.value.loc[i] = [ + self.create_sample_set_textinput, + colors[random.randint(0, len(colors) - 1)], + False, + ] + self.create_sample_set_textinput = None -def preprocess(tsm): - """Take a tsbrowse.TSModel and make individuals and sample sets tables.""" - logger.info( - "Preprocessing data: making individuals and sample sets tables" - ) - individuals_table = make_individuals_table(tsm) - sample_sets_table = make_sample_sets_table(tsm) - return individuals_table, sample_sets_table + def get_ids(self) -> List: + """Returns the sample set IDs. + + Returns: + List: A list of the sample set IDs as integers + + Raises: + TypeError: If the sample set table is not a valid + Dataframe (not yet populated) + """ + if isinstance(self.table, pd.DataFrame): + return self.table.index.values.tolist() + else: + raise TypeError("self.table is not a valid pandas DataFrame.") + + @property + def color_by_name(self) -> Dict[str, str]: + """Return the color of all sample sets as a dictionary with sample set + names as keys. + + Returns: + Dict: dictionary of + """ + d = {} + for _, row in self.data.rx.value.iterrows(): + d[row["name"]] = row.color + return d + + @property + def names(self) -> Dict[int, str]: + # TODO: see why this is called 6 times in a row - unecessary + """Return the names of all sample sets as a dictionary. + + Returns: + Dict: dictionary of indices (int) as keys and + names (str) as values + """ + d = {} + for index, row in self.data.rx.value.iterrows(): + d[index] = row["name"] + + return d + + def loc(self, i: int) -> pd.core.series.Series: + """Returns sample set pd.core.series.Series object (dataframe row) by + index. + + Arguments: + i: Index for the sample set wanted + + Returns: + pd.core.series.Series object: + Containing name, color and predefined status. + """ + return self.data.rx.value.loc[i] + + @pn.depends("create_sample_set_button.value") + def __panel__(self) -> pn.Column: + """Returns the main content of the page which is retrieved from the + `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the main content area. + """ + self.create_new_sample_set() + + table = pn.widgets.Tabulator( + self.data, + layout="fit_data_table", + selectable=True, + page_size=10, + pagination="remote", + margin=10, + formatters=self.formatters, + editors=self.editors, + configuration={ + "rowHeight": 40, + }, + height=500, + ) + return pn.Column( + self.tooltip, + table, + ) + + def sidebar_table(self) -> pn.Card: + """Generates a sidebar table with quick view functionalities. + + Returns: + pn.Card: The layout for the sidebar. + """ + table = pn.widgets.Tabulator( + self.data, + layout="fit_data_table", + selectable=True, + page_size=10, + pagination="remote", + margin=10, + formatters=self.formatters, + editors=self.editors, + hidden_columns=["id"], + ) + return pn.Card( + pn.Column(self.tooltip, table), + title="Sample sets table quick view", + collapsed=True, + header_background=config.SIDEBAR_BACKGROUND, + active_header_background=config.SIDEBAR_BACKGROUND, + styles=config.VCARD_STYLE, + ) + + def sidebar(self) -> pn.Column: + """Returns the content of the sidebar. + + Returns: + pn.Column: The layout for the sidebar. + """ + return pn.Column( + pn.Card( + self.param.create_sample_set_textinput, + self.create_sample_set_button, + title="Sample sets table options", + collapsed=False, + header_background=config.SIDEBAR_BACKGROUND, + active_header_background=config.SIDEBAR_BACKGROUND, + styles=config.VCARD_STYLE, + ), + self.sample_set_warning, + ) class IndividualsTable(Viewer): - """Class to hold and view individuals and perform calculations to - change filters.""" + """Class represents a table for managing individuals and perform + calculations to change filters. + + Attributes: + sample_sets_table (param.ClassSelector): + ClassSelector for the SampleSetsTable class. + columns (list): + The default columns displayed in the table (["name", "color", + "predefined"]). + editors (dict): + Dictionary specifying editor types for each column in the table. + formatters (dict): + Dictionary defining formatters for each column. + filters (dict): + Filter configurations for the columns. + table (param.DataFrame): + Underlying data stored as a DataFrame. + page_size (param.Selector): + Number of rows per page to display. + sample_select (pn.widgets.MultiChoice): + Widget for selecting sample sets. + population_from (pn.widgets.Select): + Widget for selecting the original population ID. + sample_set_to (pn.widgets.Select): + Widget for selecting the new sample set ID. + mod_update_button (pn.widgets.Button): + Button to apply reassignment of population IDs. + refresh_button (pn.widgets.Button): + Button to refresh the table view. + restore_button (pn.widgets.Button): + Button to restore data to its original state. + data_mod_warning (pn.pane.Alert): + Warning alert for invalid modifications. + + + Methods: + + tooltip() -> pn.widgets.TooltipIcon : + Returns a TooltipIcon widget containing information about the + individuals + table and how to edit it. + + sample_sets(only_selected: Optional[bool] = True): + Returns a dictionary with a sample set id to samples list mapping. + + get_population_ids() -> List[int]: + Returns a sorted list of unique population IDs present in the data. + + get_sample_set_ids() -> List[int]: + Returns a sorted list of unique sample set IDs present in the data. + This method combines IDs from two sources: + 1. Underlying data ("sample_set_id" column). + 2. Optional SampleSetsTable object (if defined). + + sample2ind -> Dict[int, int]: + Creates a dictionary mapping sample (tskit node) IDs to + individual IDs. + + samples(): + Yields all sample (tskit node) IDs present in the data. + + loc(i: int) -> pd.core.series.Series: + Returns the individual data (pd.Series) for a specific index (ID). + + reset_modification(): + Resets the "sample_set_id" column to the original values from + "population". + + combine_tables(individuals_table: param.reactive.rx) -> + pn.widgets.Tabulator: + Combines individuals and sample set data into a single table using + pandas.merge. + + __panel__ -> pn.Column: + The main content of the page, retrieved from `datastore.tsm.ts`. + Updates options based on button interactions and returns a Column + layout. + + options_sidebar() -> pn.Card: + Creates a Panel card containing options for the individuals table: + - Page size selector. + - Sample set selector. + + modification_sidebar() -> pn.Column: + Creates a Panel column containing data modification options: + - Card with population from and sample set to selectors. + - Restore and update buttons. + - Warning message for invalid data. + """ + sample_sets_table = param.ClassSelector(class_=SampleSetsTable) columns = [ - "name", + "color", "population", "sample_set_id", - "selected", + "name_sample_set", + "name_individual", "longitude", "latitude", + "selected", ] - editors = {k: None for k in columns} # noqa + editors = {k: None for k in columns} editors["sample_set_id"] = { "type": "number", "valueLookup": True, @@ -63,36 +432,16 @@ class IndividualsTable(Viewer): "values": [False, True], "valuesLookup": True, } - formatters = {"selected": {"type": "tickCross"}} - - table = param.DataFrame() - - page_size = param.Selector( - objects=[10, 20, 50, 100, 200, 500], - default=20, - doc="Number of rows per page to display", - ) - sample_select = pn.widgets.MultiChoice( - name="Select sample sets", - description="Select samples based on the sample set ID.", - options=[], - ) - population_from = param.Integer( - label="Population ID", - default=None, - bounds=(0, None), - doc=("Reassign individuals with this population ID."), - ) - sample_set_to = param.Integer( - label="New sample set ID", - default=None, - bounds=(0, None), - doc=("Reassign individuals to this sample set ID."), - ) - mod_update_button = pn.widgets.Button(name="Update") - + formatters = { + "selected": {"type": "tickCross"}, + "color": {"type": "color"}, + } filters = { - "name": {"type": "input", "func": "like", "placeholder": "Enter name"}, + "name_individual": { + "type": "input", + "func": "like", + "placeholder": "Enter name", + }, "population": { "type": "input", "func": "like", @@ -107,19 +456,78 @@ class IndividualsTable(Viewer): "type": "tickCross", "tristate": True, "indeterminateValue": None, - "placeholder": "Enter True/False", + }, + "name_sample_set": { + "type": "input", + "func": "like", + "placeholder": "Enter name", }, } + table = param.DataFrame() + page_size = param.Selector( + objects=[10, 20, 50, 100, 200, 500], + default=20, + doc="Number of rows per page to display", + ) + sample_select = pn.widgets.MultiChoice( + name="Select sample sets", + description="Select samples based on the sample set ID.", + options=[], + ) + population_from = pn.widgets.Select( + name="Original population ID", + value=None, + sizing_mode="stretch_width", + description=("Reassign individuals with this population ID."), + ) + sample_set_to = pn.widgets.Select( + name="New sample set ID", + value=None, + sizing_mode="stretch_width", + description=("Reassign individuals to this sample set ID."), + ) + mod_update_button = pn.widgets.Button( + name="Reassign", + button_type="success", + margin=(10, 10), + description="Apply reassignment.", + ) + refresh_button = pn.widgets.Button( + name="Refresh", + button_type="success", + margin=(10, 0), + description="Refresh to apply updates to entire page.", + ) + restore_button = pn.widgets.Button( + name="Restore", + button_type="danger", + margin=(10, 10), + description="Restore sample sets to their original state.", + ) + data_mod_warning = pn.pane.Alert( + """Please enter a valid population ID and + a non-negative new sample set ID""", + alert_type="warning", + visible=False, + ) def __init__(self, **params): super().__init__(**params) self.table.set_index(["id"], inplace=True) self.data = self.param.table.rx() - self.sample_select.options = self.sample_set_indices() - self.sample_select.value = self.sample_set_indices() + all_sample_set_ids = self.get_sample_set_ids() + self.sample_select.options = all_sample_set_ids + self.sample_select.value = all_sample_set_ids @property - def tooltip(self): + def tooltip(self) -> pn.widgets.TooltipIcon: + """Returns a TooltipIcon widget containing information about the + individuals table and how to edit it. + + Returns: + pn.widgets.TooltipIcon: A TooltipIcon widget displaying + information. + """ return pn.widgets.TooltipIcon( value=( "Individuals table with columns relevant for modifying plots. " @@ -136,34 +544,75 @@ def tooltip(self): ), ) - def sample_set_indices(self): - """Return indices of sample groups.""" - return sorted(self.data.rx.value["sample_set_id"].unique().tolist()) + def sample_sets(self, only_selected: Optional[bool] = True): + """Returns a dictionary with a sample set id to samples list mapping. + + Arguments: + only_selected (bool, optional): If True, only considers + individuals marked as selected in the table. Defaults to True. - def sample_sets(self): + Returns: + dict: A dictionary where keys are sample set IDs and values are + lists of samples (tskit node IDs) belonging to that set. If + `only_selected` is True, only samples marked as selected are + included in the lists. + """ sample_sets = {} - samples = [] inds = self.data.rx.value for _, ind in inds.iterrows(): - if not ind.selected: + if not ind.selected and only_selected: continue sample_set = ind.sample_set_id if sample_set not in sample_sets: sample_sets[sample_set] = [] sample_sets[sample_set].extend(ind.nodes) - samples.extend(ind.nodes) - return samples, sample_sets + return sample_sets - def get_sample_sets(self, indexes=None): - """Return list of sample sets and their samples.""" - samples, sample_sets = self.sample_sets() - if indexes: - return [sample_sets[i] for i in indexes] - return [sample_sets[i] for i in sample_sets] + def get_population_ids(self) -> List[int]: + """Returns a sorted list of unique population IDs present in the data. + + Returns: + list: A list containing all unique population IDs in the table. + """ + return sorted(self.data.rx.value["population"].unique().tolist()) + + def get_sample_set_ids(self) -> List[int]: + """Returns a sorted list of unique sample set IDs present in the data. + + This method combines sample set IDs from two sources: + + 1. Unique IDs from the "sample_set_id" column of the underlying data + (self.data.rx.value). + 2. (Optional) IDs retrieved from the SampleSetsTable object + (accessed through self.sample_sets_table) iff it is defined. + + Returns: + list: A sorted list containing all unique sample set IDs found in + the data and potentially from the `SampleSetsTable`. + """ + individuals_sets = sorted(self.data.rx.value["sample_set_id"].tolist()) + if self.sample_sets_table is not None: # Nonetype when not yet defined + individuals_sets = ( + individuals_sets + self.sample_sets_table.get_ids() + ) + return sorted(list(set(individuals_sets))) @property - def sample2ind(self): - """Map sample (tskit node) ids to individual ids""" + def sample2ind(self) -> Dict[int, int]: + """Creates a dictionary that maps sample (tskit node) IDs to individual + IDs. + + This method iterates through the underlying data and builds a + dictionary where: + Keys are sample (tskit node) IDs. Values are the corresponding + individual IDs + (indices) in the data. + + Returns: + dict: A dictionary mapping sample (tskit node) IDs to their + corresponding + individual IDs. + """ inds = self.data.rx.value d = {} for index, ind in inds.iterrows(): @@ -172,252 +621,248 @@ def sample2ind(self): return d def samples(self): - """Return all samples""" + """Yields all sample (tskit node) IDs present in the data. + + This method iterates through the underlying data and yields each sample + (tskit node) ID. + + Yields: + int: Sample (tskit node) ID from the data. + """ + for _, ind in self.data.rx.value.iterrows(): for node in ind.nodes: yield node - def loc(self, i): - """Return individual by index""" + def loc(self, i: int) -> pd.core.series.Series: + """Returns the individual data, pd.core.series.Series object, for a + specific index (ID) i. + + Arguments: + i (int): The index (ID) of the individual to retrieve. + + Returns: + pd.core.series.Series: A pandas Series representing the individual + data corresponding to the provided index. + """ return self.data.rx.value.loc[i] - @pn.depends("page_size", "sample_select.value", "mod_update_button.value") - def __panel__(self): - self.data.rx.value["selected"] = False - if ( - isinstance(self.sample_select.value, list) - and self.sample_select.value - ): - for sample_set_id in self.sample_select.value: - self.data.rx.value.loc[ - self.data.rx.value.sample_set_id == sample_set_id, - "selected", - ] = True - if self.sample_set_to is not None: - if self.population_from is not None: - try: - self.table.loc[ - self.table["population"] == self.population_from, # pyright: ignore[reportIndexIssue] - "sample_set_id", - ] = self.sample_set_to - except IndexError: - logger.error("No such population %i", self.population_from) - else: - logger.info("No population defined") - data = self.data[self.columns] + def reset_modification(self): + """Resets the "sample_set_id" column of the underlying data + (`self.data.rx.value`) back to the original values from the + "population" column. + + This effectively undoes any modifications made to sample set + assignments. + """ + self.data.rx.value.sample_set_id = self.data.rx.value.population + + def combine_tables( + self, individuals_table: param.reactive.rx + ) -> pn.widgets.Tabulator: + """Combines individuals data and sample set data into a single table. + + This method merges the data from two sources: + + 1. The individuals data (`individuals_table.rx.value`) from the + provided + `individuals_table` argument. + 2. The sample set data (`self.sample_sets_table.data.rx.value`) + from the `SampleSetsTable` object (accessed through + `self.sample_sets_table`) + if it's defined. + + The merge is performed using pandas.merge based on the + "sample_set_id" column. + The resulting table includes additional columns with suffixes + indicating their + origin (e.g., "_individual" for data from `individuals_table`). + + Arguments: + individuals_table (aram.reactive.rx object): + An object containing the individuals data table. + + Returns: + pn.widgets.Tabulator: + A Tabulator widget representing the combined individuals and + sample set data. + """ + + combined_df = pd.merge( + individuals_table.rx.value, + self.sample_sets_table.data.rx.value, + left_on="sample_set_id", + right_index=True, + suffixes=("_individual", "_sample_set"), + ) - table = pn.widgets.Tabulator( - data, + combined_df["id"] = combined_df.index + combined_df = combined_df[self.columns] + + combined_table = pn.widgets.Tabulator( + combined_df, pagination="remote", layout="fit_columns", selectable=True, page_size=self.page_size, formatters=self.formatters, editors=self.editors, + sorters=[ + {"field": "id", "dir": "asc"}, + {"field": "selected", "dir": "des"}, + ], margin=10, text_align={col: "right" for col in self.columns}, header_filters=self.filters, ) - return pn.Column(self.tooltip, table) - - def options_sidebar(self): - return pn.Card( - self.param.page_size, - self.sample_select, - collapsed=False, - title="Individuals table options", - header_background=config.SIDEBAR_BACKGROUND, - active_header_background=config.SIDEBAR_BACKGROUND, - styles=config.VCARD_STYLE, - ) - - modification_header = pn.pane.Markdown("#### Batch reassign indivuduals:") - - def modification_sidebar(self): - return pn.Card( - pn.Column( - self.modification_header, - pn.Row(self.param.population_from, self.param.sample_set_to), - self.mod_update_button, - ), - collapsed=False, - title="Data modification", - header_background=config.SIDEBAR_BACKGROUND, - active_header_background=config.SIDEBAR_BACKGROUND, - styles=config.VCARD_STYLE, - ) - - -class SampleSetsTable(Viewer): - default_columns = ["name", "color", "predefined"] - editors = {k: None for k in default_columns} - editors["color"] = { - "type": "list", - "values": config.COLORS, - "valueLookup": True, - } - editors["name"] = {"type": "input", "validator": "unique", "search": True} - formatters = { - "color": {"type": "color"}, - "predefined": {"type": "tickCross"}, - } - - create_sample_set_textinput = param.String( - doc="New sample set name. Press Enter (⏎) to create.", - default=None, - label="New sample set name", - ) - - warning_pane = pn.pane.Alert( - "This sample set name already exists, pick a unique name.", - alert_type="warning", - visible=False, + return combined_table + + @pn.depends( + "page_size", + "sample_select.value", + "mod_update_button.value", + "refresh_button.value", + "restore_button.value", + "sample_sets_table.create_sample_set_button.value", ) + def __panel__(self) -> pn.Column: + """Returns the main content of the page which is retrieved from the + `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the main content area. + """ + self.population_from.options = self.get_population_ids() + all_sample_set_ids = self.get_sample_set_ids() + self.sample_set_to.options = all_sample_set_ids + self.sample_select.options = all_sample_set_ids + + if isinstance(self.sample_select.value, list): + self.data.rx.value["selected"] = False + for sample_set_id in self.sample_select.value: + self.data.rx.value.loc[ + self.data.rx.value.sample_set_id == sample_set_id, + "selected", + ] = True + if ( + isinstance(self.mod_update_button.value, bool) + and self.mod_update_button.value + ): + self.table.loc[ + self.table["population"] == self.population_from.value, # pyright: ignore[reportIndexIssue] + "sample_set_id", + ] = self.sample_set_to.value - page_size = param.Selector(objects=[10, 20, 50, 100], default=20) + if ( + isinstance(self.restore_button.value, bool) + and self.restore_button.value + ): + self.reset_modification() - table = param.DataFrame() + data = self.data - def __init__(self, **params): - super().__init__(**params) - self.table.set_index(["id"], inplace=True) - self.data = self.param.table.rx() + table = self.combine_tables(data) - @property - def tooltip(self): - return pn.widgets.TooltipIcon( - value=( - "The name and color of each sample set are editable. In the " - "color column, select a color from the dropdown list. In the " - "individuals table, you can assign individuals to sample sets." - ), - ) + return pn.Column(pn.Row(self.tooltip, align=("start", "end")), table) - @pn.depends("page_size", "create_sample_set_textinput") # , "columns") - def __panel__(self): - if self.create_sample_set_textinput is not None: - previous_names = [ - self.table.name[i] for i in range(len(self.table)) - ] - if self.create_sample_set_textinput in previous_names: - self.warning_pane.visible = True - else: - previous_colors = [ - self.table.color[i] for i in range(len(self.table)) - ] - unused_colors = [ - color - for color in config.COLORS - if color not in previous_colors - ] - if len(unused_colors) != 0: - colors = unused_colors - else: - colors = config.COLORS - self.warning_pane.visible = False - i = max(self.param.table.rx.value.index) + 1 - self.param.table.rx.value.loc[i] = [ - self.create_sample_set_textinput, - colors[random.randint(0, len(colors) - 1)], - False, - ] - self.create_sample_set_textinput = None - table = pn.widgets.Tabulator( - self.data, - layout="fit_data_table", - selectable=True, - page_size=self.page_size, - pagination="remote", - margin=10, - formatters=self.formatters, - editors=self.editors, - ) - return pn.Column(self.tooltip, table) + def options_sidebar(self) -> pn.Card: + """Creates a Panel card containing options for the individuals table. - def sidebar_table(self): - table = pn.widgets.Tabulator( - self.data, - layout="fit_data_table", - selectable=True, - page_size=10, - pagination="remote", - margin=10, - formatters=self.formatters, - editors=self.editors, - hidden_columns=["id"], - ) + Returns: + pn.Card: A Panel card containing the following options: + - Page size selector: Allows the user to adjust the number of + rows per page. + - Sample set selector: Allows the user to select specific + sample sets to filter the data. + """ return pn.Card( - pn.Column(self.tooltip, table), - title="Sample sets table quick view", - collapsed=True, + self.param.page_size, + self.sample_select, + collapsed=False, + title="Individuals table options", header_background=config.SIDEBAR_BACKGROUND, active_header_background=config.SIDEBAR_BACKGROUND, styles=config.VCARD_STYLE, ) - def sidebar(self): + def modification_sidebar(self) -> pn.Column: + """Creates a Panel column containing the data modification options. + + Returns: + pn.Column: A Panel column containing the following elements: + - A card with the following options: + - Population from selector: Allows the user to select the + original population ID. + - Sample set to selector: Allows the user to select the + new sample set ID. + - Restore button: Resets modifications. + - Update button: Applies the modifications. + - A warning message (`self.data_mod_warning`) that is displayed + when invalid data is entered. + """ + modification_header = pn.pane.HTML( + "

Batch reassign individuals

" + ) return pn.Column( pn.Card( - self.param.page_size, - self.param.create_sample_set_textinput, - title="Sample sets table options", + modification_header, + pn.Row(self.population_from, self.sample_set_to), + pn.Row( + pn.Spacer(width=120), + self.restore_button, + self.mod_update_button, + align="end", + ), collapsed=False, + title="Data modification", header_background=config.SIDEBAR_BACKGROUND, active_header_background=config.SIDEBAR_BACKGROUND, styles=config.VCARD_STYLE, ), - self.warning_pane, + self.data_mod_warning, ) - @property - def color(self): - """Return the color of all sample sets as a dictionary""" - d = {} - for index, row in self.data.rx.value.iterrows(): - d[index] = row.color - return d - - @property - def color_by_name(self): - """Return the color of all sample sets as a dictionary with - sample set names as keys""" - d = {} - for _, row in self.data.rx.value.iterrows(): - d[row["name"]] = row.color - return d - - @property - def names(self): - """Return the names of all sample sets as a dictionary""" - d = {} - for index, row in self.data.rx.value.iterrows(): - d[index] = row["name"] - return d - - @property - def names2id(self): - """Return the sample sets as dictionary with names as keys, - ids as values""" - d = {} - for index, row in self.data.rx.value.iterrows(): - d[row["name"]] = index - return d - - def loc(self, i): - """Return sample set by index""" - return self.data.rx.value.loc[i] - class DataStore(Viewer): + """Class representing a data store for managing and accessing data used for + analysis. This class provides access to various data sources and + functionalities related to individuals, sample sets, and the underlying + TreeSequenceModel. + + Attributes: + tsm (param.ClassSelector): + ClassSelector for the model.TSModel object holding the TreeSequence + data. + sample_sets_table (param.ClassSelector): + ClassSelector for the SampleSetsTable object managing sample set + information. + individuals_table (param.ClassSelector): + ClassSelector for the IndividualsTable object handling individual + data and filtering. + views (param.List, constant=True): + A list of views to be displayed. + + Methods: + color(self) -> pd.core.series.Series: + Returns a pandas DataFrame containing the colors of selected + individuals + merged with their corresponding sample set names. + + haplotype_gnn(self, focal_ind, windows=None): + Calculates and returns the haplotype Genealogical Nearest + Neighbors (GNN) + for a specified focal individual and optional window sizes. + """ + tsm = param.ClassSelector(class_=model.TSModel) - individuals_table = param.ClassSelector(class_=IndividualsTable) sample_sets_table = param.ClassSelector(class_=SampleSetsTable) + individuals_table = param.ClassSelector(class_=IndividualsTable) views = param.List(constant=True) @property - def color(self): - """Return colors of selected individuals""" + def color(self) -> pd.core.series.Series: + """Return colors of selected individuals.""" color = pd.merge( self.individuals_table.data.rx.value, self.sample_sets_table.data.rx.value, @@ -426,8 +871,27 @@ def color(self): ) return color.loc[color.selected].color - def haplotype_gnn(self, focal_ind, windows=None): - samples, sample_sets = self.individuals_table.sample_sets() + def haplotype_gnn( + self, focal_ind: int, windows: Optional[List[int]] = None + ) -> pd.DataFrame: + """Calculates and returns the haplotype Genealogical Nearest Neighbors + (GNN) for a specified focal individual and optional window sizes. + + Arguments: + focal_ind (int): The index (ID) of the focal individual within the + individuals table. + windows (List[int], optional): A list of window sizes for + calculating + GNNs within those specific windows. If None, GNNs are + calculated + across the entire sequence length. + + Returns: + pandas.DataFrame: A DataFrame containing GNN information for each + haplotype. + """ + print("ksbhflbsdfj", type(focal_ind), type(windows)) + sample_sets = self.individuals_table.sample_sets() ind = self.individuals_table.loc(focal_ind) hap = windowed_genealogical_nearest_neighbours( self.tsm.ts, ind.nodes, sample_sets, windows=windows @@ -457,9 +921,71 @@ def haplotype_gnn(self, focal_ind, windows=None): df.set_index(["haplotype", "start", "end"], inplace=True) return df - # Not needed? Never used? - def __panel__(self): - return pn.Row( - self.individuals_table, - self.sample_sets_table, + +def make_individuals_table(tsm: model.TSModel) -> IndividualsTable: + """Creates an IndividualsTable object from the data in the provided TSModel + object, by iterating through the individuals in the tree sequence and + creates an Individual object for each one, creating a Pandas DataFrame + populated with the individual level information. + + Arguments: + tsm (model.TSModel): The TSModel object containing the tree + sequence data. + + Returns: + IndividualsTable: An IndividualsTable object populated with + individual level information from the tree sequence. + """ + result = [] + for ts_ind in tsm.ts.individuals(): + ind = Individual(individual=ts_ind) + result.append(ind) + return IndividualsTable(table=pd.DataFrame(result)) + + +def make_sample_sets_table(tsm: model.TSModel) -> SampleSetsTable: + """Creates a SampleSetsTable object from the data in the provided TSModel + object, by iterating through the populations in the tree sequence and + creates a SampleSet object for each one, creating a Pandas DataFrame + populated with the population level information. + + Arguments: + tsm (model.TSModel): The TSModel object containing the tree + sequence data. + + Returns: + SampleSet: A SampleSet object populated with + population level information from the tree sequence. + """ + result = [] + for ts_pop in tsm.ts.populations(): + ss = SampleSet( + sample_set_id=ts_pop.id, population=ts_pop, predefined=True ) + result.append(ss) + return SampleSetsTable(table=pd.DataFrame(result)) + + +def preprocess(tsm: model.TSModel) -> Tuple[IndividualsTable, SampleSetsTable]: + """Take a TSModel and creates IndividualsTable and SampleSetsTable objects + from the data in the provided TSModel object. + + Arguments: + tsm (model.TSModel): The TSModel object containing the tree sequence + data. + + Returns: + Tuple[IndividualsTable, SampleSetsTable]: A tuple containing two + elements: + IndividualsTable: An IndividualsTable object populated with + individual + information from the tree sequence. + SampleSetsTable: A SampleSetsTable object populated with population + information from the tree sequence. + """ + logger.info( + "Preprocessing data: making individuals and sample sets tables" + ) + sample_sets_table = make_sample_sets_table(tsm) + individuals_table = make_individuals_table(tsm) + return individuals_table, sample_sets_table diff --git a/src/tseda/gnn.py b/src/tseda/gnn.py index de27625a..e3d00567 100644 --- a/src/tseda/gnn.py +++ b/src/tseda/gnn.py @@ -1,10 +1,8 @@ -"""Helper code to compute genealogical nearest neighbours along -haplotypes. +"""Helper code to compute genealogical nearest neighbours along haplotypes. Based on code from tskit.tests.test_stats: https://github.com/tskit-dev/tskit/pull/683/files#diff-e5e589330499b325320b2e3c205eaf350660b50691d3e1655f8789683e49dca6R399 - """ import numpy as np diff --git a/src/tseda/main.py b/src/tseda/main.py index 3ab85f92..17f49bba 100644 --- a/src/tseda/main.py +++ b/src/tseda/main.py @@ -1,14 +1,16 @@ -"""Helper module for serving the TSEda app from the command line using -panel serve. +"""Helper module for serving the TSEda app from the command line using panel +serve. -This module is used to serve the TSEda app from the command line using -panel serve. One use case is for development purposes where the --dev -argument enables automated reloading of the app when the source code -changes. To launch the app from the command line run: +This module is used to serve the TSEda app from the command line using panel +serve. One use case is for development purposes where the --dev argument +enables automated reloading of the app when the source code changes. To launch +the app from the command line run: - $ panel serve --dev --admin --show --args path/to/tszip_file.zip +$ panel serve --dev --admin --show --args path/to/tszip_file.zip -See https://panel.holoviz.org/how_to/server/commandline.html for more +See +https://panel.holoviz.org/how_to/server/commandline.html +for more information. """ @@ -36,8 +38,8 @@ app_ = app.DataStoreApp( datastore=datastore.DataStore( tsm=tsm, - individuals_table=individuals_table, sample_sets_table=sample_sets_table, + individuals_table=individuals_table, ), title="TSEda Datastore App", views=[datastore.IndividualsTable], diff --git a/src/tseda/model.py b/src/tseda/model.py index 679c60f0..569b9dd3 100644 --- a/src/tseda/model.py +++ b/src/tseda/model.py @@ -13,7 +13,6 @@ - simplify haplotype_gnn function - cache computations! - """ import dataclasses @@ -32,9 +31,7 @@ class DataTypes(Enum): - """ - Enum for getter method data types - """ + """Enum for getter method data types.""" LIST = "list" DATAFRAME = "df" @@ -42,7 +39,7 @@ class DataTypes(Enum): def decode_metadata(obj): - """Decode metadata from bytes to dict""" + """Decode metadata from bytes to dict.""" if not hasattr(obj, "metadata"): return None if isinstance(obj.metadata, bytes): @@ -55,7 +52,7 @@ def decode_metadata(obj): def parse_metadata(obj, regex): - """Retrieve metadata value pairs based on key regex""" + """Retrieve metadata value pairs based on key regex.""" md = decode_metadata(obj) if md is None: return @@ -66,7 +63,7 @@ def parse_metadata(obj, regex): def palette(cmap=Set3[12], n=12, start=0, end=1): - """Make a small colorblind-friendly palette""" + """Make a small colorblind-friendly palette.""" import matplotlib linspace = np.linspace(start, end, n) @@ -80,13 +77,11 @@ def palette(cmap=Set3[12], n=12, start=0, end=1): @dataclasses.dataclass class SampleSet: - """ - A class to contain sample sets. - """ + """A class to contain sample sets.""" name_re = re.compile(r"^(name|Name|population|Population)$") - id: np.int32 + sample_set_id: np.int32 name: str = None color: str = None population: dataclasses.InitVar[tskit.Population | None] = None @@ -96,18 +91,16 @@ class SampleSet: def __post_init__(self, population): if self.color is None: - self.color = self.colormap[self.id % len(self.colormap)] + self.color = self.colormap[self.sample_set_id % len(self.colormap)] if population is not None: self.name = parse_metadata(population, self.name_re) if self.name is None: - self.name = f"SampleSet-{self.id}" + self.name = f"SampleSet-{self.sample_set_id}" @dataclasses.dataclass class Individual(tskit.Individual): - """ - A class to handle individuals. - """ + """A class to handle individuals.""" name_re = re.compile(r"^(name|Name|SM)$") longitude_re = re.compile(r"^(longitude|Longitude|lng|long)$") @@ -140,17 +133,17 @@ def __post_init__(self) -> None: @property def samples(self): - """Return samples (nodes) associated with individual""" + """Return samples (nodes) associated with individual.""" return self.nodes def toggle(self) -> None: - """Toggle selection status""" + """Toggle selection status.""" self.selected = not self.selected def select(self) -> None: - """Select individual""" + """Select individual.""" self.selected = True def deselect(self) -> None: - """Deselect individual""" + """Deselect individual.""" self.selected = False diff --git a/src/tseda/vpages/__init__.py b/src/tseda/vpages/__init__.py index a481db36..25f0ec95 100644 --- a/src/tseda/vpages/__init__.py +++ b/src/tseda/vpages/__init__.py @@ -2,7 +2,6 @@ ignn, individuals, overview, - sample_sets, stats, structure, trees, @@ -10,7 +9,6 @@ PAGES = [ overview.OverviewPage, - sample_sets.SampleSetsPage, individuals.IndividualsPage, structure.StructurePage, ignn.IGNNPage, diff --git a/src/tseda/vpages/core.py b/src/tseda/vpages/core.py index 544bf7ae..2b193abc 100644 --- a/src/tseda/vpages/core.py +++ b/src/tseda/vpages/core.py @@ -1,7 +1,7 @@ """Core vpages module. -Provides View helper class for panel plots and helper functions common -to pages. +Provides View helper class for panel plots and helper functions common to +pages. """ import numpy as np diff --git a/src/tseda/vpages/ignn.py b/src/tseda/vpages/ignn.py index dce4dd57..482b064f 100644 --- a/src/tseda/vpages/ignn.py +++ b/src/tseda/vpages/ignn.py @@ -15,6 +15,8 @@ - linked brushing between the map and the GNN plot """ +from typing import Any, Union + import holoviews as hv import hvplot.pandas # noqa import pandas as pd @@ -37,11 +39,30 @@ class GNNHaplotype(View): - """Make GNN haplotype plot.""" + """Make GNN haplotype plot. This class creates a Panel object that displays + a GNN haplotype plot for a selected individual. + + Attributes: + individual_id (int): the ID of the individual to visualize (0-indexed). + Defaults to None. + window_size (int): The size of the window to use for visualization. + Defaults to 10000. Must be greater than 0. + warning_pane (pn.Alert): a warning panel that is displayed if no + samples are selected. + individual_id_warning (pn.Alert): a warning panel that is displayed + if an invalid individual ID is entered. + + Methods: + plot(haplotype=0): makes the haplotype plot. + plot_haplotype0(): calls the plot function for haplotype 0. + plot_haplotype1(): calls the plot function for haplotype 1. + __panel__() -> pn.Column: Defines the layout of the main content area + or sends out a warning message if the user input isn't valid. + sidebar() -> pn.Card: Defines the layout of the sidebar content area. + """ individual_id = param.Integer( default=None, - bounds=(0, None), doc="Individual ID (0-indexed)", ) @@ -56,9 +77,32 @@ class GNNHaplotype(View): visible=False, ) - def plot(self, haplotype=0): + individual_id_warning = pn.pane.Alert( + "", + alert_type="warning", + visible=False, + ) + + def plot( + self, haplotype: int = 0 + ) -> Union[hv.core.overlay.NdOverlay, pn.pane.Markdown]: + """Creates the GNN Haplotype plot. + + Args: + haplotype (int): Can be either 0 or 1 and will be used to plot + haplotype 0 or haplotype 1. + + + Returns: + pn.pane.Markdown: A message directed to the user to enter a valid + correct sample ID. + pn.pane.Markdown: A placeholder pane in place to show the + warningmessage when a incorrect sample ID is entered. + hv.core.overlay.NdOverlay: A GNN Haplotype plot. + """ + if self.individual_id is None: - return + return pn.pane.Markdown("Enter a sample ID") if self.window_size is not None: windows = make_windows( self.window_size, self.datastore.tsm.ts.sequence_length @@ -122,25 +166,121 @@ def plot(self, haplotype=0): ) return p + def plot_haplotype0( + self, + ) -> Union[hv.core.overlay.NdOverlay, pn.pane.Markdown]: + """Creates the GNN Haplotype plot for haplotype 0. + + Returns: + pn.pane.Markdown: A message directed to the user to enter a valid + correct sample ID. + pn.pane.Markdown: A placeholder pane in place to show the + warningmessage when a incorrect sample ID is entered. + hv.core.overlay.NdOverlay: A GNN Haplotype plot for haplotype 0. + """ + return self.plot(0) + + def plot_haplotype1( + self, + ) -> Union[hv.core.overlay.NdOverlay, pn.pane.Markdown]: + """Creates the GNN Haplotype plot for haplotype 1. + + Returns: + pn.pane.Markdown: A message directed to the user to enter a valid + correct sample ID. + pn.pane.Markdown: A placeholder pane in place to show the + warningmessage when a incorrect sample ID is entered. + hv.core.overlay.NdOverlay: A GNN Haplotype plot for haplotype 1. + """ + return self.plot(1) + + def check_inputs(self, inds: pd.core.frame.DataFrame) -> tuple: + """Checks the inputs to the GNN Haplotype plot. + + Args: + inds (pandas.core.frame.DataFrame): Contains the data in the + individuals table. + + Returns: + pn.pane.Column: If the input argument is valid this coloumn will + return the nodes of the index and an empty Coloumn. Otherwise it + will return a None value and a Coloumn telling the user to enter a + valid sample ID. + """ + max_id = inds.index.max() + info_column = pn.Column( + pn.pane.Markdown( + "**Enter a valid sample id to see the GNN haplotype plot.**" + ), + ) + if self.individual_id is None: + self.individual_id_warning.visible = False # No warning for None + return None, info_column + + try: + if ( + not isinstance(self.individual_id, (int, float)) + or self.individual_id < 0 + ): + self.individual_id_warning.object = ( + "The individual ID does not exist. " + f"Valid IDs are in the range 0-{max_id}." + ) + self.individual_id_warning.visible = True + return (None, info_column) + else: + self.individual_id_warning.visible = False + nodes = inds.loc[self.individual_id].nodes + info_column = pn.Column(pn.pane.Markdown("")) + return (nodes, info_column) + except KeyError: + self.individual_id_warning.object = ( + "The individual ID does not exist. " + f"Valid IDs are in the range 0-{max_id}." + ) + self.individual_id_warning.visible = True + return (None, info_column) + @pn.depends("individual_id", "window_size") - def __panel__(self, **params): + def __panel__(self, **params) -> pn.Column: + """Returns the main content for the GNN Haplotype plot which is + retrieved from the `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the main content area of the GNN + Haplotype plot or a warning message if the input isn't validated. + """ + inds = self.datastore.individuals_table.data.rx.value - if self.individual_id is None: - return pn.pane.Markdown("") - nodes = inds.loc[self.individual_id].nodes - return pn.Column( - pn.pane.Markdown(f"## Individual id {self.individual_id}"), - self.warning_pane, - pn.pane.Markdown(f"### Haplotype 0 (sample id {nodes[0]})"), - self.plot(0), - pn.pane.Markdown(f"### Haplotype 1 (sample id {nodes[1]})"), - self.plot(1), - ) + nodes = self.check_inputs(inds) + if nodes[0] is not None: + return pn.Column( + pn.pane.HTML( + "

" + f"- Individual id {self.individual_id}

", + sizing_mode="stretch_width", + ), + self.warning_pane, + pn.pane.Markdown(f"### Haplotype 0 (sample id {nodes[0][0]})"), + self.plot_haplotype0, + pn.pane.Markdown(f"### Haplotype 1 (sample id {nodes[0][1]})"), + self.plot_haplotype0, + ) + else: + return nodes[1] - def sidebar(self): + def sidebar(self) -> pn.Card: + """Returns the content of the sidbar options for the GNN Haplotype + plot. + + Returns: + pn.Card: The layout for the sidebar content area connected to the + GNN Haplotype plot. + """ return pn.Card( self.param.individual_id, self.param.window_size, + self.individual_id_warning, collapsed=False, title="GNN haplotype options", header_background=config.SIDEBAR_BACKGROUND, @@ -150,7 +290,24 @@ def sidebar(self): class VBar(View): - """Make VBar plot of GNN output.""" + """Make VBar plot of GNN output. This class creates a Panel object that + displays a VBar plot of the sample sets. + + Attributes: + sorting (pn.Selector): the selected population to base the sort order + on. + sort_order (pn.Selector): the selected sorting order + (Ascending/Descending) + warning_pane (pn.Alert): a warning panel that is displayed if no + samples are selected. + + Methods: + gnn() -> pd.DataFrame: gets the data for the GNN VBar plot. + __panel__() -> pn.panel: creates the panel containing the GNN VBar + plot. + sidebar() -> pn.Card: defines the layout of the sidebar content area + for the VBar options. + """ sorting = param.Selector( doc="Select what population to base the sort order on. Default is " @@ -173,9 +330,18 @@ class VBar(View): ) # TODO: move to DataStore class? - def gnn(self): + def gnn(self) -> pd.DataFrame: + """Creates the data for the GNN VBar plot. + + Returns: + pd.DataFrame: a dataframe containing all the information for the + GNN VBar plot. + """ inds = self.datastore.individuals_table.data.rx.value - samples, sample_sets = self.datastore.individuals_table.sample_sets() + sample_sets = self.datastore.individuals_table.sample_sets() + samples = [ + sample for sublist in sample_sets.values() for sample in sublist + ] self.param.sorting.objects = [""] + list( self.datastore.sample_sets_table.names.values() ) @@ -196,9 +362,19 @@ def gnn(self): return df @pn.depends("sorting", "sort_order") - def __panel__(self): - samples, sample_sets = self.datastore.individuals_table.sample_sets() + def __panel__(self) -> Union[pn.pane.plot.Bokeh, pn.pane.Alert, Any]: + # TODO: Does not accept pn.panel so Any is included as quickfix + """Returns the main content of the plot which is retrieved from the + `datastore.tsm.ts` attribute by the gnn() function. + + Returns: + pn.pane.Alert: a warning pane telling the user that it needs to + select a sample. + pn.pane.plot.Bokeh: a panel with the GNN VBar plot. + """ + sample_sets = self.datastore.individuals_table.sample_sets() if len(list(sample_sets.keys())) < 1: + print(type(self.warning_pane)) return self.warning_pane df = self.gnn() sample_sets = self.datastore.sample_sets_table.data.rx.value @@ -233,6 +409,15 @@ def __panel__(self): ["sample_set_id"] + [self.sorting] + ["sample_id", "id"] # pyright: ignore[reportOperatorIssue] ) ascending = [True, False, False, False] + + columns = df.columns.tolist() + columns.remove(self.sorting) + id_index = columns.index("id") + columns.insert(id_index + 1, self.sorting) + df = df[columns] + sorting_index = groups.index(self.sorting) + groups[sorting_index], groups[0] = groups[0], groups[sorting_index] + color[sorting_index], color[0] = color[0], color[sorting_index] else: sort_by = ["sample_set_id", "sample_id", "id"] ascending = [True, False, False] @@ -288,10 +473,15 @@ def __panel__(self): fig.xaxis.separator_line_width = 2.0 fig.xaxis.separator_line_color = "grey" fig.xaxis.separator_line_alpha = 0.5 - return pn.panel(fig) def sidebar(self): + """Returns the content of the sidbar options for the VBar plot. + + Returns: + pn.Card: The layout for the sidebar content area connected to the + VBar plot. + """ return pn.Card( self.param.sorting, self.param.sort_order, @@ -304,6 +494,25 @@ def sidebar(self): class IGNNPage(View): + """Make the iGNN page. This class creates the iGNN page. + + Attributes: + key (str): A unique identifier for the iGNN instance. + title (str): The display title for the iGNN instance. + geomap (GeoMap): An instance of the GeoMap class, providing geographic + visualizations of genomic data. + vbar (VBar): An instance of the VBar class, providing bar plot + visualizations of genomic data. + gnnhaplotype (GNNHaplotype): An instance of the GNNHaplotype class, + handling GNN-based haplotype analysis. + sample_sets (pandas.DataFrame): A DataFrame containing information + about the available sample sets. + + Methods: + __panel__() -> pn.Column: Defines the layout of the main content area. + sidebar() -> pn.Column: Defines the layout of the sidebar content area. + """ + key = "iGNN" title = "iGNN" geomap = param.ClassSelector(class_=GeoMap) @@ -317,17 +526,62 @@ def __init__(self, **params): self.gnnhaplotype = GNNHaplotype(datastore=self.datastore) self.sample_sets = self.datastore.sample_sets_table - def __panel__(self): + def __panel__(self) -> pn.Column: + """Returns the main content of the page which is retrieved from the + `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the main content area. + """ + return pn.Column( - pn.Row( - self.geomap, + pn.Accordion( + pn.Column( + self.geomap, + pn.pane.Markdown( + "**Map** - Displays the geographical locations " + "where samples were collected and visually " + "represents their group sample affiliations " + "through colors.", + sizing_mode="stretch_width", + ), + name="Geomap", + ), + pn.Column( + self.vbar, + pn.pane.Markdown( + "**vBar** - Lorem ipsum", + sizing_mode="stretch_width", + ), + name="VBar Plot", + ), + pn.Column(self.gnnhaplotype, name="GNN Haplotype Plot"), + active=[0, 1, 2], ), - self.vbar, - self.gnnhaplotype, ) - def sidebar(self): + def sidebar(self) -> pn.Column: + """Returns the sidebar content of the page which is retrieved from the + `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the sidebar content area. + """ + return pn.Column( + pn.pane.HTML( + "

iGNN

", sizing_mode="stretch_width" + ), + pn.pane.Markdown( + ( + "This section provides interactive visualizations for " + "**Genealogical Nearest Neighbors " + "(GNN)** analysis.

" + "Use the controls below to customize the plots and " + "adjust parameters." + ), + sizing_mode="stretch_width", + ), self.geomap.sidebar, self.vbar.sidebar, self.gnnhaplotype.sidebar, diff --git a/src/tseda/vpages/individuals.py b/src/tseda/vpages/individuals.py index 106d9e60..ca30f328 100644 --- a/src/tseda/vpages/individuals.py +++ b/src/tseda/vpages/individuals.py @@ -12,29 +12,143 @@ import panel as pn import param -from tseda.datastore import IndividualsTable +from tseda.datastore import IndividualsTable, SampleSetsTable from .core import View from .map import GeoMap class IndividualsPage(View): - key = "individuals" - title = "Individuals" + """This class represents a view for the individuals page. + + Attributes: + key (str): A unique identifier for this view. + title (str): The title displayed for this view. + sample_sets_table (param.ClassSelector): A reference to a + `SampleSetsTable` object containing information about sample sets. + individuals_table (param.ClassSelector): A reference to an + `IndividualsTable` object managing individual data and filtering + options. + geomap (param.ClassSelector): A reference to a `GeoMap` object + displaying + geographical locations and sample set affiliations (optional). + + Methods: + __panel__(): Defines the layout of the view using Panel components. + sidebar(): Defines the sidebar content with descriptions and controls. + Contains: + sample_sets_accordion_toggled(event): Handles the toggling event + of the sample sets accordion + """ + + key = "individuals and sets" + title = "Individuals & sets" + sample_sets_table = param.ClassSelector(class_=SampleSetsTable) + individuals_table = param.ClassSelector(class_=IndividualsTable) geomap = param.ClassSelector(class_=GeoMap) - data = param.ClassSelector(class_=IndividualsTable) def __init__(self, **params): super().__init__(**params) self.geomap = GeoMap(datastore=self.datastore) - self.data = self.datastore.individuals_table + self.sample_sets_table = self.datastore.sample_sets_table + self.individuals_table = self.datastore.individuals_table + self.individuals_table.sample_sets_table = self.sample_sets_table + + @pn.depends( + "individuals_table.sample_select.value", + "individuals_table.refresh_button.value", + ) + def __panel__(self) -> pn.Column: + """Defines the layout of the view using Panel components. This method + is called dynamically when dependent parameters change. + + Returns: + pn.Column: A Panel column containing the layout of the view. + """ + sample_sets_accordion = pn.Accordion( + pn.Column( + self.sample_sets_table, + sizing_mode="stretch_width", + name="Sample Sets Table", + ), + max_width=400, + active=[0], + ) + + def sample_sets_accordion_toggled(event): + """Handles the toggling event of the sample sets accordion. + + This function dynamically adjusts the maximum width of the + accordion based on its active state. If the accordion is closed + (active state is an empty list), the width is set to 180 pixels. + Otherwise, when the accordion is open, the width is set to 400 + pixels. + + Arguments: + event (param.Event): The event object triggered by the + accordion's toggle. NOTE: event should not be provided, but + Panel + does not recognize the function without it. + """ + if sample_sets_accordion.active == []: + sample_sets_accordion.max_width = 180 + else: + sample_sets_accordion.max_width = 400 + + sample_sets_accordion.param.watch( + sample_sets_accordion_toggled, "active" + ) + return pn.Column( + pn.Row( + pn.Accordion( + pn.Column( + self.geomap, + pn.pane.Markdown( + "**Map** - Displays the geographical locations " + "where samples were collected and visually " + "represents their group sample affiliations " + "through colors.", + sizing_mode="stretch_width", + ), + min_width=400, + name="Geomap", + ), + active=[0], + ), + pn.Spacer(sizing_mode="stretch_width", max_width=5), + sample_sets_accordion, + ), + pn.Accordion( + pn.Column(self.individuals_table, name="Individuals Table"), + active=[0], + ), + ) - def __panel__(self): - return pn.Column(self.geomap, self.data) + def sidebar(self) -> pn.Column: + """Defines the content for the sidebar of the view containing + descriptive text and control elements. - def sidebar(self): + Returns: + pn.Column: A Panel column containing the sidebar content. + """ return pn.Column( + pn.pane.HTML( + "

Individuals & sets

", + sizing_mode="stretch_width", + ), + pn.pane.Markdown( + ( + "This section allows you to manage and explore" + "individual samples in your dataset " + "and customize Sample Sets.

" + "Use the controls below to customize the plots," + "adjust parameters, and add new samples." + ), + sizing_mode="stretch_width", + ), + self.individuals_table.refresh_button, self.geomap.sidebar, - self.data.options_sidebar, - self.data.modification_sidebar, + self.sample_sets_table.sidebar, + self.individuals_table.options_sidebar, + self.individuals_table.modification_sidebar, ) diff --git a/src/tseda/vpages/map.py b/src/tseda/vpages/map.py index c306c19d..5d857c42 100644 --- a/src/tseda/vpages/map.py +++ b/src/tseda/vpages/map.py @@ -1,4 +1,4 @@ -"""Module for creating a map of the world with sample locations +"""Module for creating a map of the world with sample locations. Generate a hvplot map of the world with sample locations based on a GeoPandas representation of the individuals data. The map is @@ -9,16 +9,17 @@ - Add linked brushing between the map and other panel objects / widgets - Fix issue where map is rendered small and repeated tiles - """ import geopandas import hvplot.pandas # noqa +import pandas as pd import panel as pn import param import xyzservices.providers as xyz from tseda import config +from tseda.datastore import IndividualsTable from .core import View @@ -34,8 +35,24 @@ class GeoMap(View): - height = param.Integer(default=400, doc="Height of the map") - width = param.Integer(default=1200, doc="Width of the map") + """Make the Geomap plot. This class creates a hvplot that displays the map + where the different samples were collected. + + Attributes: + tiles_selector (pn.Selector): the selected tiles for the map + vizualisation. + tiles (str): the selected tile for the map. + individuals_table (IndividualsTable): An instance of the + IndividualsTable class, containing the information from the individuals + table. + + Methods: + __panel__() -> gdf.hvplot: Returns the Geomap as an Hvplot. + sidebar() -> pn.Card: Defines the layout of the sidebar options for + the Geomap. + """ + + individuals_table = param.ClassSelector(class_=IndividualsTable) tiles_selector = param.Selector( default="WorldPhysical", @@ -44,8 +61,18 @@ class GeoMap(View): ) tiles = tiles_options[tiles_selector.default] - @pn.depends("tiles_selector", "height", "width") + def __init__(self, **params): + super().__init__(**params) + self.individuals_table = self.datastore.individuals_table + + @pn.depends("individuals_table.refresh_button.value") def __panel__(self): + """Returns the main content for the Geomap plot which is retrieved from + the `datastore.tsm.ts` attribute. + + Returns: + gdf.hvplot: the geomap plot as a Hvplot. + """ self.tiles = tiles_options[self.tiles_selector] df = self.datastore.individuals_table.data.rx.value df = df.loc[df.selected] @@ -56,25 +83,51 @@ def __panel__(self): ) color = color.loc[~gdf.geometry.is_empty.values] gdf = gdf[~gdf.geometry.is_empty] - return gdf.hvplot.points( + + kw = { + "geo": True, + "tiles": self.tiles, + "tiles_opts": {"alpha": 0.5}, + "responsive": True, + "max_height": 200, + "min_height": 199, + "xlim": (-180, 180), + "ylim": (-60, 70), + "tools": ["wheel_zoom", "box_select", "tap", "pan", "reset"], + } + + if gdf.empty: + gdf = geopandas.GeoDataFrame( + pd.DataFrame(index=[0]), + geometry=geopandas.points_from_xy([0.0], [0.0]), + ) + return gdf.hvplot( + **kw, + hover_cols=None, + size=100, + color=None, + fill_alpha=0.0, + line_color=None, + ) + return gdf.hvplot( + **kw, hover_cols=["name", "population", "sample_set_id"], - geo=True, - tiles=self.tiles, - tiles_opts={"alpha": 0.5}, - max_height=self.height, - min_height=self.height, size=100, color=color, - tools=["wheel_zoom", "box_select", "tap", "pan", "reset"], fill_alpha=0.5, line_color="black", - responsive=True, ) def sidebar(self): + """Returns the content of the sidbar options for the Geomap plot. + + Returns: + pn.Card: The layout for the sidebar content area connected to the + Geomap plot. + """ return pn.Card( self.param.tiles_selector, - collapsed=False, + collapsed=True, title="Map options", header_background=config.SIDEBAR_BACKGROUND, active_header_background=config.SIDEBAR_BACKGROUND, diff --git a/src/tseda/vpages/overview.py b/src/tseda/vpages/overview.py index 577aa162..b43a3e4a 100644 --- a/src/tseda/vpages/overview.py +++ b/src/tseda/vpages/overview.py @@ -1,11 +1,56 @@ +"""Overview page. + +This file contains the class for the application's overview page. The page +includes both the main content with the information about the data file given +to the application as well as a sidebar with a short description of the +application. +""" + import panel as pn from .core import View class OverviewPage(View): + """Represents the overview page of the tseda application. + + Attributes: + key (str): A unique identifier for this view within the application. + title (str): The title displayed on the page. + + Methods: + __panel__() -> pn.Column: Defines the layout of the main content area. + sidebar() -> pn.Column: Defines the layout of the sidebar content area. + """ + key = "overview" title = "Overview" - def __panel__(self): + def __panel__(self) -> pn.Column: + """Returns the main content of the page which is retrieved from the + `datastore.tsm.ts` attribute. + + Returns: + pn.Column: The layout for the main content area. + """ return pn.Column(pn.pane.HTML(self.datastore.tsm.ts)) + + def sidebar(self) -> pn.Column: + """Returns the content of the sidebar. + + Returns: + pn.Column: The layout for the sidebar. + """ + return pn.Column( + pn.pane.HTML( + "

Overview

", + sizing_mode="stretch_width", + ), + pn.pane.Markdown( + ( + "Welcome to **tseda**! This is a tool that " + "you can use to analyze your tskit data file." + ), + sizing_mode="stretch_width", + ), + ) diff --git a/src/tseda/vpages/sample_sets.py b/src/tseda/vpages/sample_sets.py deleted file mode 100644 index 0d16eca7..00000000 --- a/src/tseda/vpages/sample_sets.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Sample sets editor page. - -Panel showing a simple sample set editor page. The page consists of -an editable table showing the sample sets. - -The sample sets table allows the user to edit the name and color of -each sample set. In addition, new sample sets can be added that allows -the user to reassign individuals to different sample sets in the -individuals table. - -TODO: - -- change from/to params to param.NumericTuple? -""" - -import panel as pn -import param - -from tseda.datastore import SampleSetsTable - -from .core import View - - -class SampleSetsPage(View): - key = "sample_sets" - title = "Sample Sets" - data = param.ClassSelector(class_=SampleSetsTable) - - def __init__(self, **params): - super().__init__(**params) - self.data = self.datastore.sample_sets_table - - def __panel__(self): - return pn.Column(self.data) - - def sidebar(self): - return pn.Column(self.data.sidebar) diff --git a/src/tseda/vpages/stats.py b/src/tseda/vpages/stats.py index 0966f8fb..f863e8b3 100644 --- a/src/tseda/vpages/stats.py +++ b/src/tseda/vpages/stats.py @@ -9,6 +9,8 @@ """ import ast +import itertools +from typing import Union import holoviews as hv import pandas as pd @@ -25,9 +27,10 @@ # TODO: make sure this is safe -def eval_sample_sets(sample_sets): - """Evaluate sample sets parameter.""" - return ast.literal_eval(sample_sets) +def eval_comparisons(comparisons): + """Evaluate comparisons parameter.""" + evaluated = ast.literal_eval(str(comparisons).replace(" & ", ",")) + return [tuple(map(int, item.split(","))) for item in evaluated] def eval_indexes(indexes): @@ -36,10 +39,41 @@ def eval_indexes(indexes): class OnewayStats(View): + """This class defines a view for one-way population genetic statistics + plots. + + Attributes: + mode (param.Selector): + A parameter to select the calculation mode ("site" or "branch"). + Branch mode is only available for calibrated data. (default: "site") + statistic (param.Selector): + A parameter to select the statistic to calculate + (e.g., "Tajimas_D", "diversity"). + Names correspond to tskit method names. (default: "diversity") + window_size (param.Integer): + A parameter to define the size of the window for window-based + statistics. + (default: 10000, bounds=(1, None)) + sample_select_warning (pn.pane.Alert): + An alert panel displayed when no sample sets are selected. + tooltip (pn.widgets.TooltipIcon): + A tooltip icon providing information about the plot. + + Methods: + tooltip() -> pn.widgets.TooltipIcon: + Returns a tooltip for the plot. + __panel__() -> pn.Column: + Generates the view containing the one-way statistics plot. + Raises a warning if no sample sets are selected. + sidebar() -> pn.Card: + Creates the sidebar panel with controls for the plot. + """ + mode = param.Selector( - objects=["branch", "site"], + objects=["site"], default="site", - doc="Select mode for statistics.", + doc="""Select mode (site or branch) for statistics. + Branch mode is only available for calibrated data.""", ) statistic = param.Selector( objects=["Tajimas_D", "diversity"], @@ -49,13 +83,21 @@ class OnewayStats(View): window_size = param.Integer( default=10000, bounds=(1, None), doc="Size of window" ) - sample_sets = param.String( - default="[0,1]", - doc="Comma-separated list of sample sets (0-indexed) to plot.", + sample_select_warning = pn.pane.Alert( + """Select at least 1 sample set to see this plot. + Sample sets are selected on the Individuals page""", + alert_type="warning", ) @property def tooltip(self): + """Returns a TooltipIcon widget containing information about the oneway + statistical plot and how to edit it. + + Returns: + pn.widgets.TooltipIcon: A TooltipIcon widget displaying + the information. + """ return pn.widgets.TooltipIcon( value=( "Oneway statistical plot. The colors can be modified " @@ -63,28 +105,38 @@ def tooltip(self): ) ) - @param.depends("mode", "statistic", "window_size", "sample_sets") - def __panel__(self): + def __init__(self, **params): + super().__init__(**params) + if self.datastore.tsm.ts.time_units != "uncalibrated": + self.param.mode.objects = ["branch", "site"] + + @param.depends("mode", "statistic", "window_size") + def __panel__(self) -> Union[pn.Column, pn.pane.Alert]: + """Returns the plot. + + Returns: + pn.Column: The layout for the plot. + """ data = None windows = make_windows( self.window_size, self.datastore.tsm.ts.sequence_length ) - sample_sets_list = eval_sample_sets(self.sample_sets) - try: - sample_sets = self.datastore.individuals_table.get_sample_sets( - sample_sets_list - ) - except KeyError: - return pn.pane.Alert("Sample set error. Check sample set indexes.") + sample_sets_dictionary = self.datastore.individuals_table.sample_sets() + sample_sets_ids = list(sample_sets_dictionary.keys()) + if len(sample_sets_ids) < 1: + return self.sample_select_warning + sample_sets_individuals = list(sample_sets_dictionary.values()) if self.statistic == "Tajimas_D": data = self.datastore.tsm.ts.Tajimas_D( - sample_sets, windows=windows, mode=self.mode + sample_sets_individuals, windows=windows, mode=self.mode ) + fig_text = "**Oneway Tajimas_D plot** - Lorem Ipsum" elif self.statistic == "diversity": data = self.datastore.tsm.ts.diversity( - sample_sets, windows=windows, mode=self.mode + sample_sets_individuals, windows=windows, mode=self.mode ) + fig_text = "**Oneway Diversity plot** - Lorem Ipsum" else: raise ValueError("Invalid statistic") @@ -92,7 +144,7 @@ def __panel__(self): data, columns=[ self.datastore.sample_sets_table.names[i] - for i in sample_sets_list + for i in sample_sets_ids ], ) position = hv.Dimension( @@ -110,17 +162,24 @@ def __panel__(self): } kdims = [hv.Dimension("ss", label="Sample set")] holomap = hv.HoloMap(data_dict, kdims=kdims) - return pn.panel( - holomap.overlay("ss").opts(legend_position="right"), - sizing_mode="stretch_width", + return pn.Column( + pn.panel( + holomap.overlay("ss").opts(legend_position="right"), + sizing_mode="stretch_width", + ), + pn.pane.Markdown(fig_text), ) - def sidebar(self): + def sidebar(self) -> pn.Card: + """Returns the content of the sidebar. + + Returns: + pn.Card: The layout for the sidebar. + """ return pn.Card( self.param.mode, self.param.statistic, self.param.window_size, - self.param.sample_sets, collapsed=False, title="Oneway statistics plotting options", header_background=config.SIDEBAR_BACKGROUND, @@ -130,10 +189,47 @@ def sidebar(self): class MultiwayStats(View): + """This class defines a view for multi-way population genetic statistics + plots. + + Attributes: + mode (param.Selector): + A parameter to select the calculation mode ("site" or "branch"). + Branch mode is only available for calibrated data. (default: "site") + statistic (param.Selector): + A parameter to select the statistic to calculate (e.g., "Fst", + "divergence"). + Names correspond to tskit method names. (default: "Fst") + window_size (param.Integer): + A parameter to define the size of the window for window-based + statistics. + (default: 10000, bounds=(1, None)) + comparisons (pn.widgets.MultiChoice): + A multi-choice widget for selecting sample set pairs to compare. + sample_select_warning (pn.pane.Alert): + An alert panel displayed when no sample sets are selected. + cmaps (dict): + A dictionary containing available Holoviews colormaps. + colormap (param.Selector): + A parameter to select the colormap for the plot. + (default: "glasbey_dark") + + Methods: + set_multichoice_options(): + Updates the options for the comparisons multi-choice widget based + on available sample sets. + __panel__() -> pn.Column: + Generates the view containing the multiway statistics plot. + Raises a warning if no sample sets are selected. + sidebar() -> pn.Card: + Creates the sidebar panel with controls for the plot. + """ + mode = param.Selector( - objects=["branch", "site"], + objects=["site"], default="site", - doc="Select mode for statistics.", + doc="""Select mode (site or branch) for statistics. + Branch mode is only available for calibrated data.""", ) statistic = param.Selector( objects=["Fst", "divergence"], @@ -143,16 +239,13 @@ class MultiwayStats(View): window_size = param.Integer( default=10000, bounds=(1, None), doc="Size of window" ) - sample_sets = param.String( - default="[0,1,2]", - doc="Comma-separated list of sample sets (0-indexed) to compare.", + comparisons = pn.widgets.MultiChoice( + name="Comparisons", description="Choose indexes to compare.", value=[] ) - indexes = param.String( - default="[(0,1), (0,2), (1,2)]", - doc=( - "Comma-separated list of tuples of sample sets " - "(0-indexed) indexes to compare." - ), + sample_select_warning = pn.pane.Alert( + """Select at least 2 sample sets to see this plot. + Sample sets are selected on the Individuals page""", + alert_type="warning", ) cmaps = { cm.name: cm @@ -167,8 +260,20 @@ class MultiwayStats(View): doc="Holoviews colormap for sample set pairs", ) + def __init__(self, **params): + super().__init__(**params) + if self.datastore.tsm.ts.time_units != "uncalibrated": + self.param.mode.objects = ["branch", "site"] + @property def tooltip(self): + """Returns a TooltipIcon widget containing information about the + multiway statistical plot and how to edit it. + + Returns: + pn.widgets.TooltipIcon: A TooltipIcon widget displaying + the information. + """ return pn.widgets.TooltipIcon( value=( "Multiway statistical plot. The colors can be modified " @@ -176,44 +281,81 @@ def tooltip(self): ) ) + def set_multichoice_options(self): + """This method dynamically populates the `comparisons` widget with a + list of possible sample set pairs based on the currently selected + sample sets in the `individuals_table`.""" + sample_sets = self.datastore.individuals_table.sample_sets() + all_comparisons = list( + f"{x} & {y}" + for x, y in itertools.combinations( + list(sample_sets.keys()), + 2, + ) + ) + self.comparisons.options = all_comparisons + @pn.depends( - "mode", - "statistic", - "window_size", - "sample_sets", - "indexes", - "colormap", + "mode", "statistic", "window_size", "colormap", "comparisons.value" ) def __panel__(self): + """Returns the multiway plot. + + Returns: + pn.Column: The layout for the main content area. + """ + self.set_multichoice_options() + data = None tsm = self.datastore.tsm - sample_sets_list = [] windows = [] - indexes_list = [] colormap_list = [] windows = make_windows(self.window_size, tsm.ts.sequence_length) - sample_sets_list = eval_sample_sets(self.sample_sets) - indexes_list = eval_indexes(self.indexes) - try: - sample_sets = self.datastore.individuals_table.get_sample_sets( - sample_sets_list + comparisons = eval_comparisons(self.comparisons.value) + + selected_sample_sets = self.datastore.individuals_table.sample_sets() + selected_sample_sets_ids = list(selected_sample_sets.keys()) + if len(selected_sample_sets_ids) < 2: + return self.sample_select_warning + elif self.comparisons.value == []: + return pn.pane.Markdown( + "**Select which sample sets to compare to see this plot.**" + ) + all_sample_sets = self.datastore.individuals_table.sample_sets( + only_selected=False + ) + all_sample_sets_sorted = { + key: all_sample_sets[key] for key in sorted(all_sample_sets) + } + sample_sets_individuals = list(all_sample_sets_sorted.values()) + comparisons_indexes = [ + ( + list(all_sample_sets_sorted.keys()).index(x), + list(all_sample_sets_sorted.keys()).index(y), + ) + for x, y in comparisons + if x in all_sample_sets_sorted and y in all_sample_sets_sorted + ] + if comparisons_indexes == []: + return pn.pane.Markdown( + "**Select which sample sets to compare to see this plot.**" ) - except KeyError: - return pn.pane.Alert("Sample set error. Check sample set indexes.") if self.statistic == "Fst": data = tsm.ts.Fst( - sample_sets, + sample_sets_individuals, windows=windows, - indexes=indexes_list, + indexes=comparisons_indexes, mode=self.mode, ) + fig_text = "**Multiway Fst plot** - Lorem Ipsum" elif self.statistic == "divergence": data = tsm.ts.divergence( - sample_sets, + sample_sets_individuals, windows=windows, - indexes=indexes_list, + indexes=comparisons_indexes, mode=self.mode, ) + fig_text = "**Multiway divergence plot** - Lorem Ipsum" else: raise ValueError("Invalid statistic") sample_sets_table = self.datastore.sample_sets_table @@ -226,7 +368,7 @@ def __panel__(self): sample_sets_table.loc(j)["name"], ] ) - for i, j in indexes_list + for i, j in comparisons_indexes ], ) position = hv.Dimension( @@ -245,18 +387,25 @@ def __panel__(self): } kdims = [hv.Dimension("sspair", label="Sample set combination")] holomap = hv.HoloMap(data_dict, kdims=kdims) - return pn.panel( - holomap.overlay("sspair").opts(legend_position="right"), - sizing_mode="stretch_width", + return pn.Column( + pn.panel( + holomap.overlay("sspair").opts(legend_position="right"), + sizing_mode="stretch_width", + ), + pn.pane.Markdown(fig_text), ) - def sidebar(self): + def sidebar(self) -> pn.Card: + """Returns the content of the sidebar. + + Returns: + pn.Card: The layout for the sidebar. + """ return pn.Card( self.param.mode, self.param.statistic, self.param.window_size, - self.param.sample_sets, - self.param.indexes, + self.comparisons, self.param.colormap, collapsed=False, title="Multiway statistics plotting options", @@ -267,6 +416,28 @@ def sidebar(self): class StatsPage(View): + """This class defines a view for the "Statistics" page. + + Attributes: + key (str): + The unique key for the page (default: "stats"). + title (str): + The title of the page (default: "Statistics"). + oneway (param.ClassSelector): + A parameter to select the OnewayStats class for one-way plots. + multiway (param.ClassSelector): + A parameter to select the MultiwayStats class for multi-way plots. + sample_sets (SampleSetsTable): # Assuming SampleSetsTable exists elsewhere + The SampleSetsTable object for managing sample set information. + + Methods: + __panel__() -> pn.Column: + Generates the panel for the "Statistics" page with one-way and + multi-way plot accordions. + sidebar() -> pn.Card: + Creates the sidebar panel for the "Statistics" + """ + key = "stats" title = "Statistics" oneway = param.ClassSelector(class_=OnewayStats) @@ -276,22 +447,52 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.oneway = OnewayStats(datastore=self.datastore) self.multiway = MultiwayStats(datastore=self.datastore) + self.sample_sets = self.datastore.sample_sets_table def __panel__(self): + """Returns the main content of the page. + + Returns: + pn.Column: The layout for the main content area. + """ return pn.Column( - pn.Column( - self.oneway.tooltip, - self.oneway, - ), - pn.Column( - self.multiway.tooltip, - self.multiway, + pn.Accordion( + pn.Column( + self.oneway.tooltip, + self.oneway, + name="Oneway Statistics Plot", + ), + pn.Column( + self.multiway.tooltip, + self.multiway, + name="Multiway Statistics Plot", + ), + active=[0, 1], ), ) def sidebar(self): + """Returns the content of the sidebar. + + Returns: + pn.Card: The layout for the sidebar. + """ return pn.Column( - pn.pane.Markdown("# Statistics"), + pn.pane.HTML( + "

Statistics

", + sizing_mode="stretch_width", + ), + pn.pane.Markdown( + ( + "This section provides **population genetic " + "statistics** to analyze genetic variation " + "and divergence among sample sets.

" + "Use the controls below to customize the plots and " + "adjust parameters." + ), + sizing_mode="stretch_width", + ), self.oneway.sidebar, self.multiway.sidebar, + self.sample_sets.sidebar_table, ) diff --git a/src/tseda/vpages/structure.py b/src/tseda/vpages/structure.py index 1f55e579..4f9e4eb9 100644 --- a/src/tseda/vpages/structure.py +++ b/src/tseda/vpages/structure.py @@ -6,10 +6,10 @@ TODO: - add PCA - add parameter to subset sample sets of interest - """ import itertools +from typing import Union import colorcet as cc import holoviews as hv @@ -26,7 +26,15 @@ class GNN(View): - """Make aggregated GNN plot.""" + """Makes the GNN plot. + + Attributes: + warnimng_pane (pn.pane.Alert): A warning message that is activated + if less than two sample sets are selected. + + Methods: + __panel__() -> Union[pn.Column, pn.pane.Alert]: Defines the GNN plot. + """ warning_pane = pn.pane.Alert( """Please select at least 2 samples to visualize this graph. @@ -34,12 +42,20 @@ class GNN(View): alert_type="warning", ) - def __panel__(self): - samples, sample_sets = self.datastore.individuals_table.sample_sets() + def __panel__(self) -> Union[pn.Column, pn.pane.Alert]: + """Returns the GNN cluster plot as a heatmap or a warning message if + less than 2 samples are selected. + + Returns: + Union[pn.Column, pn.pane.Alert]: The layout for the GNN cluster + plot with a descriptive markdown element or a warning message. + """ + sample_sets = self.datastore.individuals_table.sample_sets() + samples = [ + sample for sublist in sample_sets.values() for sample in sublist + ] if len(sample_sets) <= 1: - return pn.Column( - pn.pane.Markdown("## GNN cluster plot\n"), self.warning_pane - ) + return self.warning_pane else: sstable = self.datastore.sample_sets_table.data.rx.value inds = self.datastore.individuals_table.data.rx.value @@ -62,16 +78,30 @@ def __panel__(self): mean_gnn = df.groupby("focal_population").mean() # Z-score normalization here! return pn.Column( - pn.pane.Markdown("## GNN cluster plot\n"), mean_gnn.hvplot.heatmap( cmap=cc.bgy, height=300, responsive=True ), + pn.pane.Markdown( + "**GNN cluster plot** - This heatmap visualizes the " + "genealogical relationships between individuals based on " + "the proportions of their genealogical nearest neighbors " + "(GNN).", + sizing_mode="stretch_width", + ), pn.pane.Markdown("FIXME: dendrogram and Z-score\n"), ) class Fst(View): - """Make Fst plot.""" + """Makes the Fst plot. + + Attributes: + warnimng_pane (pn.pane.Alert): A warning message that is activated + if less than two sample sets are selected. + + Methods: + __panel__() -> Union[pn.Column, pn.pane.Alert]: Defines the Fst plot. + """ warning_pane = pn.pane.Alert( """Please select at least 2 samples to visualize this graph. @@ -79,10 +109,17 @@ class Fst(View): alert_type="warning", ) - def __panel__(self): - samples, sample_sets = self.datastore.individuals_table.sample_sets() + def __panel__(self) -> Union[pn.Column, pn.pane.Alert]: + """Returns the Fst plot as a heatmap or a warning message if less than + 2 samples are selected. + + Returns: + Union[pn.Column, pn.pane.Alert]: The layout for the Fst + plot with a descriptive markdown element or a warning message. + """ + sample_sets = self.datastore.individuals_table.sample_sets() if len(sample_sets) <= 1: - return pn.Column(pn.pane.Markdown("## Fst\n"), self.warning_pane) + return self.warning_pane else: sstable = self.datastore.sample_sets_table.data.rx.value ts = self.datastore.tsm.ts @@ -94,13 +131,29 @@ def __panel__(self): np.reshape(fst, newshape=(k, k)), columns=groups, index=groups ) return pn.Column( - pn.pane.Markdown("## Fst\n"), df.hvplot.heatmap(cmap=cc.bgy, height=300, responsive=True), + pn.pane.Markdown( + "**Fst Plot** - Shows the fixation index (Fst) between " + "different sample sets, allowing comparison of genetic " + "diversity across populations.", + sizing_mode="stretch_width", + ), ) class StructurePage(View): - """Make structure page.""" + """Represents the structure page of the tseda application. + + Attributes: + key (str): A unique identifier for this view within the application. + title (str): The title displayed on the page. + gnn (param.ClassSelector): The gnn plot. + fst (param.ClassSelector): The fst plot. + + Methods: + __panel__() -> pn.Column: Defines the layout of the main content area. + sidebar() -> pn.Column: Defines the layout of the sidebar content area. + """ key = "structure" title = "Structure" @@ -111,9 +164,40 @@ def __init__(self, **params): super().__init__(**params) self.gnn = GNN(datastore=self.datastore) self.fst = Fst(datastore=self.datastore) + self.sample_sets = self.datastore.sample_sets_table + + def __panel__(self) -> pn.Column: + """Returns the main content of the structure page. - def __panel__(self): + Returns: + pn.Column: The layout for the main content area. + """ return pn.Column( - self.gnn, - self.fst, + pn.Accordion( + pn.Column(self.gnn, name="GNN Cluster Plot"), + pn.Column(self.fst, name="Fst"), + active=[0, 1], + ) + ) + + def sidebar(self) -> pn.Column: + """Returns the content of the sidebar. + + Returns: + pn.Column: The layout for the sidebar. + """ + return pn.Column( + pn.pane.HTML( + "

Structure

", + sizing_mode="stretch_width", + ), + pn.pane.Markdown( + ( + "This section provides an analysis of the **population " + "structure** based on genomic data. " + "You can explore two types of plots." + ), + sizing_mode="stretch_width", + ), + self.sample_sets.sidebar_table, ) diff --git a/src/tseda/vpages/trees.py b/src/tseda/vpages/trees.py index af5782ae..11eeeebe 100644 --- a/src/tseda/vpages/trees.py +++ b/src/tseda/vpages/trees.py @@ -1,15 +1,18 @@ -"""Module to plot local trees +"""Tree page structure. -TODO: +This is a module to plot local trees. +TODO: - fix bounds of position / treeid parameters """ import ast +from typing import Tuple, Union import holoviews as hv import panel as pn import param +import tskit from tseda import config @@ -18,12 +21,86 @@ hv.extension("bokeh") -def eval_options(options): - """Evaluate options parameter.""" +def eval_options(options: str) -> dict: + """Converts the option string to a dictionary. + + Args: + options (str): The options inputted by the user. + + Returns: + dict: A dictionary containing the options. + """ return ast.literal_eval(options) class Tree(View): + """This class represents a panel component for visualizing tskit trees. + + Attributes: + search_by (pn.widgets.ToggleGroup): Select the method for searching + for trees. + tree_index (param.Integer): Get tree by zero-based index. + position (param.Integer): Get tree at genome position (bp). + position_index_warning (pn.pane.Alert): Warning message displayed + when position or tree index is invalid. + width (param.Integer): Width of the tree plot. + height (param.Integer): Height of the tree plot. + num_trees (pn.widgets.Select): Select the number of trees to display. + y_axis (pn.widgets.Checkbox): Toggle to include y-axis in the plot. + y_ticks (pn.widgets.Checkbox): Toggle to include y-axis ticks in the + plot. + x_axis (pn.widgets.Checkbox): Toggle to include x-axis in the plot. + sites_mutations (pn.widgets.Checkbox): Toggle to clude sites and + mutations + in the plot. + pack_unselected (pn.widgets.Checkbox): Toggle to pack unselected + sample sets + in the plot. + options_doc (pn.widgets.TooltipIcon): Tooltip explaining advanced + options. + symbol_size (param.Number): Size of the symbols representing tree + nodes. + node_labels (param.String): Dictionary specifying custom labels for + tree nodes. + additional_options (param.String): Dictionary specifying additional + plot options. + advanced_warning (pn.pane.Alert): Warning message displayed when + advanced options + are invalid. + next (param.Action): Action triggered by the "Next tree" button. + prev (param.Action): Action triggered by the "Previous tree" button. + slider (pn.widgets.IntSlider): Slider for selecting chromosome + position. + + Methods: + __init__(self, **params): Initializes the `Tree` class with provided + parameters. + default_css(self): Generates default CSS styles for tree nodes. + next_tree(self): Increments the tree index to display the next tree. + prev_tree(self): Decrements the tree index to display the previous + tree. + check_inputs(self): Raises a ValueError if position or tree index is + invalid. + handle_advanced(self): Processes advanced options for plotting. + update_slider(self): Updates the slider value based on the selected + position. + update_position(self): Updates the position based on the slider value. + plot_tree(self, tree, omit_sites, y_ticks, node_labels, + additional_options): Generates + the HTML plot for a single tree with specified options. + get_all_trees(self, trees): Constructs a panel layout displaying all + provided trees. + multiple_trees(self): Adjusts layout and options for displaying + multiple trees. + advanced_options(self): Defines the layout for the advanced options + in the sidebar. + __panel__(self): Defines the layout of the main content on the page. + update_sidebar(self): Created the sidebar based on the chosen search + method. + sidebar(self): Calls the update_sidebar method whenever chosen search + method changes. + """ + search_by = pn.widgets.ToggleGroup( name="Search By", options=["Position", "Tree Index"], @@ -31,26 +108,81 @@ class Tree(View): button_type="primary", ) - tree_index = param.Integer(default=0, doc="Get tree by zero-based index") + tree_index = param.Integer( + default=0, + doc="""Get tree by zero-based index. If multiple trees are + shown, this is the index of the first tree.""", + ) position = param.Integer( - default=None, doc="Get tree at genome position (bp)" + default=None, + doc="""Get tree at genome position (bp). If multiple trees are + shown, this is the position of the first tree.""", ) - warning_pane = pn.pane.Alert( - "The input for position or tree index is out of bounds.", + position_index_warning = pn.pane.Alert( + """The input for position or tree index is + out of bounds for the specified number + of trees.""", alert_type="warning", visible=False, ) width = param.Integer(default=750, doc="Width of the tree plot") height = param.Integer(default=520, doc="Height of the tree plot") - options = param.String( - default="{'y_axis': 'time', 'node_labels': {}}", + + num_trees = pn.widgets.Select( + name="Number of trees", + options=[1, 2, 3, 4, 5, 6], + value=1, + description="""Select the number of trees to display. The first tree + will represent your selected chromosome position or tree index.""", + ) + + y_axis = pn.widgets.Checkbox(name="Include y-axis", value=True) + y_ticks = pn.widgets.Checkbox(name="Include y-ticks", value=True) + x_axis = pn.widgets.Checkbox(name="Include x-axis", value=False) + sites_mutations = pn.widgets.Checkbox( + name="Include sites and mutations", value=True + ) + pack_unselected = pn.widgets.Checkbox( + name="Pack unselected sample sets", value=False, width=197 + ) + options_doc = pn.widgets.TooltipIcon( + value=( + """Select various elements to include in your graph. + Pack unselected sample sets: Selecting this option + will allow large polytomies involving unselected + samples to be summarised as a dotted line. Selection + of samples and sample sets can be done on the + Individuals page.""" + ), + ) + + symbol_size = param.Number(default=8, bounds=(0, None), doc="Symbol size") + + node_labels = param.String( + default="{}", doc=( - "Additional options for configuring tree plot. " - "Must be a valid dictionary string." + """Show custom labels for the nodes (specified by ID). + Any nodes not present will not have a label. + Examle: {1: 'label1', 2: 'label2',...}""" ), ) + additional_options = param.String( + default="{}", + doc=( + """Add more options as specified by the documentation. + Must be a valid dictionary. + Examle: {'title': 'My Tree',...}""" + ), + ) + + advanced_warning = pn.pane.Alert( + "The inputs for the advanced options are not valid.", + alert_type="warning", + visible=False, + ) + next = param.Action( lambda x: x.next_tree(), doc="Next tree", label="Next tree" ) @@ -58,85 +190,340 @@ class Tree(View): lambda x: x.prev_tree(), doc="Previous tree", label="Previous tree" ) - symbol_size = param.Number(default=8, bounds=(0, None), doc="Symbol size") + slider = pn.widgets.IntSlider(name="Chromosome Position") - def next_tree(self): - self.position = None - self.tree_index = min( - self.datastore.tsm.ts.num_trees - 1, int(self.tree_index) + 1 - ) - # pyright: ignore[reportOperatorIssue] - - def prev_tree(self): - self.position = None - self.tree_index = max(0, int(self.tree_index) - 1) - # pyright: ignore[reportOperatorIssue] + def __init__(self, **params): + super().__init__(**params) + self.slider.end = int(self.datastore.tsm.ts.sequence_length - 1) @property - def default_css(self): - """Default css styles for tree nodes""" + def default_css(self) -> str: + """Default css styles for tree nodes. + + Returns: + str: A string with the css styling. + """ styles = [] sample_sets = self.datastore.sample_sets_table.data.rx.value individuals = self.datastore.individuals_table.data.rx.value sample2ind = self.datastore.individuals_table.sample2ind + selected_sample_sets = self.datastore.individuals_table.sample_sets() + selected_samples = [ + int(i) + for sublist in list(selected_sample_sets.values()) + for i in sublist + ] for n in self.datastore.individuals_table.samples(): ssid = individuals.loc[sample2ind[n]].sample_set_id ss = sample_sets.loc[ssid] - s = f".node.n{n} > .sym " + "{" + f"fill: {ss.color} " + "}" + if n in selected_samples: + s = ( + f".node.n{n} > .sym " + + "{" + + f"fill: {ss.color}; stroke: black; stroke-width: 2px;" + + "}" + ) + else: + s = f".node.n{n} > .sym " + "{" + f"fill: {ss.color} " + "}" styles.append(s) css_string = " ".join(styles) return css_string - @param.depends("position", "tree_index", watch=True) + def next_tree(self): + """Increments the tree index to display the next tree.""" + self.position = None + self.tree_index = min( + self.datastore.tsm.ts.num_trees - self.num_trees.value, + int(self.tree_index) + 1, + ) # pyright: ignore[reportOperatorIssue] + + def prev_tree(self): + """Decrements the tree index to display the previous tree.""" + self.position = None + self.tree_index = max(0, int(self.tree_index) - 1) # pyright: ignore[reportOperatorIssue] + def check_inputs(self): - if self.position is not None and ( - int(self.position) < 0 - or int(self.position) > self.datastore.tsm.ts.sequence_length - ): - self.warning_pane.visible = True - raise ValueError - if ( - self.tree_index is not None - and int(self.tree_index) < 0 - or int(self.tree_index) > self.datastore.tsm.ts.num_trees + """Checks the inputs for position and tree index. + + Raises + ValueError: If the position or tree index is invalid. + """ + if self.position is not None: + if ( + int(self.position) < 0 + or int(self.position) >= self.datastore.tsm.ts.sequence_length + ): + raise ValueError + elif int( + self.datastore.tsm.ts.at(self.position).index + + self.num_trees.value + ) > int(self.datastore.tsm.ts.num_trees): + raise ValueError + if self.tree_index is not None and ( + int(self.tree_index) < 0 + or int(self.tree_index) + int(self.num_trees.value) + > self.datastore.tsm.ts.num_trees ): - self.warning_pane.visible = True raise ValueError else: - self.warning_pane.visible = False + self.position_index_warning.visible = False - @param.depends( - "width", "height", "position", "options", "symbol_size", "tree_index" - ) - def __panel__(self): - options = eval_options(self.options) - if self.position is not None: - tree = self.datastore.tsm.ts.at(self.position) - self.tree_index = tree.index + def handle_advanced(self) -> Tuple[bool, Union[dict, None]]: + """Handles advanced options so that they are returned in the correct + format. + + Returns + bool: Whether mutations & sites should be included in the tree. + Union[dict, None]: Specified the option for ticks on the y-axis. + """ + if self.sites_mutations.value is True: + omit_sites = not self.sites_mutations.value + else: + omit_sites = True + if self.y_ticks.value is True: + y_ticks = None else: - tree = self.datastore.tsm.ts.at_index(self.tree_index) + y_ticks = {} + if self.node_labels == "": + self.node_labels = "{}" + if self.additional_options == "": + self.additional_options = "{}" + return omit_sites, y_ticks + + @param.depends("position", watch=True) + def update_slider(self): + """Updates the slider value based on the selected position.""" + if self.position is not None: + self.slider.value = self.position + + @param.depends("slider.value_throttled", watch=True) + def update_position(self): + """Updates the position based on the slider value.""" + self.position = self.slider.value + + def plot_tree( + self, + tree: tskit.trees.Tree, + omit_sites: bool, + y_ticks: Union[None, dict], + node_labels: dict, + additional_options: dict, + ) -> Union[pn.Accordion, pn.Column]: + """Plots a single tree. + + Arguments: + tree (tskit.trees.Tree): The tree to be plotted. + omit_sites (bool): If sites & mutaions should be included in the + plot. + y_ticks (Union[None, dict]): If y_ticks should be included in the + plot. + nodel_labels (dict): Any customised node labels. + additional_options (dict): Any additional plotting options. + + Returns: + Union[pn.Accordion, pn.Column]: A panel element containing the + tree. + """ + try: + plot = tree.draw_svg( + size=(self.width, self.height), + symbol_size=self.symbol_size, + y_axis=self.y_axis.value, + x_axis=self.x_axis.value, + omit_sites=omit_sites, + node_labels=node_labels, + y_ticks=y_ticks, + pack_untracked_polytomies=self.pack_unselected.value, + style=self.default_css, + **additional_options, + ) + self.advanced_warning.visible = False + except (ValueError, SyntaxError, TypeError): + plot = tree.draw_svg( + size=(self.width, self.height), + y_axis=True, + node_labels={}, + style=self.default_css, + ) + self.advanced_warning.visible = True pos1 = int(tree.get_interval()[0]) pos2 = int(tree.get_interval()[1]) - 1 + if int(self.num_trees.value) > 1: + return pn.Accordion( + pn.Column( + pn.pane.HTML(plot), + name=f"Tree index {tree.index} (position {pos1} - {pos2})", + ), + active=[0], + ) + else: + return pn.Column( + pn.pane.HTML( + f"

Tree index {tree.index}" + f" (position {pos1} - {pos2})

", + sizing_mode="stretch_width", + ), + pn.pane.HTML(plot), + ) + + def get_all_trees(self, trees: list) -> Union[None, pn.Column]: + """Returns all trees in columns and rows. + + Arguments: + trees: A list of all trees to be displayed. + + Returns: + Union[None, pn.Column]: A column of rows with trees, + if there are any trees to display. + """ + if not trees: + return None + rows = [pn.Row(*trees[i : i + 2]) for i in range(0, len(trees), 2)] + return pn.Column(*rows) + + @param.depends("num_trees.value", watch=True) + def multiple_trees(self): + """Sets the default setting depending on if one or several trees are + displayed.""" + if int(self.num_trees.value) > 1: + self.width = 470 + self.height = 470 + self.y_axis.value = False + self.x_axis.value = False + self.y_ticks.value = False + self.sites_mutations.value = False + self.pack_unselected.value = True + self.symbol_size = 6 + else: + self.width = 750 + self.height = 520 + self.y_axis.value = True + self.x_axis.value = False + self.y_ticks.value = True + self.sites_mutations.value = True + self.pack_unselected.value = False + self.symbol_size = 8 + + def advanced_options(self): + """Defined the content of the advanced options card in the sidebar.""" + doc_link = ( + "https://tskit.dev/tskit/docs/" + "stable/python-api.html#tskit.TreeSequence.draw_svg" + ) + sidebar_content = pn.Column( + pn.Card( + pn.pane.HTML( + f"""See the + tskit documentation for more information + about these plotting options.""" + ), + self.num_trees, + pn.Row(pn.pane.HTML("Options", width=30), self.options_doc), + self.x_axis, + self.y_axis, + self.y_ticks, + self.sites_mutations, + self.pack_unselected, + self.param.symbol_size, + self.param.node_labels, + self.param.additional_options, + collapsed=True, + title="Advanced plotting options", + header_background=config.SIDEBAR_BACKGROUND, + active_header_background=config.SIDEBAR_BACKGROUND, + styles=config.VCARD_STYLE, + ), + self.advanced_warning, + ) + return sidebar_content + + @param.depends( + "width", + "height", + "position", + "symbol_size", + "tree_index", + "num_trees.value", + "y_axis.value", + "y_ticks.value", + "x_axis.value", + "sites_mutations.value", + "pack_unselected.value", + "node_labels", + "additional_options", + "slider.value_throttled", + ) + def __panel__(self) -> pn.Column: + """Returns the main content of the Trees page. + + Returns: + pn.Column: The layout for the main content area. + + Raises: + ValueError: If inputs are not in the correct format + """ + try: + self.check_inputs() + except ValueError: + self.position_index_warning.visible = True + raise ValueError("Inputs for position or tree index are not valid") + + sample_sets = self.datastore.individuals_table.sample_sets() + selected_samples = [ + int(i) for sublist in list(sample_sets.values()) for i in sublist + ] + if len(selected_samples) < 1: + self.pack_unselected.value = False + self.pack_unselected.disabled = True + else: + self.pack_unselected.disabled = False + omit_sites, y_ticks = self.handle_advanced() + try: + node_labels = eval_options(self.node_labels) + additional_options = eval_options(self.additional_options) + except (ValueError, SyntaxError, TypeError): + node_labels = None + additional_options = None + self.advanced_warning.visible = True + trees = [] + for i in range(self.num_trees.value): + if self.position is not None: + tree = self.datastore.tsm.ts.at(self.position) + self.tree_index = tree.index + tree = self.datastore.tsm.ts.at_index( + (tree.index + i), tracked_samples=selected_samples + ) + else: + tree = self.datastore.tsm.ts.at_index( + int(self.tree_index) + i, tracked_samples=selected_samples + ) + self.slider.value = int(tree.get_interval()[0]) + trees.append( + self.plot_tree( + tree, omit_sites, y_ticks, node_labels, additional_options + ) + ) + all_trees = self.get_all_trees(trees) return pn.Column( + all_trees, pn.pane.Markdown( - f"## Tree index {self.tree_index} (position {pos1} - {pos2})" - ), - pn.pane.HTML( - tree.draw_svg( - size=(self.width, self.height), - symbol_size=self.symbol_size, - style=self.default_css, - **options, - ), + """**Tree plot** - Lorem Ipsum... + Selected samples are marked with a black outline.""" ), + self.slider, pn.Row( self.param.prev, self.param.next, ), ) - def update_sidebar(self): - """Dynamically update the sidebar based on searchBy value.""" + def update_sidebar(self) -> pn.Column: + """Renders the content of the sidebar based on searchBy value. + + Returns: + pn.Column: The sidebar content. + """ if self.search_by.value == "Tree Index": self.position = None fields = [self.param.tree_index] @@ -149,24 +536,42 @@ def update_sidebar(self): *fields, self.param.width, self.param.height, - self.param.options, - self.param.symbol_size, collapsed=False, title="Tree plotting options", header_background=config.SIDEBAR_BACKGROUND, active_header_background=config.SIDEBAR_BACKGROUND, styles=config.VCARD_STYLE, ), - self.warning_pane, + self.position_index_warning, ) return sidebar_content @param.depends("search_by.value", watch=True) - def sidebar(self): + def sidebar(self) -> pn.Column: + """Makes sure the sidebar is updated whenever the search-by value is + toggled. + + Returns: + pn.Column: The sidebar content. + """ return self.update_sidebar() class TreesPage(View): + """Represents the trees page of the tseda application. + + Attributes: + key (str): A unique identifier for this view within the application. + title (str): The title displayed on the page. + data (param.ClassSelector): The main content of the page. + + Methods: + __init__(self, **params): Initializes the `TreesPage` class with + provided parameters. + __panel__() -> pn.Column: Defines the layout of the main content area. + sidebar() -> pn.Column: Defines the layout of the sidebar content area. + """ + key = "trees" title = "Trees" data = param.ClassSelector(class_=Tree) @@ -177,12 +582,36 @@ def __init__(self, **params): self.sample_sets = self.datastore.sample_sets_table def __panel__(self): + """Returns the main content of the page. + + Returns: + pn.Column: The layout for the main content area. + """ return pn.Column( self.data, ) def sidebar(self): + """Returns the content of the sidebar. + + Returns: + pn.Column: The layout for the sidebar. + """ return pn.Column( + pn.pane.HTML( + "

Trees

", + sizing_mode="stretch_width", + ), + pn.pane.Markdown( + ( + "This section allows you to explore local genealogical " + "trees.

" + "Use the controls below to customize the plots and adjust" + "parameters." + ), + sizing_mode="stretch_width", + ), self.data.sidebar, + self.data.advanced_options, self.sample_sets.sidebar_table, ) diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 00000000..7fbe4a19 Binary files /dev/null and b/tests/.DS_Store differ diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 4f58cb37..df43b563 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -14,7 +14,7 @@ def test_datastore_preprocess(tsm): individuals_table, sample_sets_table = datastore.preprocess(tsm) assert individuals_table is not None assert sample_sets_table is not None - samples, sample_sets = individuals_table.sample_sets() + sample_sets = individuals_table.sample_sets() assert len(sample_sets) == 6 np.testing.assert_equal(sample_sets[1], np.arange(0, 12)) @@ -28,7 +28,7 @@ def test_individuals_table(individuals_table): assert ind.name == 5 assert individuals_table.sample2ind[ind.nodes[0]] == 5 assert individuals_table.sample2ind[ind.nodes[1]] == 5 - _, ss = individuals_table.sample_sets() + ss = individuals_table.sample_sets() assert len(ss) == 6 assert len(ss[0]) == 12 samples = list(individuals_table.samples()) diff --git a/tests/test_model.py b/tests/test_model.py index 90e78ba5..9c96ff7e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -21,12 +21,12 @@ def test_sample_set_init(ts): for pop in ts.populations(): ss = model.SampleSet(pop.id, population=pop) assert ss is not None - assert ss.id == pop.id + assert ss.sample_set_id == pop.id assert ss.name == json.loads(pop.metadata.decode())["population"] assert ss.color == ss.colormap[pop.id] ss = model.SampleSet(0, name="test") assert ss is not None - assert ss.id == 0 + assert ss.sample_set_id == 0 assert ss.name == "test" assert ss.population is None diff --git a/tests/test_trees.py b/tests/test_trees.py index e292193f..c691eda1 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -19,4 +19,7 @@ def test_treespage(treespage): def test_tree(tree): - assert ".node.n26 > .sym {fill: #e4ae38 }" in tree.default_css + assert ( + ".node.n26 > .sym {fill: #e4ae38}" in tree.default_css + or ".node.n26 > .sym {fill: #e4ae38}; stroke: black; stroke-width: 2px;" + ) diff --git a/tests/test_ui.py b/tests/test_ui.py index 60359d5e..a0aacc18 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -23,24 +23,29 @@ def test_component(page, port, ds): page.set_viewport_size({"width": 1920, "height": 1080}) - page.get_by_role("button", name="Sample Sets").click() - expect(page.get_by_text("New sample set name")).to_be_visible() - expect(page.get_by_text("predefined")).to_be_visible() - - page.get_by_role("button", name="Individuals").click() - expect(page.get_by_text("Data modification")).to_be_visible() - expect(page.get_by_text("Population ID")).to_be_visible() + page.get_by_role("button", name="Individuals & sets").click() + time.sleep(10) + expect(page.get_by_text("Geomap").nth(0)).to_be_visible() + expect(page.get_by_text("Original population ID").nth(0)).to_be_visible() + expect(page.get_by_text("Create new sample set").nth(0)).to_be_visible() page.get_by_role("button", name="Structure").click() - expect(page.get_by_text("GNN cluster plot")).to_be_visible() + time.sleep(10) + expect(page.get_by_text("GNN cluster plot").nth(0)).to_be_visible() + expect(page.get_by_text("Structure").nth(0)).to_be_visible() page.get_by_role("button", name="iGNN").click() - expect(page.get_by_text("Sample sets table quick view")).to_be_visible() + time.sleep(10) + expect( + page.get_by_text("Sample sets table quick view").nth(0) + ).to_be_visible() page.get_by_role("button", name="Statistics").click() + time.sleep(10) expect( - page.get_by_text("Oneway statistics plotting options") + page.get_by_text("Oneway statistics plotting options").nth(0) ).to_be_visible() page.get_by_role("button", name="Trees").click() - expect(page.get_by_text("Tree plotting options")).to_be_visible() + time.sleep(10) + expect(page.get_by_text("Tree plotting options").nth(0)).to_be_visible()