-
Notifications
You must be signed in to change notification settings - Fork 1
Optimize kscore with reusable bitset #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,8 +1,89 @@ | ||||||||||
| from typing import Dict, Optional | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def kmers(seq, k=4): | ||||||||||
| n_kmers = len(seq) - k + 1 | ||||||||||
| for i in range(n_kmers): | ||||||||||
| yield seq[i : (i + k)] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| _BASE_TO_BITS: Dict[str, int] = { | ||||||||||
| "A": 0, | ||||||||||
| "C": 1, | ||||||||||
| "G": 2, | ||||||||||
| "T": 3, | ||||||||||
| "a": 0, | ||||||||||
| "c": 1, | ||||||||||
| "g": 2, | ||||||||||
| "t": 3, | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| _BITSET_CACHE: Dict[int, bytearray] = {} | ||||||||||
| _MAX_BITSET_SIZE = 1 << 22 # 4 Mi entries (4 MiB) per cached bitset | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _kscore_fallback(seq: str, k: int) -> float: | ||||||||||
| seq_len = len(seq) | ||||||||||
| if seq_len == 0: | ||||||||||
| return 0.0 | ||||||||||
| return len(set(kmers(seq, k=k))) / seq_len | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _bitset_for_k(k: int) -> Optional[bytearray]: | ||||||||||
| size = 1 << (2 * k) | ||||||||||
| if size > _MAX_BITSET_SIZE: | ||||||||||
| return None | ||||||||||
|
|
||||||||||
| bitset = _BITSET_CACHE.get(k) | ||||||||||
| if bitset is None or len(bitset) != size: | ||||||||||
| bitset = bytearray(size) | ||||||||||
| _BITSET_CACHE[k] = bitset | ||||||||||
| return bitset | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def kscore(seq, k=4): | ||||||||||
| return len(set(kmers(seq))) / len(seq) | ||||||||||
| seq_len = len(seq) | ||||||||||
| if seq_len == 0: | ||||||||||
| return 0.0 | ||||||||||
|
|
||||||||||
| if k <= 0: | ||||||||||
| return _kscore_fallback(seq, k) | ||||||||||
|
|
||||||||||
| if seq_len < k: | ||||||||||
| return 0.0 | ||||||||||
|
|
||||||||||
| bitset = _bitset_for_k(k) | ||||||||||
| if bitset is None: | ||||||||||
| return _kscore_fallback(seq, k) | ||||||||||
|
|
||||||||||
| mask = (1 << (2 * k)) - 1 | ||||||||||
| touched = [] | ||||||||||
| unique = 0 | ||||||||||
| rolling_value = 0 | ||||||||||
| window_len = 0 | ||||||||||
| use_fallback = False | ||||||||||
|
|
||||||||||
| for base in seq: | ||||||||||
| bits = _BASE_TO_BITS.get(base) | ||||||||||
| if bits is None: | ||||||||||
| use_fallback = True | ||||||||||
| break | ||||||||||
|
|
||||||||||
|
||||||||||
| # Shift the previous rolling value left by 2 bits (to make room for the new base), | |
| # OR in the new base bits, and apply the mask to keep only k bases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number calculation
(1 << (2 * k)) - 1lacks explanation. Consider adding a comment explaining this creates a bitmask for k bases where each base uses 2 bits (e.g., for k=4, this creates 0xFF to keep the last 8 bits).