1- from typing import List , Union , Optional , Dict , MutableMapping , Any , Set
1+ from typing import List , Union , Optional , Dict , MutableMapping , Any , Set , Tuple
22import itertools
33import math
44import numbers
@@ -1239,15 +1239,35 @@ def __attrs_post_init__(self):
12391239 'if time_units!="generations", generation_time must be specified'
12401240 )
12411241
1242- def __getitem__ (self , deme_name ) :
1242+ def __getitem__ (self , deme_name : Name ) -> Deme :
12431243 """
1244- Return the :class:`.Deme` with the specified name.
1244+ Get the :class:`.Deme` with the specified name.
1245+
1246+ .. code::
1247+
1248+ graph = demes.load("gutenkunst_ooa.yml")
1249+ yri = graph["YRI"]
1250+ print(yri)
1251+
1252+ :param str deme_name: The name of the deme.
1253+ :rtype: Deme
1254+ :return: The deme.
12451255 """
12461256 return self ._deme_map [deme_name ]
12471257
1248- def __contains__ (self , deme_name ) :
1258+ def __contains__ (self , deme_name : Name ) -> bool :
12491259 """
12501260 Check if the graph contains a deme with the specified name.
1261+
1262+ .. code::
1263+
1264+ graph = demes.load("gutenkunst_ooa.yml")
1265+ if "CHB" in graph:
1266+ print("Deme CHB is in the graph")
1267+
1268+ :param str deme_name: The name of the deme.
1269+ :rtype: bool
1270+ :return: ``True`` if the deme is in the graph, ``False`` otherwise.
12511271 """
12521272 return deme_name in self ._deme_map
12531273
@@ -1611,10 +1631,52 @@ def _add_pulse(self, *, source, dest, proportion, time) -> Pulse:
16111631 self .pulses .append (new_pulse )
16121632 return new_pulse
16131633
1614- def _migration_matrices (self ):
1634+ def migration_matrices (self ) -> Tuple [ List [ List [ List [ float ]]], List [ Number ]] :
16151635 """
1616- Return a list of migration matrices, and a list of end times that
1617- partition them. The start time for the first matrix is inf.
1636+ Get the migration matrices and the end times that partition them.
1637+
1638+ Returns a list of matrices, one for each time interval
1639+ over which migration rates do not change, in time-descending
1640+ order (from most ancient to most recent). For a migration matrix list
1641+ :math:`M`, the migration rate is :math:`M[i][j][k]` from deme
1642+ :math:`k` into deme :math:`j` during the :math:`i` 'th time interval.
1643+ The order of the demes' indices in each matrix matches the
1644+ order of demes in the graph's deme list (I.e. deme :math:`j`
1645+ corresponds to ``Graph.demes[j]``).
1646+
1647+ There is always at least one migration matrix in the list, even when
1648+ the graph defines no migrations.
1649+
1650+ A list of end times to which the matrices apply is also
1651+ returned. The time intervals to which the migration rates apply are an
1652+ open-closed interval ``(start_time, end_time]``, where the start time
1653+ of the first matrix is ``inf`` and the start time of subsequent
1654+ matrices match the end time of the previous matrix in the list.
1655+
1656+ .. note::
1657+ The last entry of the list of end times is always ``0``,
1658+ even when all demes in the graph go extinct before time ``0``.
1659+
1660+
1661+ .. code::
1662+
1663+ graph = demes.load("gutenkunst_ooa.yml")
1664+ mm_list, end_times = graph.migration_matrices()
1665+ start_times = [math.inf] + end_times[:-1]
1666+ assert len(mm_list) == len(end_times) == len(start_times)
1667+ deme_ids = {deme.name: j for j, deme in enumerate(graph.demes)}
1668+ j = deme_ids["YRI"]
1669+ k = deme_ids["CEU"]
1670+ for mm, start_time, end_time in zip(mm_list, start_times, end_times):
1671+ print(
1672+ f"CEU -> YRI migration rate is {mm[j][k]} during the "
1673+ f"time interval ({start_time}, {end_time}]"
1674+ )
1675+
1676+ :return: A 2-tuple of ``(mm_list, end_times)``,
1677+ where ``mm_list`` is a list of migration matrices,
1678+ and ``end_times`` are a list of end times for each matrix.
1679+ :rtype: tuple[list[list[list[float]]], list[float]]
16181680 """
16191681 uniq_times = set (migration .start_time for migration in self .migrations )
16201682 uniq_times .update (migration .end_time for migration in self .migrations )
@@ -1624,7 +1686,7 @@ def _migration_matrices(self):
16241686 # Extend to t=0 even when there are no migrations.
16251687 end_times .append (0 )
16261688 n = len (self .demes )
1627- mm_list = [[[0 ] * n for _ in range (n )] for _ in range (len (end_times ))]
1689+ mm_list = [[[0.0 ] * n for _ in range (n )] for _ in range (len (end_times ))]
16281690 deme_id = {deme .name : j for j , deme in enumerate (self .demes )}
16291691 for migration in self .migrations :
16301692 start_time = math .inf
@@ -1640,7 +1702,7 @@ def _migration_matrices(self):
16401702 f"source={ migration .source } , dest={ migration .dest } "
16411703 f"between start_time={ start_time } , end_time={ end_time } "
16421704 )
1643- mm_list [k ][dest_id ][source_id ] = migration .rate
1705+ mm_list [k ][dest_id ][source_id ] = float ( migration .rate )
16441706 start_time = end_time
16451707 return mm_list , end_times
16461708
@@ -1650,7 +1712,7 @@ def _check_migration_rates(self):
16501712 deme in any interval of time.
16511713 """
16521714 start_time = math .inf
1653- mm_list , end_times = self ._migration_matrices ()
1715+ mm_list , end_times = self .migration_matrices ()
16541716 for migration_matrix , end_time in zip (mm_list , end_times ):
16551717 for j , row in enumerate (migration_matrix ):
16561718 row_sum = sum (row )
0 commit comments