From 1904b9463cb10cbfd03b053bdfc19bb4cfa2b1b6 Mon Sep 17 00:00:00 2001 From: Xunmo Yang Date: Fri, 15 May 2026 10:49:58 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 916081867 --- meterstick_custom_metrics.ipynb | 2671 ++++++++++++++++--------------- metrics.py | 911 +++++++---- models.py | 1 + operations.py | 1145 +++++++++---- sql.py | 373 +++-- utils.py | 131 +- 6 files changed, 3022 insertions(+), 2210 deletions(-) diff --git a/meterstick_custom_metrics.ipynb b/meterstick_custom_metrics.ipynb index 8fc08f9..5257fe3 100644 --- a/meterstick_custom_metrics.ipynb +++ b/meterstick_custom_metrics.ipynb @@ -1,100 +1,127 @@ { "cells": [ { + "id": "28c0e400", "cell_type": "markdown", - "metadata": { - "id": "fbocqoWfpVPg" - }, "source": [ "The Meterstick package provides a concise and flexible syntax to describe and execute\n", "routine data analysis tasks. This notebooks explains how to implement custom `Metric`s or `Operation`s in Meterstick." - ] + ], + "metadata": { + "id": "fbocqoWfpVPg" + } }, { + "id": "df5ace76", "cell_type": "markdown", - "metadata": { - "id": "1KBTLgUrzmS7" - }, "source": [ "# For External users\n", "\n", "You can open this notebook in [Google Colab](https://colab.research.google.com/github/google/meterstick/blob/master/meterstick_custom_metrics.ipynb)." - ] + ], + "metadata": { + "id": "1KBTLgUrzmS7" + } }, { + "id": "71b48521", "cell_type": "markdown", - "metadata": { - "id": "H9ojnghz0b2N" - }, "source": [ "## Installation\n", "\n", "You can install from pip for the stable version" - ] + ], + "metadata": { + "id": "H9ojnghz0b2N" + } }, { + "id": "21b471d0", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ohmnh0qRz6bS" - }, - "outputs": [], "source": [ "!pip install meterstick" - ] + ], + "metadata": { + "id": "ohmnh0qRz6bS" + }, + "execution_count": null, + "outputs": [] }, { + "id": "dc686842", "cell_type": "markdown", - "metadata": { - "id": "MZXKtCHy0CEo" - }, "source": [ "or from GitHub for the latest version." - ] + ], + "metadata": { + "id": "MZXKtCHy0CEo" + } }, { + "id": "ee246255", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uQRaNJ2h0NvF" - }, - "outputs": [], "source": [ "!git clone https://github.com/google/meterstick.git\n", "import sys, os\n", "sys.path.append(os.getcwd())" - ] + ], + "metadata": { + "id": "uQRaNJ2h0NvF" + }, + "execution_count": null, + "outputs": [] }, { + "id": "59c850a2", "cell_type": "markdown", - "metadata": { - "id": "te-lKCw20P41" - }, "source": [ "# Setup" - ] + ], + "metadata": { + "id": "te-lKCw20P41" + } }, { + "id": "e1ec9e64", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0UI9rAtZnBUG" - }, - "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import scipy\n", + "from typing import List, Optional, Text, Union\n", "\n", "from meterstick import *" - ] + ], + "metadata": { + "id": "0UI9rAtZnBUG" + }, + "execution_count": null, + "outputs": [] }, { + "id": "93bdd936", "cell_type": "code", - "execution_count": null, + "source": [ + "np.random.seed(0)\n", + "platform = ('Desktop', 'Mobile', 'Tablet')\n", + "exprs = ('ctrl', 'expr')\n", + "country = ('US', 'non-US')\n", + "size = 80\n", + "impressions = np.random.randint(10, 20, size)\n", + "clicks = impressions * 0.1 * np.random.random(size)\n", + "df = pd.DataFrame({'impressions': impressions, 'clicks': clicks})\n", + "df['platform'] = np.random.choice(platform, size=size)\n", + "df['expr_id'] = np.random.choice(exprs, size=size)\n", + "df['country'] = np.random.choice(country, size=size)\n", + "df['cookie'] = np.random.choice(range(3), size=size)\n", + "\n", + "df.loc[df.country == 'US', 'clicks'] *= 2\n", + "df.loc[(df.country == 'US') \u0026 (df.platform == 'Desktop'), 'impressions'] *= 4\n", + "df.head()" + ], "metadata": { "colab": { "height": 206 @@ -112,15 +139,16 @@ "id": "AqPQnAdJU2wd", "outputId": "18641488-870f-48bc-c032-c288f534c7e2" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
impressionsclicksplatformexpr_idcountrycookie
0152.985899MobileexprUS0
1101.163701TabletctrlUS1
2521.077358DesktopexprUS2
3521.234214DesktopexprUS0
4171.059967Mobilectrlnon-US0
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -343,31 +371,11 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "np.random.seed(0)\n", - "platform = ('Desktop', 'Mobile', 'Tablet')\n", - "exprs = ('ctrl', 'expr')\n", - "country = ('US', 'non-US')\n", - "size = 80\n", - "impressions = np.random.randint(10, 20, size)\n", - "clicks = impressions * 0.1 * np.random.random(size)\n", - "df = pd.DataFrame({'impressions': impressions, 'clicks': clicks})\n", - "df['platform'] = np.random.choice(platform, size=size)\n", - "df['expr_id'] = np.random.choice(exprs, size=size)\n", - "df['country'] = np.random.choice(country, size=size)\n", - "df['cookie'] = np.random.choice(range(3), size=size)\n", - "\n", - "df.loc[df.country == 'US', 'clicks'] *= 2\n", - "df.loc[(df.country == 'US') & (df.platform == 'Desktop'), 'impressions'] *= 4\n", - "df.head()" ] }, { + "id": "34e4b138", "cell_type": "markdown", - "metadata": { - "id": "YVuQS1w4JGOX" - }, "source": [ "# Level of Caching\n", "\n", @@ -376,11 +384,28 @@ "1. Class level: different `Metric` classes don't share the same result while different instances of the same class might, *if and only if they have the same attributes that matters*.\n", "\n", "Class level caching is a bit tricky so by default it's only enabled for built-in `Metric`s, and custom `Metric`s only enjoy instance level caching. If you want to enable class level caching on a custom `Metric`, you need to manually set `cache_across_instances` to `True` and register all attributes that have impat on the result into `additional_fingerprint_attrs`." - ] + ], + "metadata": { + "id": "YVuQS1w4JGOX" + } }, { + "id": "fe6a864a", "cell_type": "code", - "execution_count": null, + "source": [ + "# SumWithTrace has instance level caching.\n", + "class SumWithTrace(Sum):\n", + " def __init__(self, *args, **kwargs):\n", + " super(SumWithTrace, self).__init__(*args, **kwargs)\n", + "\n", + " def compute_through(self, data, split_by):\n", + " print('Computing %s...' % self.name)\n", + " return super(SumWithTrace, self).compute_through(data, split_by)\n", + "\n", + "sum_clicks = SumWithTrace('clicks')\n", + "ctr = sum_clicks / SumWithTrace('impressions')\n", + "MetricList((sum_clicks, ctr)).compute_on(df)" + ], "metadata": { "colab": { "height": 115 @@ -398,6 +423,7 @@ "id": "bGfZTsiUZk2k", "outputId": "6c8194b7-7edf-45d3-9dc1-9ebce28a8363" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -411,10 +437,10 @@ "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sum(clicks)sum(clicks) / sum(impressions)
097.2493680.056871
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -589,25 +615,18 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "# SumWithTrace has instance level caching.\n", - "class SumWithTrace(Sum):\n", - " def __init__(self, *args, **kwargs):\n", - " super(SumWithTrace, self).__init__(*args, **kwargs)\n", - "\n", - " def compute_through(self, data, split_by):\n", - " print('Computing %s...' % self.name)\n", - " return super(SumWithTrace, self).compute_through(data, split_by)\n", - "\n", - "sum_clicks = SumWithTrace('clicks')\n", - "ctr = sum_clicks / SumWithTrace('impressions')\n", - "MetricList((sum_clicks, ctr)).compute_on(df)" ] }, { + "id": "385ad42d", "cell_type": "code", - "execution_count": null, + "source": [ + "# SumWithTrace doesn't have class level caching. Below sum of clicks is computed\n", + "# twice because we have two instances of it.\n", + "sum_clicks = SumWithTrace('clicks')\n", + "ctr = SumWithTrace('clicks') / SumWithTrace('impressions')\n", + "MetricList((sum_clicks, ctr)).compute_on(df)" + ], "metadata": { "colab": { "height": 133 @@ -625,6 +644,7 @@ "id": "QGewAuqaVS6y", "outputId": "7e0d9999-d6ba-4d17-f137-e258c4ff17ba" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -639,10 +659,10 @@ "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sum(clicks)sum(clicks) / sum(impressions)
097.2493680.056871
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -817,18 +837,36 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "# SumWithTrace doesn't have class level caching. Below sum of clicks is computed\n", - "# twice because we have two instances of it.\n", - "sum_clicks = SumWithTrace('clicks')\n", - "ctr = SumWithTrace('clicks') / SumWithTrace('impressions')\n", - "MetricList((sum_clicks, ctr)).compute_on(df)" ] }, { + "id": "6321b529", "cell_type": "code", - "execution_count": null, + "source": [ + "# With class level caching enabled, SumWithClassLevelCaching('clicks') will only\n", + "# be computed once.\n", + "class SumWithClassLevelCaching(Metric):\n", + "\n", + " def __init__(self, var):\n", + " self.var = var\n", + " # Register all attributes that have impact to the result number.\n", + " super(SumWithClassLevelCaching, self).__init__(\n", + " name=f'Sum of {var}', additional_fingerprint_attrs=['var']\n", + " )\n", + " # Enable class level caching.\n", + " self.cache_across_instances = True\n", + "\n", + " def compute(self, data):\n", + " print('Computing %s...' % self.name)\n", + " return df[self.var].sum()\n", + "\n", + "\n", + "sum_clicks = SumWithClassLevelCaching('clicks')\n", + "ctr = SumWithClassLevelCaching('clicks') / SumWithClassLevelCaching(\n", + " 'impressions'\n", + ")\n", + "MetricList((sum_clicks, ctr)).compute_on(df)" + ], "metadata": { "colab": { "height": 115 @@ -846,6 +884,7 @@ "id": "oOKLtf22amAR", "outputId": "99d55826-892b-4f19-9c27-d2408cb606c2" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -859,10 +898,10 @@ "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Sum of clicksSum of clicks / Sum of impressions
097.2493680.056871
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -1037,36 +1076,15 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "# With class level caching enabled, SumWithClassLevelCaching('clicks') will only\n", - "# be computed once.\n", - "class SumWithClassLevelCaching(Metric):\n", - "\n", - " def __init__(self, var):\n", - " self.var = var\n", - " # Register all attributes that have impact to the result number.\n", - " super(SumWithClassLevelCaching, self).__init__(\n", - " name=f'Sum of {var}', additional_fingerprint_attrs=['var']\n", - " )\n", - " # Enable class level caching.\n", - " self.cache_across_instances = True\n", - "\n", - " def compute(self, data):\n", - " print('Computing %s...' % self.name)\n", - " return df[self.var].sum()\n", - "\n", - "\n", - "sum_clicks = SumWithClassLevelCaching('clicks')\n", - "ctr = SumWithClassLevelCaching('clicks') / SumWithClassLevelCaching(\n", - " 'impressions'\n", - ")\n", - "MetricList((sum_clicks, ctr)).compute_on(df)" ] }, { + "id": "57548e42", "cell_type": "code", - "execution_count": null, + "source": [ + "# Check additional_fingerprint_attrs for what attributes are registered.\n", + "sum_clicks.additional_fingerprint_attrs" + ], "metadata": { "executionInfo": { "elapsed": 55, @@ -1081,6 +1099,7 @@ "id": "oqOUyKxERqeu", "outputId": "6c7d9e69-20f0-4d34-85be-7d5965901e4e" }, + "execution_count": null, "outputs": [ { "data": { @@ -1092,28 +1111,22 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "# Check additional_fingerprint_attrs for what attributes are registered.\n", - "sum_clicks.additional_fingerprint_attrs" ] }, { + "id": "eeafbd45", "cell_type": "markdown", - "metadata": { - "id": "AQjJAr3YcQB2" - }, "source": [ "#Custom Metric\n", "We provide many Metrics out of box but we understand there are cases you need more, so we make it easy for you to write you own Metrics.\n", "First you need to understand the dataflow of a DataFrame when it's passed to compute_on(). The dataflow looks like this.\n", "\n", - "\t\t <-------------------------------------------compute_on(handles caching)---------------------------------------------->\n", - "\t\t <-------------------------------------compute_through-----------------------------------> |\n", - "\t\t | <------compute_slices------> | |\n", - "\t\t | |-> slice1 -> compute | | | |\n", - "\t\tdf -> df.query(where) -> precompute -> split_data -|-> slice2 -> compute | -> pd.concat -> postcompute -> manipulate -> final_compute\n", - "\t\t |-> ...\t\t\t |\n", + "\t\t \u003c-------------------------------------------compute_on(handles caching)----------------------------------------------\u003e\n", + "\t\t \u003c-------------------------------------compute_through-----------------------------------\u003e |\n", + "\t\t | \u003c------compute_slices------\u003e | |\n", + "\t\t | |-\u003e slice1 -\u003e compute | | | |\n", + "\t\tdf -\u003e df.query(where) -\u003e precompute -\u003e split_data -|-\u003e slice2 -\u003e compute | -\u003e pd.concat -\u003e postcompute -\u003e manipulate -\u003e final_compute\n", + "\t\t |-\u003e ...\t\t\t |\n", "\n", "In summary, compute() operates on a slice of data and hence only takes one arg, df. While precompute(), postcompute(), compute_slices(), compute_through() and final_compute() operate on the whole DataFrame so they take the df that has been processed by the dataflow till them and the split_by passed to compute_on(). final_compute() also has access to the original df passed to compute_on() for you to make additional manipulation. manipulate() does common data manipulation like melting and cleaning. Besides wrapping all the computations above, compute_on() also caches the result from compute_through(). Please refer to the section of Caching for more details.\n", "\n", @@ -1128,20 +1141,41 @@ "Also there are some requirements.\n", "1. Your Metric shouldn't change the input DataFrame inplace or it might not work with other Metrics.\n", "2. Your Metric shouldn't rely on the index of the input DataFrame if you want it to work with Jackknife. The reason is Jackknife might reset the index.\n" - ] + ], + "metadata": { + "id": "AQjJAr3YcQB2" + } }, { + "id": "eb84e7db", "cell_type": "markdown", - "metadata": { - "id": "a_imRCi1gYa6" - }, "source": [ "## No Vectorization" - ] + ], + "metadata": { + "id": "a_imRCi1gYa6" + } }, { + "id": "5cc017f4", "cell_type": "code", - "execution_count": null, + "source": [ + "class CustomSum(Metric):\n", + "\n", + " def __init__(self, var):\n", + " name = 'custom sum(%s)' % var\n", + " super(CustomSum, self).__init__(name, additional_fingerprint_attrs=['var'])\n", + " self.var = var\n", + " # For custom Metrics, class-level caching needs to be manually enabled.\n", + " # See the Caching section for more information.\n", + " self.cache_across_instances = True\n", + "\n", + " def compute(self, df):\n", + " return df[self.var].sum()\n", + "\n", + "\n", + "CustomSum('clicks').compute_on(df, 'country')" + ], "metadata": { "colab": { "height": 143 @@ -1159,15 +1193,16 @@ "id": "uAgrTxLDfh3z", "outputId": "a96e488f-1ecf-4fa0-f05f-4597b0a1578d" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
custom sum(clicks)
country
US69.095494
non-US28.153874
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -1350,28 +1385,14 @@ "metadata": {}, "output_type": "execute_result" } - ], + ] + }, + { + "id": "55df961f", + "cell_type": "code", "source": [ - "class CustomSum(Metric):\n", - "\n", - " def __init__(self, var):\n", - " name = 'custom sum(%s)' % var\n", - " super(CustomSum, self).__init__(name, additional_fingerprint_attrs=['var'])\n", - " self.var = var\n", - " # For custom Metrics, class-level caching needs to be manually enabled.\n", - " # See the Caching section for more information.\n", - " self.cache_across_instances = True\n", - "\n", - " def compute(self, df):\n", - " return df[self.var].sum()\n", - "\n", - "\n", - "CustomSum('clicks').compute_on(df, 'country')" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "Sum('clicks').compute_on(df, 'country')" + ], "metadata": { "colab": { "height": 143 @@ -1389,15 +1410,16 @@ "id": "XgYXgTOTgPwo", "outputId": "5ee040aa-203e-4982-951a-5c61b0638fc8" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sum(clicks)
country
US69.095494
non-US28.153874
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -1580,23 +1602,26 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "Sum('clicks').compute_on(df, 'country')" ] }, { + "id": "f778d8ee", "cell_type": "markdown", - "metadata": { - "id": "-hvgLLKbglrP" - }, "source": [ "CustomSum doesn't have vectorization. It loops through the DataFrame and sum on every slice. As the result, it's slower than vectorized summation." - ] + ], + "metadata": { + "id": "-hvgLLKbglrP" + } }, { + "id": "654fc59c", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "CustomSum('clicks').compute_on(df, 'country')" + ], "metadata": { "executionInfo": { "elapsed": 9776, @@ -1611,6 +1636,7 @@ "id": "6MVo9EtTgg0e", "outputId": "25631988-0fac-4161-cb6d-b6970a4a911e" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -1619,16 +1645,16 @@ "1000 loops, best of 5: 1.56 ms per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "CustomSum('clicks').compute_on(df, 'country')" ] }, { + "id": "e343c867", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "Sum('clicks').compute_on(df, 'country')" + ], "metadata": { "executionInfo": { "elapsed": 5126, @@ -1643,6 +1669,7 @@ "id": "cLrEAl2Tgi_W", "outputId": "403f93dc-ccc7-4281-a409-1fd15eceee51" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -1651,16 +1678,16 @@ "1000 loops, best of 5: 758 µs per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "Sum('clicks').compute_on(df, 'country')" ] }, { + "id": "24d08cc1", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "df.groupby('country')['clicks'].sum()" + ], "metadata": { "executionInfo": { "elapsed": 2968, @@ -1675,6 +1702,7 @@ "id": "qdhczptcg6VO", "outputId": "059361eb-85fe-4c8f-fe8a-8cf20c448eef" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -1683,27 +1711,44 @@ "1000 loops, best of 5: 467 µs per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "df.groupby('country')['clicks'].sum()" ] }, { + "id": "67d97160", "cell_type": "markdown", - "metadata": { - "id": "5tUfc996hNBm" - }, "source": [ "## With Vectorization\n", "\n", "We can do better. Let's implement a Sum with vectorization." - ] + ], + "metadata": { + "id": "5tUfc996hNBm" + } }, { + "id": "b3e11294", "cell_type": "code", - "execution_count": null, + "source": [ + "class VectorizedSum(Metric):\n", + "\n", + " def __init__(self, var):\n", + " name = 'vectorized sum(%s)' % var\n", + " super(VectorizedSum, self).__init__(\n", + " name=name, additional_fingerprint_attrs=['var']\n", + " )\n", + " self.var = var\n", + " # For custom Metrics, class-level caching needs to be manually enabled.\n", + " # See the Caching section for more information.\n", + " self.cache_across_instances = True\n", + "\n", + " def compute_slices(self, df, split_by):\n", + " if split_by:\n", + " return df.groupby(split_by)[self.var].sum()\n", + " return df[self.var].sum()\n", + "\n", + "\n", + "VectorizedSum('clicks').compute_on(df, 'country')" + ], "metadata": { "colab": { "height": 143 @@ -1721,15 +1766,16 @@ "id": "1ny3uHTuhXAJ", "outputId": "80959aff-9d30-4da6-bf71-e03ef1d1dd55" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
vectorized sum(clicks)
country
US69.095494
non-US28.153874
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -1912,32 +1958,16 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "class VectorizedSum(Metric):\n", - "\n", - " def __init__(self, var):\n", - " name = 'vectorized sum(%s)' % var\n", - " super(VectorizedSum, self).__init__(\n", - " name=name, additional_fingerprint_attrs=['var']\n", - " )\n", - " self.var = var\n", - " # For custom Metrics, class-level caching needs to be manually enabled.\n", - " # See the Caching section for more information.\n", - " self.cache_across_instances = True\n", - "\n", - " def compute_slices(self, df, split_by):\n", - " if split_by:\n", - " return df.groupby(split_by)[self.var].sum()\n", - " return df[self.var].sum()\n", - "\n", - "\n", - "VectorizedSum('clicks').compute_on(df, 'country')" ] }, { + "id": "5dd18a9f", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "VectorizedSum('clicks').compute_on(df, 'country')" + ], "metadata": { "executionInfo": { "elapsed": 4752, @@ -1952,6 +1982,7 @@ "id": "VekZlay-hoZ6", "outputId": "a7a21a4d-a5e5-41de-cf0f-d8101a489e11" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -1960,27 +1991,48 @@ "1000 loops, best of 5: 749 µs per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "VectorizedSum('clicks').compute_on(df, 'country')" ] }, { + "id": "9aca4193", "cell_type": "markdown", - "metadata": { - "id": "FtRTh1lLiA2W" - }, "source": [ "## Precompute, postcompute and final_compute\n", "\n", "They are useful when you need to preprocess and postprocess the data." - ] + ], + "metadata": { + "id": "FtRTh1lLiA2W" + } }, { + "id": "58d3c16f", "cell_type": "code", - "execution_count": null, + "source": [ + "class USOnlySum(Sum):\n", + "\n", + " def precompute(self, df, split_by):\n", + " return df[df.country == 'US']\n", + "\n", + " def postcompute(self, data, split_by):\n", + " print('Inside postcompute():')\n", + " print('Input data: ', data)\n", + " print('Input split_by: ', split_by)\n", + " print('\\n')\n", + " return data\n", + "\n", + " def final_compute(self, res, melted, return_dataframe, split_by, df):\n", + " # res is the result processed by the dataflow till now. df is the original\n", + " # DataFrme passed to compute_on().\n", + " print('Inside final_compute():')\n", + " for country in df.country.unique():\n", + " if country not in res.index:\n", + " print('Country \"%s\" is missing!' % country)\n", + " return res\n", + "\n", + "\n", + "USOnlySum('clicks').compute_on(df, 'country')" + ], "metadata": { "colab": { "height": 268 @@ -1998,6 +2050,7 @@ "id": "OrcVt-gviQv5", "outputId": "48f4d594-3eae-4900-e80e-b9769434e127" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -2018,10 +2071,10 @@ "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sum(clicks)
country
US69.095494
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -2199,38 +2252,11 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "class USOnlySum(Sum):\n", - "\n", - " def precompute(self, df, split_by):\n", - " return df[df.country == 'US']\n", - "\n", - " def postcompute(self, data, split_by):\n", - " print('Inside postcompute():')\n", - " print('Input data: ', data)\n", - " print('Input split_by: ', split_by)\n", - " print('\\n')\n", - " return data\n", - "\n", - " def final_compute(self, res, melted, return_dataframe, split_by, df):\n", - " # res is the result processed by the dataflow till now. df is the original\n", - " # DataFrme passed to compute_on().\n", - " print('Inside final_compute():')\n", - " for country in df.country.unique():\n", - " if country not in res.index:\n", - " print('Country \"%s\" is missing!' % country)\n", - " return res\n", - "\n", - "\n", - "USOnlySum('clicks').compute_on(df, 'country')" ] }, { + "id": "295f72b9", "cell_type": "markdown", - "metadata": { - "id": "8TJkDyF6o2aW" - }, "source": [ "##Custom Operation\n", "Writing a custom `Operation` is more complex. Typically an `Operation` needs to compute some util `Metric`s. A common one is its child `Metric`. The tricky part is how to make sure the additional computations interact correctly with the cache. First take a look at the Caching section above to understand how caching works in `Meterstick`. Then here is a decision tree to help you.\n", @@ -2250,7 +2276,7 @@ " +-------------------------------+ ↓ \n", " ↓ +----------------------+\n", " | |Call compute_on() or |\n", - " |------------N----->|compute_on_sql() on |\n", + " |------------N-----\u003e|compute_on_sql() on |\n", " | |the util Metric. Set |\n", " Y |precomputable_in_jk_bs|\n", " | |to False. |\n", @@ -2298,22 +2324,21 @@ " * set attribute `precomputable_in_jk_bs` to `False`, which will stop `Jackknife`/`Bootstrap` from taking the shortcuts, and make the computation slower.\n", "1. An `Operation` might not be precomputable even all the leaf `Metric`s are `Sum` and/or `Count`. It's not easy to decide. The easiest way to check is just set the `precomputable_in_jk_bs` to `True` and try `Metric`s like\n", "\n", - " - `Jackknife(..., Operation(Dot('x', 'y', where='x>2')))` and\n", - " - `Jackknife(..., Operation(Dot('x', 'y', where='x>2')), enable_optimization=False)`.\n", + " - `Jackknife(..., Operation(Dot('x', 'y', where='x\u003e2')))` and\n", + " - `Jackknife(..., Operation(Dot('x', 'y', where='x\u003e2')), enable_optimization=False)`.\n", "\n", " If the first one computes and gives the same result to the second one, the `Operation` is precomputable. See the doc of `Operation` for the attribute for\n", "more details if you're curious.\n", "\n", "That's a lot to digest. Let's see two examples. Below we implement `Distribution` with and without vectorization." - ] + ], + "metadata": { + "id": "8TJkDyF6o2aW" + } }, { + "id": "4e8cb4d6", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QWsivJZMpvgO" - }, - "outputs": [], "source": [ "class DistributionWithVectorization(Operation):\n", " \"\"\"Computes the normalized values of a Metric over column(s).\"\"\"\n", @@ -2346,15 +2371,16 @@ " children.groupby(level=split_by).sum() if split_by else children.sum()\n", " )\n", " return children / total" - ] + ], + "metadata": { + "id": "QWsivJZMpvgO" + }, + "execution_count": null, + "outputs": [] }, { + "id": "7ccef624", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PxdJ-ZX_NFVu" - }, - "outputs": [], "source": [ "class DistributionNoVectorization(Operation):\n", " \"\"\"Computes the normalized values of a Metric over column(s).\"\"\"\n", @@ -2374,26 +2400,28 @@ " def compute(self, slice_of_children):\n", " total = slice_of_children.sum()\n", " return slice_of_children / total" - ] + ], + "metadata": { + "id": "PxdJ-ZX_NFVu" + }, + "execution_count": null, + "outputs": [] }, { + "id": "fcc5ea01", "cell_type": "markdown", - "metadata": { - "id": "5adfiJUjxIfa" - }, "source": [ "## SQL Generation\n", "\n", "If you want the custom Metric to generate SQL query, you need to implement to_sql() or get_sql_and_with_clause(). The latter is more common and recommended. Please refer to built-in Metrics to see how it should be implemented. Here we show two examples, one for Metric and the other for Operation." - ] + ], + "metadata": { + "id": "5adfiJUjxIfa" + } }, { + "id": "92bacfa9", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Neiqbr-wxlej" - }, - "outputs": [], "source": [ "class SumWithSQL(SimpleMetric):\n", "\n", @@ -2437,11 +2465,25 @@ " columns = Column(self.var, 'SUM({})', self.name, local_filter)\n", " # Returns a Sql instance and the WITH clause it needs.\n", " return Sql(columns, table, global_filter, split_by), with_data" - ] + ], + "metadata": { + "id": "Neiqbr-wxlej" + }, + "execution_count": null, + "outputs": [] }, { + "id": "d946325d", "cell_type": "code", - "execution_count": null, + "source": [ + + "m = Sum('clicks') - SumWithSQL('clicks', 'custom_sum')\n", + "m.compute_on_sql(\n", + " 'T',\n", + " 'platform',\n", + " execute=lambda sql: pd.read_sql(text(sql), engine.connect())\n", + ")" + ], "metadata": { "colab": { "height": 175 @@ -2459,15 +2501,16 @@ "id": "rEFXXMFtCgh5", "outputId": "ed2eec40-38f9-4b5b-e968-700cc6b47bcd" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sum(clicks) - custom_sum
platform
Desktop0.0
Mobile0.0
Tablet0.0
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -2655,64 +2698,21 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - - "m = Sum('clicks') - SumWithSQL('clicks', 'custom_sum')\n", - "m.compute_on_sql(\n", - " 'T',\n", - " 'platform',\n", - " execute=lambda sql: pd.read_sql(text(sql), engine.connect())\n", - ")" ] }, { + "id": "3bb27740", "cell_type": "markdown", - "metadata": { - "id": "rzUpMSEr8n4z" - }, "source": [ "For an Operation, you ususally call the child metrics' get_sql_and_with_clause() to get the subquery you need." - ] + ], + "metadata": { + "id": "rzUpMSEr8n4z" + } }, { + "id": "266f0483", "cell_type": "code", - "execution_count": null, - "metadata": { - "executionInfo": { - "elapsed": 59, - "status": "ok", - "timestamp": 1752749803027, - "user": { - "displayName": "Xunmo Yang", - "userId": "12474546967758012552" - }, - "user_tz": 420 - }, - "id": "5dKXoRHZ_Xi4", - "outputId": "50aa82e5-e70c-48e3-ee48-cfe6ec7ab880" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "WITH\n", - "DistributionRaw AS (SELECT\n", - " country,\n", - " SUM(clicks) AS sum_clicks\n", - "FROM T\n", - "GROUP BY country)\n", - "SELECT\n", - " country,\n", - " SAFE_DIVIDE(sum_clicks, SUM(sum_clicks) OVER ()) AS Distribution_of_sum_clicks\n", - "FROM DistributionRaw" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "class DistributionWithSQL(Operation):\n", "\n", @@ -2792,22 +2792,57 @@ "\n", "m = DistributionWithSQL('country', Sum('clicks'))\n", "m.to_sql('T')" + ], + "metadata": { + "executionInfo": { + "elapsed": 59, + "status": "ok", + "timestamp": 1752749803027, + "user": { + "displayName": "Xunmo Yang", + "userId": "12474546967758012552" + }, + "user_tz": 420 + }, + "id": "5dKXoRHZ_Xi4", + "outputId": "50aa82e5-e70c-48e3-ee48-cfe6ec7ab880" + }, + "execution_count": null, + "outputs": [ + { + "data": { + "text/plain": [ + "WITH\n", + "DistributionRaw AS (SELECT\n", + " country,\n", + " SUM(clicks) AS sum_clicks\n", + "FROM T\n", + "GROUP BY country)\n", + "SELECT\n", + " country,\n", + " SAFE_DIVIDE(sum_clicks, SUM(sum_clicks) OVER ()) AS Distribution_of_sum_clicks\n", + "FROM DistributionRaw" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } ] }, { + "id": "db4a9946", "cell_type": "markdown", - "metadata": { - "id": "H6MJwaLlZ79Y" - }, "source": [ "# More Examples" - ] + ], + "metadata": { + "id": "H6MJwaLlZ79Y" + } }, { + "id": "832d5674", "cell_type": "markdown", - "metadata": { - "id": "XzlSJO0OrBQE" - }, "source": [ "## Operation with confidence interval\n", "\n", @@ -2864,24 +2899,24 @@ "```\n", "\n", "Below we implement an `Operation` that computes the standard deviation from t distribution." - ] + ], + "metadata": { + "id": "XzlSJO0OrBQE" + } }, { + "id": "5858433f", "cell_type": "markdown", - "metadata": { - "id": "LyOjLkutSr1q" - }, "source": [ "### Student t distribution" - ] + ], + "metadata": { + "id": "LyOjLkutSr1q" + } }, { + "id": "1f24488a", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "g5_106uBrIAr" - }, - "outputs": [], "source": [ "class TDistribution(MetricWithCI):\n", "\n", @@ -2926,11 +2961,21 @@ " def get_stderrs(bucket_estimates):\n", " dof = bucket_estimates.count(axis=1) - 1\n", " return bucket_estimates.sem(1), dof" - ] + ], + "metadata": { + "id": "g5_106uBrIAr" + }, + "execution_count": null, + "outputs": [] }, { + "id": "fdfb0212", "cell_type": "code", - "execution_count": null, + "source": [ + + "m = TDistribution(('cookie', 'platform'), Mean('clicks'), 0.9)\n", + "m.compute_on(df).display()" + ], "metadata": { "colab": { "height": 142 @@ -2948,21 +2993,22 @@ "id": "6hGiqcXDTt-w", "outputId": "c3acc646-5339-4d3a-9caa-bad939e05edd" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ - "
" + " \u003c/script\u003e" ], "text/plain": [ - "" + "\u003cIPython.core.display.HTML object\u003e" ] }, "metadata": {}, @@ -2995,7 +3041,7 @@ { "data": { "text/html": [ - "" + "\u003c/style\u003e" ], "text/plain": [ - "" + "\u003cIPython.core.display.HTML object\u003e" ] }, "metadata": {}, @@ -3128,25 +3174,25 @@ { "data": { "text/html": [ - "" + "\u003cstyle\u003e\u003c/style\u003e" ], "text/plain": [ - "" + "\u003cIPython.core.display.HTML object\u003e" ] }, "metadata": {}, "output_type": "display_data" } - ], - "source": [ - - "m = TDistribution(('cookie', 'platform'), Mean('clicks'), 0.9)\n", - "m.compute_on(df).display()" ] }, { + "id": "5a22370d", "cell_type": "code", - "execution_count": null, + "source": [ + "# The confidence interval matches t distribution.\n", + "df_mean = df.groupby(['cookie', 'platform']).mean(numeric_only=True).clicks\n", + "tuple(np.round(scipy.stats.t.interval(0.9, len(df_mean) - 1, df_mean.mean(), df_mean.sem()), 10))" + ], "metadata": { "executionInfo": { "elapsed": 53, @@ -3161,6 +3207,7 @@ "outputId": "2151a8cc-f3aa-47e6-f4f8-23e986cbf514", "id": "5efY83v_QyqW" }, + "execution_count": null, "outputs": [ { "data": { @@ -3172,27 +3219,27 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "# The confidence interval matches t distribution.\n", - "df_mean = df.groupby(['cookie', 'platform']).mean(numeric_only=True).clicks\n", - "tuple(np.round(scipy.stats.t.interval(0.9, len(df_mean) - 1, df_mean.mean(), df_mean.sem()), 10))" ] }, { + "id": "29ab30ac", "cell_type": "markdown", - "metadata": { - "id": "cK3efhANrto5" - }, "source": [ "### Paired t-test\n", "\n", "When you apply TDistribution to AbsoluteChange, it's equivalent to a paired t-test." - ] + ], + "metadata": { + "id": "cK3efhANrto5" + } }, { + "id": "6238b3a8", "cell_type": "code", - "execution_count": null, + "source": [ + "m = TDistribution('cookie', AbsoluteChange('expr_id', 'ctrl', Sum('clicks')), .9)\n", + "m.compute_on(df)" + ], "metadata": { "colab": { "height": 143 @@ -3210,15 +3257,16 @@ "id": "FPjhdvVQUM9W", "outputId": "a33c1dcf-9020-48e1-830b-9f6f0bbfb409" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Metricsum(clicks) Absolute Change
Valuet-distribution CI-lowert-distribution CI-upper
expr_id
expr3.707903-8.88832116.304128
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -3411,15 +3459,17 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "m = TDistribution('cookie', AbsoluteChange('expr_id', 'ctrl', Sum('clicks')), .9)\n", - "m.compute_on(df)" ] }, { + "id": "8adae31a", "cell_type": "code", - "execution_count": null, + "source": [ + + "df_mean = df.groupby(['expr_id', 'cookie']).sum(numeric_only=True).clicks\n", + "t = scipy.stats.ttest_rel(df_mean.expr, df_mean.ctrl)\n", + "t.confidence_interval(0.9)" + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -3437,6 +3487,7 @@ "id": "wWRbG4vpyhSd", "outputId": "22dfc0d4-c8d8-4012-b6cf-2a5832624230" }, + "execution_count": null, "outputs": [ { "data": { @@ -3448,32 +3499,23 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - - "df_mean = df.groupby(['expr_id', 'cookie']).sum(numeric_only=True).clicks\n", - "t = scipy.stats.ttest_rel(df_mean.expr, df_mean.ctrl)\n", - "t.confidence_interval(0.9)" ] }, { + "id": "f44ef5dc", "cell_type": "markdown", - "metadata": { - "id": "GqiuYVJGGsPT" - }, "source": [ "## Linear Regression\n", "\n", "Here we fit a univariate linear regression on mean values of groups. We show two versions, the former delegates computations to Mean so its Jackknife is faster than the latter which doesn't delegate." - ] + ], + "metadata": { + "id": "GqiuYVJGGsPT" + } }, { + "id": "31123075", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yPn6aNBzAbC_" - }, - "outputs": [], "source": [ "np.random.seed(42)\n", "size = 1000000\n", @@ -3482,11 +3524,21 @@ "df_lin['y'] = 2 * df_lin.x + np.random.random(size=size)\n", "df_lin['cookie'] = np.random.choice(range(20), size=size)\n", "df_lin_mean = df_lin.groupby('grp').mean()" - ] + ], + "metadata": { + "id": "yPn6aNBzAbC_" + }, + "execution_count": null, + "outputs": [] }, { + "id": "935f8843", "cell_type": "code", - "execution_count": null, + "source": [ + + "plt.scatter(df_lin_mean.x, df_lin_mean.y)\n", + "plt.show()" + ], "metadata": { "colab": { "height": 265 @@ -3504,27 +3556,57 @@ "id": "qQZV_quPGvct", "outputId": "54f7456b-2308-4b58-9def-0f46053b8584" }, + "execution_count": null, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsT\nAAALEwEAmpwYAAATiUlEQVR4nO3df2xdd33G8eeZ62qX0sllMaV2MlKmyiPQta6srCxaVSjFSVbR\nrEJbso11DCmA2q1MyFvMJJj2T5E8GBtFVFnbtWglhYFjqi3UrQpSYYIOJw44JXjNSqG+zhoX5LaM\nK5GEz/7wsbHNvbHvD/vcfO/7JVn3nM/59fFV/eT2e849xxEhAEC6finvBgAAa4ugB4DEEfQAkDiC\nHgASR9ADQOIuyLuBcjZs2BCbN2/Ouw0AOG8cPnz4hYjoLLesKYN+8+bNGhsby7sNADhv2P5+pWUr\nDt3Y3mT7K7aP237K9h1Z/VW2H7P9dPZ6SYXtt9uetH3C9r7afw0AQC1WM0Z/RtIHIuL1kq6VdJvt\nLZL2SXo8Iq6Q9Hg2v4TtNkmflLRD0hZJe7JtAQDrZMWgj4iTEXEkm35Z0nFJ3ZJulvRAttoDknaV\n2XyrpBMR8UxE/FTSQ9l2AIB1UtVVN7Y3S+qV9KSkSyPipDT3j4GkV5fZpFvSc4vmp7JauX3vtT1m\ne2xmZqaatgAA57DqoLf9SklfkPT+iHhptZuVqZW9uU5E7I+Ivojo6+wse+IYAFCDVV11Y7tdcyH/\nYEQMZ+XnbV8WESdtXybpVJlNpyRtWjS/UdJ0PQ0DQGpGxosaGp3U9GxJXR0FDfT3aFdv2cGPmqzm\nqhtLulfS8Yj42KJFD0u6NZu+VdIXy2z+TUlX2L7c9oWSdmfbAQA0F/KDwxMqzpYUkoqzJQ0OT2hk\nvNiwY6xm6GabpHdKeovto9nPTkkfkXSj7acl3ZjNy3aX7UOSFBFnJN0uaVRzJ3E/FxFPNax7ADjP\nDY1OqnT67JJa6fRZDY1ONuwYKw7dRMTXVH6sXZJuKLP+tKSdi+YPSTpUa4MAkLLp2VJV9VpwrxsA\nyFFXR6Gqei0IegDI0UB/jwrtbUtqhfY2DfT3NOwYTXmvGwBoFfNX16zlVTcEPQDkbFdvd0ODfTmG\nbgAgcQQ9ACSOoAeAxBH0AJA4gh4AEkfQA0DiCHoASBxBDwCJI+gBIHEEPQAkjqAHgMQR9ACQOIIe\nABJH0ANA4la8TbHt+yTdJOlURLwxq31W0vxd8TskzUbE1WW2fVbSy5LOSjoTEX0N6RoAGmBkvLim\n94FvFqu5H/39ku6S9On5QkT8wfy07Y9KevEc2785Il6otUEAWAsj40UNDk8sPJi7OFvS4PCEJCUX\n9isO3UTEE5J+VG6ZbUv6fUkHGtwXAKypodHJhZCfVzp9VkOjkzl1tHbqHaP/HUnPR8TTFZaHpEdt\nH7a991w7sr3X9pjtsZmZmTrbAoBzm54tVVU/n9Ub9Ht07k/z2yLiGkk7JN1m+7pKK0bE/ojoi4i+\nzs7OOtsCgHPr6ihUVT+f1Rz0ti+QdIukz1ZaJyKms9dTkg5K2lrr8QCgkQb6e1Rob1tSK7S3aaC/\np8IW5696PtG/VdJ3I2Kq3ELbF9m+eH5a0tskHavjeADQMLt6u3XnLVequ6MgS+ruKOjOW65M7kSs\ntLrLKw9Iul7SBttTkj4cEfdK2q1lwza2uyTdExE7JV0q6eDc+VpdIOkzEfFIY9sHgNrt6u1OMtiX\nWzHoI2JPhfqflqlNS9qZTT8j6ao6+wMA1IlvxgJA4gh6AEgcQQ8AiSPoASBxBD0AJI6gB4DEEfQA\nkDiCHgASR9ADQOIIegBIHEEPAIkj6AEgcQQ9ACRuNQ8HB4CGGxkvamh0UtOzJXV1FDTQ39MStwzO\nA0EPYN2NjBc1ODyx8HDu4mxJg8MTkkTYrwGGbgCsu6HRyYWQn1c6fVZDo5M5dZQ2gh7AupueLVVV\nR30IegDrrqujUFUd9Vkx6G3fZ/uU7WOLan9ru2j7aPazs8K2221P2j5he18jGwdw/hro71GhvW1J\nrdDepoH+npw6SttqPtHfL2l7mfo/RMTV2c+h5Qttt0n6pKQdkrZI2mN7Sz3NAkjDrt5u3XnLleru\nKMiSujsKuvOWKzkRu0ZW83DwJ2xvrmHfWyWdyB4SLtsPSbpZ0ndq2BeAxOzq7SbY10k9Y/S32/52\nNrRzSZnl3ZKeWzQ/ldUAAOuo1qD/lKRfl3S1pJOSPlpmHZepRaUd2t5re8z22MzMTI1tAQCWqyno\nI+L5iDgbET+T9M+aG6ZZbkrSpkXzGyVNn2Of+yOiLyL6Ojs7a2kLAFBGTUFv+7JFs78n6ViZ1b4p\n6Qrbl9u+UNJuSQ/XcjwAQO1WPBlr+4Ck6yVtsD0l6cOSrrd9teaGYp6V9J5s3S5J90TEzog4Y/t2\nSaOS2iTdFxFPrcUvAQCozBEVh81z09fXF2NjY3m3AQDnDduHI6Kv3DK+GQsAiSPoASBxBD0AJI6g\nB4DEEfQAkDiCHgASR9ADQOIIegBIHA8HB1rMyHhRQ6OTmp4tqaujoIH+Hm4XnDiCHmghI+NFDQ5P\nLDyYuzhb0uDwhCQR9glj6AZoIUOjkwshP690+qyGRidz6gjrgaAHWsj0bKmqOtJA0AMtpKujUFUd\naSDogRYy0N+jQnvbklqhvU0D/T05dYT1wMlYoIXMn3DlqpvWQtADLWZXbzfB3mIYugGAxBH0AJA4\ngh4AErdi0Nu+z/Yp28cW1YZsf9f2t20ftN1RYdtnbU/YPmqbh8ACQA5W84n+fknbl9Uek/TGiPhN\nSf8tafAc2785Iq6u9NBaAMDaWjHoI+IJST9aVns0Is5ks9+QtHENegMANEAjxuj/TNKXKiwLSY/a\nPmx777l2Ynuv7THbYzMzMw1oCwAg1Rn0tv9G0hlJD1ZYZVtEXCNph6TbbF9XaV8RsT8i+iKir7Oz\ns562AACL1Bz0tm+VdJOkP4qIKLdORExnr6ckHZS0tdbjAQBqU1PQ294u6a8lvT0iflJhnYtsXzw/\nLeltko6VWxcAsHZWc3nlAUlfl9Rje8r2uyXdJeliSY9ll07ena3bZftQtumlkr5m+1uS/kvSf0TE\nI2vyWwAAKlrxXjcRsadM+d4K605L2plNPyPpqrq6AwDUjW/GAkDiCHoASBy3KQbWych4kfvAIxcE\nPbAORsaLGhyeWHgwd3G2pMHhCUki7LHmGLoB1sHQ6ORCyM8rnT6rodHJnDpCKyHogXUwPVuqqg40\nEkEPrIOujkJVdaCRCHpgHQz096jQ3rakVmhv00B/T04doZVwMhZYB/MnXLnqBnkg6IF1squ3m2BH\nLhi6AYDEEfQAkDiCHgASR9ADQOIIegBIHEEPAIkj6AEgcQQ9ACRuNc+Mvc/2KdvHFtVeZfsx209n\nr5dU2Ha77UnbJ2zva2TjAIDVWc0n+vslbV9W2yfp8Yi4QtLj2fwSttskfVLSDklbJO2xvaWubgEA\nVVsx6CPiCUk/Wla+WdID2fQDknaV2XSrpBMR8UxE/FTSQ9l2AIB1VOsY/aURcVKSstdXl1mnW9Jz\ni+anshoAYB2t5clYl6lFxZXtvbbHbI/NzMysYVsA0FpqDfrnbV8mSdnrqTLrTEnatGh+o6TpSjuM\niP0R0RcRfZ2dnTW2BQBYrtbbFD8s6VZJH8lev1hmnW9KusL25ZKKknZL+sMajwfUbGS8yH3g0dJW\nc3nlAUlfl9Rje8r2uzUX8DfaflrSjdm8bHfZPiRJEXFG0u2SRiUdl/S5iHhqbX4NoLyR8aIGhydU\nnC0pJBVnSxocntDIeDHv1oB144iKw+a56evri7GxsbzbQAK2feTLKpZ5AHd3R0H/ue8tOXQErA3b\nhyOir9wyvhmLpE2XCflz1YEUEfRIWldHoao6kCKCHkkb6O9Rob1tSa3Q3qaB/p6cOgLWHw8HR9Lm\nr67hqhu0MoIeydvV202wo6UxdAMAiSPoASBxBD0AJI6gB4DEEfQAkDiCHgASR9ADQOIIegBIHEEP\nAIkj6AEgcQQ9ACSOoAeAxBH0AJA4gh4AElfzbYpt90j67KLS6yR9KCI+vmid6yV9UdL3stJwRPxd\nrcfE+WVkvMh94IEmUHPQR8SkpKslyXabpKKkg2VW/WpE3FTrcXB+GhkvanB4QqXTZyVJxdmSBocn\nJImwB9ZZo4ZubpD0PxHx/QbtD+e5odHJhZCfVzp9VkOjkzl1BLSuRgX9bkkHKix7k+1v2f6S7TdU\n2oHtvbbHbI/NzMw0qC3kZXq2VFUdwNqpO+htXyjp7ZL+rcziI5JeGxFXSfqEpJFK+4mI/RHRFxF9\nnZ2d9baFnHV1FKqqA1g7jfhEv0PSkYh4fvmCiHgpIn6cTR+S1G57QwOOiSY30N+jQnvbklqhvU0D\n/T05dQS0rkY8HHyPKgzb2H6NpOcjImxv1dw/LD9swDHR5OZPuHLVDZC/uoLe9isk3SjpPYtq75Wk\niLhb0jskvc/2GUklSbsjIuo5Js4fu3q7CXagCdQV9BHxE0m/uqx296LpuyTdVc8xAAD14ZuxAJA4\ngh4AEkfQA0DiCHoASBxBDwCJI+gBIHEEPQAkjqAHgMQR9ACQOIIeABJH0ANA4gh6AEgcQQ8AiWvE\n/ejRZEbGi9wHHsACgj4xI+NFDQ5PLDyYuzhb0uDwhCQR9kCLYugmMUOjkwshP690+qyGRidz6ghA\n3gj6xEzPlqqqA0gfQZ+Yro5CVXUA6asr6G0/a3vC9lHbY2WW2/Y/2T5h+9u2r6nneFjZQH+PCu1t\nS2qF9jYN9Pfk1BGAvDXiZOybI+KFCst2SLoi+/ktSZ/KXrFG5k+4ctUNgHlrfdXNzZI+HREh6Ru2\nO2xfFhEn1/i4LW1XbzfBDmBBvWP0IelR24dt7y2zvFvSc4vmp7LaL7C91/aY7bGZmZk62wIAzKs3\n6LdFxDWaG6K5zfZ1y5a7zDZRbkcRsT8i+iKir7Ozs862AADz6gr6iJjOXk9JOihp67JVpiRtWjS/\nUdJ0PccEAFSn5qC3fZHti+enJb1N0rFlqz0s6U+yq2+ulfQi4/MAsL7qORl7qaSDtuf385mIeMT2\neyUpIu6WdEjSTkknJP1E0rvqaxcAUK2agz4inpF0VZn63YumQ9JttR4DAFA/vhkLAIkj6AEgcQQ9\nACSOoAeAxBH0AJA4gh4AEkfQA0DiCHoASBwPB2+wkfEi94IH0FQI+gYaGS9qcHhi4eHcxdmSBocn\nJImwB5Abhm4aaGh0ciHk55VOn9XQ6GROHQEAQd9Q07OlquoAsB4I+gbq6ihUVQeA9UDQN9BAf48K\n7W1LaoX2Ng309+TUEQBwMrah5k+4ctUNgGZC0DfYrt5ugh1AU2HoBgASR9ADQOIIegBIXM1Bb3uT\n7a/YPm77Kdt3lFnnetsv2j6a/XyovnYBANWq52TsGUkfiIgjti+WdNj2YxHxnWXrfTUibqrjOACA\nOtT8iT4iTkbEkWz6ZUnHJXG5CQA0mYaM0dveLKlX0pNlFr/J9rdsf8n2G86xj722x2yPzczMNKIt\nAIAaEPS2XynpC5LeHxEvLVt8RNJrI+IqSZ+QNFJpPxGxPyL6IqKvs7Oz3rYAAJm6gt52u+ZC/sGI\nGF6+PCJeiogfZ9OHJLXb3lDPMQEA1annqhtLulfS8Yj4WIV1XpOtJ9tbs+P9sNZjAgCqV89VN9sk\nvVPShO2jWe2Dkn5NkiLibknvkPQ+22cklSTtjoio45gAgCrVHPQR8TVJXmGduyTdVesxAAD145ux\nAJA4gh4AEpfMbYpHxovcBx4Aykgi6EfGixocnlh4MHdxtqTB4QlJIuwBtLwkhm6GRicXQn5e6fRZ\nDY1O5tQRADSPJIJ+erZUVR0AWkkSQd/VUaiqDgCtJImgH+jvUaG9bUmt0N6mgf6enDoCgOaRxMnY\n+ROuXHUDAL8oiaCX5sKeYAeAX5TE0A0AoDKCHgASR9ADQOIIegBIHEEPAIlzMz4HxPaMpO/n3UcT\n2CDphbybaBK8F0vxfvwc78Wc10ZE2QduN2XQY47tsYjoy7uPZsB7sRTvx8/xXqyMoRsASBxBDwCJ\nI+ib2/68G2givBdL8X78HO/FChijB4DE8YkeABJH0ANA4gj6JmN7k+2v2D5u+ynbd+TdU95st9ke\nt/3vefeSN9sdtj9v+7vZfyNvyrunPNn+y+zv5JjtA7Z/Oe+emhFB33zOSPpARLxe0rWSbrO9Jeee\n8naHpON5N9Ek/lHSIxHxG5KuUgu/L7a7Jf2FpL6IeKOkNkm78+2qORH0TSYiTkbEkWz6Zc39Ibfs\njfZtb5T0u5LuybuXvNn+FUnXSbpXkiLipxExm2tT+btAUsH2BZJeIWk6536aEkHfxGxvltQr6cmc\nW8nTxyX9laSf5dxHM3idpBlJ/5INZd1j+6K8m8pLRBQl/b2kH0g6KenFiHg0366aE0HfpGy/UtIX\nJL0/Il7Ku5882L5J0qmIOJx3L03iAknXSPpURPRK+j9J+/JtKT+2L5F0s6TLJXVJusj2H+fbVXMi\n6JuQ7XbNhfyDETGcdz852ibp7baflfSQpLfY/td8W8rVlKSpiJj/P7zPay74W9VbJX0vImYi4rSk\nYUm/nXNPTYmgbzK2rbkx2OMR8bG8+8lTRAxGxMaI2Ky5k2xfjoiW/cQWEf8r6TnbPVnpBknfybGl\nvP1A0rW2X5H93dygFj45fS7JPBw8IdskvVPShO2jWe2DEXEov5bQRP5c0oO2L5T0jKR35dxPbiLi\nSdufl3REc1erjYvbIZTFLRAAIHEM3QBA4gh6AEgcQQ8AiSPoASBxBD0AJI6gB4DEEfQAkLj/B+Oq\n7PqIX/OQAAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "\u003cFigure size 600x400 with 1 Axes\u003e" ] }, "metadata": {}, "output_type": "display_data" } - ], - "source": [ - - "plt.scatter(df_lin_mean.x, df_lin_mean.y)\n", - "plt.show()" ] }, { + "id": "11b68c97", "cell_type": "code", - "execution_count": null, + "source": [ + "from sklearn import linear_model\n", + "\n", + "\n", + "class UnivarLinearReg(Operation):\n", + " def __init__(self, x, y, grp):\n", + " self.lm = linear_model.LinearRegression()\n", + " # Delegate most of the computations to Mean Metrics.\n", + " child = MetricList((Mean(x), Mean(y)))\n", + " # Register grp as the extra_index.\n", + " super(UnivarLinearReg, self).__init__(child, '%s ~ %s' % (y, x), grp)\n", + " # For custom Metrics, class-level caching needs to be manually enabled.\n", + " # See the Caching section for more information.\n", + " self.cache_across_instances = True\n", + "\n", + " def split_data(self, df, split_by=None):\n", + " \"\"\"The 1st element in yield will be passed to compute().\"\"\"\n", + " if not split_by:\n", + " yield self.compute_child(df, self.extra_split_by), None\n", + " else:\n", + " # grp needs to come after split_by.\n", + " child = self.compute_child(df, split_by + self.extra_split_by)\n", + " keys, indices = list(zip(*child.groupby(split_by).groups.items()))\n", + " for i, idx in enumerate(indices):\n", + " yield child.loc[idx.unique()].droplevel(split_by), keys[i]\n", + "\n", + " def compute(self, df):\n", + " self.lm.fit(df.iloc[:, [0]], df.iloc[:, 1])\n", + " return pd.DataFrame([self.lm.coef_[0], self.lm.intercept_])\n", + "\n", + "\n", + "lr = UnivarLinearReg('x', 'y', 'grp')\n", + "Jackknife('cookie', lr, 0.95).compute_on(df_lin)" + ], "metadata": { "colab": { "height": 143 @@ -3542,15 +3624,16 @@ "id": "q7InRdD7HsoD", "outputId": "c9834321-2e30-422a-c53b-137685e05c86" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Metricy ~ x
ValueJackknife CI-lowerJackknife CI-upper
02.0000651.9999282.000203
10.4995770.4983790.500775
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -3739,45 +3822,46 @@ "metadata": {}, "output_type": "execute_result" } - ], + ] + }, + { + "id": "aeee15f2", + "cell_type": "code", "source": [ - "from sklearn import linear_model\n", - "\n", + "class UnivarLinearRegSlow(Metric):\n", "\n", - "class UnivarLinearReg(Operation):\n", " def __init__(self, x, y, grp):\n", " self.lm = linear_model.LinearRegression()\n", - " # Delegate most of the computations to Mean Metrics.\n", - " child = MetricList((Mean(x), Mean(y)))\n", - " # Register grp as the extra_index.\n", - " super(UnivarLinearReg, self).__init__(child, '%s ~ %s' % (y, x), grp)\n", + " # Doesn't delegate.\n", + " self.x = x\n", + " self.y = y\n", + " self.grp = grp\n", + " super(UnivarLinearRegSlow, self).__init__(\n", + " '%s ~ %s' % (y, x), additional_fingerprint_attrs=['x', 'y', 'grp']\n", + " )\n", " # For custom Metrics, class-level caching needs to be manually enabled.\n", " # See the Caching section for more information.\n", " self.cache_across_instances = True\n", "\n", " def split_data(self, df, split_by=None):\n", " \"\"\"The 1st element in yield will be passed to compute().\"\"\"\n", + " idx = split_by + [self.grp] if split_by else self.grp\n", + " mean = df.groupby(idx).mean()\n", " if not split_by:\n", - " yield self.compute_child(df, self.extra_split_by), None\n", + " yield mean, None\n", " else:\n", - " # grp needs to come after split_by.\n", - " child = self.compute_child(df, split_by + self.extra_split_by)\n", - " keys, indices = list(zip(*child.groupby(split_by).groups.items()))\n", + " keys, indices = list(zip(*mean.groupby(split_by).groups.items()))\n", " for i, idx in enumerate(indices):\n", - " yield child.loc[idx.unique()].droplevel(split_by), keys[i]\n", + " yield mean.loc[idx.unique()].droplevel(split_by), keys[i]\n", "\n", " def compute(self, df):\n", " self.lm.fit(df.iloc[:, [0]], df.iloc[:, 1])\n", - " return pd.DataFrame([self.lm.coef_[0], self.lm.intercept_])\n", + " return pd.Series((self.lm.coef_[0], self.lm.intercept_))\n", "\n", "\n", - "lr = UnivarLinearReg('x', 'y', 'grp')\n", - "Jackknife('cookie', lr, 0.95).compute_on(df_lin)" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "lr_slow = UnivarLinearRegSlow('x', 'y', 'grp')\n", + "Jackknife('cookie', lr_slow, 0.95).compute_on(df_lin)" + ], "metadata": { "colab": { "height": 143 @@ -3795,15 +3879,16 @@ "id": "6PdjGmjqTmrl", "outputId": "c37e93b5-e93f-4193-be6b-9d7641e86413" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Metricy ~ x
ValueJackknife CI-lowerJackknife CI-upper
02.0000651.9999282.000203
10.4995770.4983790.500775
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -3992,46 +4077,16 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "class UnivarLinearRegSlow(Metric):\n", - "\n", - " def __init__(self, x, y, grp):\n", - " self.lm = linear_model.LinearRegression()\n", - " # Doesn't delegate.\n", - " self.x = x\n", - " self.y = y\n", - " self.grp = grp\n", - " super(UnivarLinearRegSlow, self).__init__(\n", - " '%s ~ %s' % (y, x), additional_fingerprint_attrs=['x', 'y', 'grp']\n", - " )\n", - " # For custom Metrics, class-level caching needs to be manually enabled.\n", - " # See the Caching section for more information.\n", - " self.cache_across_instances = True\n", - "\n", - " def split_data(self, df, split_by=None):\n", - " \"\"\"The 1st element in yield will be passed to compute().\"\"\"\n", - " idx = split_by + [self.grp] if split_by else self.grp\n", - " mean = df.groupby(idx).mean()\n", - " if not split_by:\n", - " yield mean, None\n", - " else:\n", - " keys, indices = list(zip(*mean.groupby(split_by).groups.items()))\n", - " for i, idx in enumerate(indices):\n", - " yield mean.loc[idx.unique()].droplevel(split_by), keys[i]\n", - "\n", - " def compute(self, df):\n", - " self.lm.fit(df.iloc[:, [0]], df.iloc[:, 1])\n", - " return pd.Series((self.lm.coef_[0], self.lm.intercept_))\n", - "\n", - "\n", - "lr_slow = UnivarLinearRegSlow('x', 'y', 'grp')\n", - "Jackknife('cookie', lr_slow, 0.95).compute_on(df_lin)" ] }, { + "id": "a0712d3e", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "Jackknife('cookie', lr, 0.95).compute_on(df_lin)" + ], "metadata": { "executionInfo": { "elapsed": 2244, @@ -4046,6 +4101,7 @@ "id": "aNLbzk-CUgqm", "outputId": "252c3218-9a82-4b48-90c5-df3d938bd236" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -4054,16 +4110,16 @@ "1 loops, best of 5: 357 ms per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "Jackknife('cookie', lr, 0.95).compute_on(df_lin)" ] }, { + "id": "0ca94456", "cell_type": "code", - "execution_count": null, + "source": [ + + "%%timeit\n", + "Jackknife('cookie', lr_slow, 0.95).compute_on(df_lin)" + ], "metadata": { "executionInfo": { "elapsed": 3649, @@ -4078,6 +4134,7 @@ "id": "a3BBhzbdUUwo", "outputId": "c6728da1-d04c-46d3-8b04-5fb62811b125" }, + "execution_count": null, "outputs": [ { "name": "stdout", @@ -4086,29 +4143,21 @@ "1 loops, best of 5: 586 ms per loop\n" ] } - ], - "source": [ - - "%%timeit\n", - "Jackknife('cookie', lr_slow, 0.95).compute_on(df_lin)" ] }, { + "id": "a617e872", "cell_type": "markdown", - "metadata": { - "id": "QFjhj96EdK-r" - }, "source": [ "## LOWESS" - ] + ], + "metadata": { + "id": "QFjhj96EdK-r" + } }, { + "id": "a48d4dc6", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2e_ttZzta7JH" - }, - "outputs": [], "source": [ "# Mimics that measurements, y, are taken repeatedly at a fixed grid, x.\n", "np.random.seed(42)\n", @@ -4116,15 +4165,16 @@ "x = list(range(5))\n", "df_sin = pd.DataFrame({'x': x * size, 'cookie': np.repeat(range(size), len(x))})\n", "df_sin['y'] = np.sin(df_sin.x) + np.random.normal(scale=0.5, size=len(df_sin.x))" - ] + ], + "metadata": { + "id": "2e_ttZzta7JH" + }, + "execution_count": null, + "outputs": [] }, { + "id": "8534a5fd", "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Eei8Kd0wd-Gt" - }, - "outputs": [], "source": [ "import statsmodels.api as sm\n", "\n", @@ -4149,11 +4199,19 @@ " lowess(data[self.y], data[self.x]), columns=[self.x, self.y]\n", " )\n", " return lowess_fit.drop_duplicates().reset_index(drop=True)" - ] + ], + "metadata": { + "id": "Eei8Kd0wd-Gt" + }, + "execution_count": null, + "outputs": [] }, { + "id": "06fd1856", "cell_type": "code", - "execution_count": null, + "source": [ + "Lowess('x', 'y') | compute_on(df_sin)" + ], "metadata": { "colab": { "height": 206 @@ -4171,15 +4229,16 @@ "id": "qvp8ihsnlY-d", "outputId": "e5728324-6bb1-4281-8afc-9aae71cf1565" }, + "execution_count": null, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
xy
00.00.016953
11.00.592061
22.00.575784
33.0-0.109848
44.0-1.096083
\n", - "
\n", - " \n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cpath d=\"M0 0h24v24H0V0z\" fill=\"none\"/\u003e\n", + " \u003cpath d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/\u003e\u003cpath d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", " \n", " \n", - "
\n", - " \n", - "
\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\"\u003e\n", + " \u003cg\u003e\n", + " \u003cpath d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/\u003e\n", + " \u003c/g\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + " \u003c/div\u003e\n", " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - " \n", + " \u003c/style\u003e\n", "\n", - " \n", - "
\n", - "
\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n", " " ], "text/plain": [ @@ -4378,14 +4437,27 @@ "metadata": {}, "output_type": "execute_result" } - ], - "source": [ - "Lowess('x', 'y') | compute_on(df_sin)" ] }, { + "id": "74f8cdcc", "cell_type": "code", - "execution_count": null, + "source": [ + + "jk = Lowess('x', 'y') | Jackknife('cookie', confidence=0.9) | compute_on(df_sin)\n", + "point_est = jk[('y', 'Value')]\n", + "ci_lower = jk[('y', 'Jackknife CI-lower')]\n", + "ci_upper = jk[('y', 'Jackknife CI-upper')]\n", + "\n", + "plt.scatter(df_sin.x, df_sin.y)\n", + "plt.plot(x, point_est, c='g')\n", + "plt.fill_between(\n", + " x, ci_lower,\n", + " ci_upper,\n", + " color='g',\n", + " alpha=0.5)\n", + "plt.show()" + ], "metadata": { "colab": { "height": 265 @@ -4403,33 +4475,18 @@ "id": "3eR0fV8feyqu", "outputId": "9d9cacaf-b344-4eda-91d9-7a4626b92839" }, + "execution_count": null, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsT\nAAALEwEAmpwYAAA460lEQVR4nO3deVQc173o++/uppkRIDGDGDQBmtGARmuKbMmSZUuybMdJ7Dgn\ntuIkPvfel3WcxE4c+9x3s+LzfO9ZcZKT5yFx5sSJJ0UvdqxYxpI8RGjGmq0BkJgEQsw00MN+f0BL\ngBrokQb691mLpe7q6qpNqfn1rl2/+m2ltUYIIcTYZwh0A4QQQgwPCfhCCBEkJOALIUSQkIAvhBBB\nQgK+EEIEiZBAN2AwCQkJOjs7O9DNEEKIUePw4cNXtdaJzl4b0QE/OzubQ4cOBboZQggxaiilygd6\nTYZ0hBAiSEjAF0KIICEBXwghgoQEfCGECBIS8IUQIkiM6Cwd4X87jlby3K6zVDWaSYuL4PF1uWwu\nSA90s4QQfiABP4jtOFrJE28ex2yxAVDZaOaJN48DSNAXYgySIZ0g9tyus9eDvYPZYuO5XWcD1CIh\nhD9JwA9iVY1mt5YLIUY3CfhBLC0uwq3lQojRTQJ+EHt8XS4RJmOfZREmI4+vyw1Qi4QQ/iQXbYOY\n48KsZOkIERwk4Ae5zQXpEuCFCBIS8IOc5OELETwk4AcxycMXIrjIRdsgJnn4QgQXCfhBTPLwhQgu\nEvCDmOThCxFcJOAHMcnDFyK4yEXbICZ5+EIEFwn4QU7y8IUIHj4Z0lFKvaKUqlVKnRjg9VVKqSal\n1LGenx/4Yr9CCCFc56se/q+BnwG/HWSdD7XWd/hof0IIIdzkkx6+1nofcM0X2xJCCOEfwzmGv0Qp\nVQJUAf+mtT7pbCWl1HZgO0BmZuYwNk+IoUkpCjGaDVda5hEgS2s9B/gpsGOgFbXWL2mtF2itFyQm\nJg5T84QYmqMURWWjGc2NUhQ7jlYGumlCuGRYAr7Wullr3drz+B3ApJRKGI59C+ErUopCjHbDEvCV\nUilKKdXzuLBnv/XDsW8hfEVKUYjRzidj+EqpPwGrgASlVAXwNGAC0Fq/AGwDvq6UsgJm4PNaa+2L\nfQsxXNLiIqh0EtylFIUYLXwS8LXW9w/x+s/oTtsUYtR6fF1un3LSIKUoxOgid9oK4SIpRSFGOwn4\nQrhBSlGI0UyqZQohRJCQHr4QbpAbr8RoJgFfCBfJHMBitJMhHSFcJDdeidFOAr4QLpIbr8RoJwFf\nCBfFRZrcWi7ESCMBXwgXDXRvuNwzLkYLCfhCuKjJbHFruRAjjQR8IVw0UM0cqaUjRgsJ+EK46PF1\nuUSYjH2WSS0dMZpIHr4QLpJaOmK0k4AvhBuklo4YzSTgC+EGKa0gRjMJ+EK4SEoriNFOLtoK4SIp\nrSBGOwn4QrjI2fSGgy0XYqSRgC+Ei4xKubVciJFGAr4QLrINUENhoOVCjDQ+CfhKqVeUUrVKqRMD\nvK6UUj9RSp1XSn2qlJrni/0KMZykhy9GO1/18H8NrB/k9duBqT0/24H/10f7FWLYSA9fjHY+Cfha\n633AtUFWuQv4re62H4hTSqX6Yt9CDJf0AWrmDLRciJFmuMbw04HLvZ5X9Cy7iVJqu1LqkFLqUF1d\n3bA0LpjtOFrJsmeLyPnu2yx7togdRysD3aQR6/F1uZgMfYdvTAYltXTEqDFcAd/ZIKfT82Ct9Uta\n6wVa6wWJiYl+blZw23G0ksdfK6Gy0YymO73w8ddKJOgPpv8nWYbvxSgyXAG/ApjY63kGUDVM+xYD\neGbnSSz2vt+7FrvmmZ0nA9Sike25XWex2PodL5uWG68GIWeQI8twBfydwIM92TqLgSatdfUw7VsM\noHGAiTsGWh7s5MYr9zhKUfQ+g3zizeMS9APIV2mZfwL+CeQqpSqUUl9VSj2qlHq0Z5V3gIvAeeBl\n4Bu+2K8Qw0nSMt0jpShGHp8UT9Na3z/E6xr4pi/2JXwnPtJEQ/vNvfl4mZTbKUnLdE/VAGc+Ay0X\n/id32gaxpzfNwGTsl3ViVDy9aUaAWjSyxUU4/yIcaHmwkykhRx4J+EFsc0E6z22bQ3pcBIrufPLn\nts2RUr8DGGjkRkZ0nJMpIUceqYcf5GQGJ9c1Ohn+Gmx5sJMpIUceCfhCuCgtLsJpRo4MUQxMOhQj\niwzpBDnJk3adDFGI0U56+EFMpuxzjwxRiNFuzAV8mWTadYPlScsxc06GKMRoNqYCvvRY3SN50kIE\nlzE1hi939rkndoD88YGWByOtNTa7DYvNQoe1g3ZLO61drbRb2gPdNCHcNqZ6+NJjdY+v8sq11ti0\nDbu2Y9d2bPYbj+3aPuBrni632q19fmx2243nuudfmxWbtl1/3fG4/3ObtnU/71lms9uwait2+439\nqV4lMVWvg7M6ezUbpm5gQuQETw6/EMNuTAV8SZtzT2O7BTsddKlSugzn0MqCxkZTl51/3/NPrNp6\nPRha7dbrwdambX0CpUajUGg0BnXjpLF3oHQ8VKg+hbE1GqW639t/+fXHPaULHOsopbp/6PuvQRlu\nWjbUawZlIMQQgslgcro+9A3yDla7lb3le9lTtod1U9axbvI6YsNjffC/IoT/jKmA//i63D5j+CBp\nc87Y7DbOXTuHIbqIus7jgB2lw1EYAEV0WCjNnc19Ap/JaLopUPYOkMEmxBBCxrgMLDYLu87v4r0L\n77Fx2kbWTlpLdGh0oJsnhFNjKuBL2tzgalprOFB5gPcvvk9zZzNT0i00lk3AZrsRsEOMBlZMSSIm\nLCaALR09TEYTE2Mn0mXrYufZnfz93N+5M/dOVuesJtIUGejmCdGH0iO40t+CBQv0oUOHAt2MUa2t\nq41Pr3zKexffo7ShFKUUiZGJRJi6h7mKzlzheGUzWncPrcxKH8eavOQAt3r06rR2UtNaQ6Qpki15\nW7gl6xbCQsIC3SwRRJRSh7XWC5y9NqZ6+KKbY8jmw/IPKa4sxma3MS5sHJmxmX2GX85UN3OiqvnG\nGLnWnKhqJi02grzUcYFq/qgWFhJGVlwWZouZPxz/Azs/28m26dtYnLGYUGNooJsngpwE/DGk/5BN\nqDGUlOgUQgzO/5v3fFaHvd8Uh3a7Zs9ndRLwvRRhiiArLou2rjZeOfoKO87s4J7p97AwfeGA/x9C\n+Jt88ka5tq42SmpKeK/0Pcoayq4P2cRHxA/53o6ei9s2mrCr9p6sFEWrVXG5KbTPRVuDMly/UDvU\nMsfz3pkxwSoqNIqo0Chau1p58fCLvHXmLe6bcR9zU+ZiNBiH3oAQPiQBfxSy2W18Vv8ZH176kAOV\nBwYcshmMXduxqAo6DWexGWpvev2VY75tc+8vAFe/OAZdRq/XHF8wTpYNuT/6fUG5uCwlOoWEyASX\nf//o0GiiQ6Np6mjiJ8U/IWNcBvfNvI+ZSTP7pLIK4U8S8EcJrfX1IZui0iKXhmyc6bR2cqzmGMWV\nxbSHNKB0JOG2uRh1At2J8BpTiGLjrBS01t03PWG//lhrjR37jce9ll1f351l3NiGq/uz2+1YsfZt\nQ6/tONv2QMt65/u7a3bSbFZmr2R8xHiX3xMbHktseCwN5gb+zyf/h0nxk7h3xr3kJeQF9ZmQGB6S\npTPCtXW1cazmGLtLd1PWUIZBGUiITLieZeOqBnMDByoPcLTmKJ22TiaOm8jEqNmcuhyD1jcCjUEp\nbpueHDRj+FprNPqmLw5nXyaOZTZto+RKCQcqD2DXdgpSCliRtYJxYe4dM60118zXaO5sJj8xn23T\ntzE5frIEfuGVwbJ0fBLwlVLrgecBI/ALrfWz/V5fBfwVKO1Z9KbW+n8Otd1gDfi9h2yKK4qxa3t3\nzzAs1q1goLXmUtMliiuLOXP1DEoppidOZ3H6YtLHdd+bcKa6mY8v1NPSYSEm3MSyyROCJth7q6Wz\nhX2X9nGk+ggGZWBh2kKWZy53O/9ea83V9qu0WlqZkzyHrflbyY7L9k+jxZjn14CvlDICnwG3AhXA\nQeB+rfWpXuusAv5Na32HO9sOpoDfe8jm/dL3aelsISwkjITIBLezOmx2GyfrTrK/Yj/VrdWEh4Qz\nP3U+hemFbvdCxdAazA3sLd9LyZUSQo2hLMlYwpKMJW7n32utudJ2BbPFTGF6IZvzNl//YhbCVf7O\nwy8EzmutL/bs7FXgLuDUoO8SQK8hm4u7KW0sxaiMJEQmuDUu3Htbh6sPc7DqIK1drSREJrBx6kbm\nJM/BZJQKmP4SHxHP5rzNLJ24lD1le9hbvpcDlQdYnrmchWkLXT72SilSolOwazufXvmUg1UHWTZx\nGZtyN5ESneLn30IEA1/08LcB67XWD/c8fwBYpLV+rNc6q4A36D4DqKK7t39ygO1tB7YDZGZmzi8v\nL/eqfSORY8hmX/m+6+PAngzZONS21bK/Yj/Ha49jtVuZHD+ZxRmLZTw4QCqbK/mg7AMuNFwgJjSG\nFVkrKEgpcDsN067tVLdUY9VWVmetZsO0DW5lBong5O8hnXuAdf0CfqHW+l97rTMOsGutW5VSG4Dn\ntdZTh9r2WBrS6T1ks/viblq7WgkLCSMxMtGjfGytNeevnWd/5X4uNlwkxBDC7OTZLE5fTGJUolvb\nstltNHQ0oFAYDUaMyojRYMSgDNcfB3s+vSfKGssoKi3icvNl4sPjWZW9yqM0TJvdRnVrNXZt59bJ\nt3L7lNuJC4/zT6PFqOfvgL8EeEZrva7n+RMAWusfDfKeMmCB1vrqYNseCwG/tauVkpqSm4Zs3M2y\nceiydVFSU0JxZTH15npiQmNYmL6Q+anz3b5Y6AgkNruN2cmzMRqMmC1mOqwddNg66LR2dv/YOrHa\nrX2CvlLdZY4dZZGBPtku/b8wjKrnuZMvFEeu+1jk+GIuKiuiprWGxMhEVuesJm+C+2mYVruV6pZq\nlFJsmLqBWyfdKkXuxE38HfBD6L5o+zmgku6Ltl/oPWSjlEoBrmittVKqEHgdyNJD7Hy0BnxfD9kA\nNHU0caDqAEeqj9Bh7SAtJo3F6YuZnjjd7TMEm91GTWsNVruV5ZnLuWPaHSRHD14wzWa3YbFb6LJ1\nYbH1/DvIc7PVTLul/foXiNna86/FTKetkw5r9xdKh7UDi91yo9Ryrxr0WnfX2bdjB26kUDpunhro\nbKT/l8tI+DLRWnOq7hQflH1AvbmetJg01mSvYVL8JLfb12XroqalBpPRxKbcTazOXk1UaJSfWi5G\nm+FIy9wA/JjutMxXtNY/VEo9CqC1fkEp9RjwdcAKmIFvaa0/GWq7ngT8QE1i7hiyKa4s5v2L73s9\nZONQ0VzB/or9nKrrvgaen5DP4ozFZIzLcDtQ9A/0G6dtHBEXA+3ajsVmcfkLxfHFYbaar//bYe3o\n89Np7aTD1kGXtev6JCu9v1Cg79mJHTvd86voQc9GHMtCDCEe/b/atZ2SKyXsLdtLU2cT2bHZrMlZ\nw8TYiW5vy1GZMzwknC15W1iRvYLwkHC3tyPGFr8HfH9xN+D3n8QcuidA+dHWWX4L+r2HbMoaPb8x\nqjeb3cbpq6fZX7GfypZKwoxhzEudR2F6oUdjt70D/bLMZWycupHUmFSP2zeaaK2x2q1DnpE4nnda\nO6+fnTi+WBxnKI7XOm2dtHS2YNM2UqJSPCp/bLVbOVx9mA/LP6TN0sa08dNYnbPaoy/gDmsHNa01\nRIdGs236NpZOXCqVOYNY0AT8Zc8WOZ3iMD0ugo+/u8Zn7fLHkA2A2WK+nlbZ3NnM+IjxLEpfxJzk\nOR4FFbu2U9NaQ5eti6UTl7Jp2qabAn2gzohGu3ZLO0WlRew8uxOr3UpqdKpHqa9dti6KK4v55PIn\ndFg7mJk4k1XZqzyaJ7fd0k5tay1xEXHcO+NeCtMLpTJnEAqagJ/z3bedVkZRQOmzG71qi7+GbACu\ntl+luLKYkpoSLHYLOXE5LEpfxLQJ0zz6Aukf6O+YdgdpMWk3rReIM6KxprmzmV3nd/Hu+XdRSpEa\nnerR58FsMfNJxScUVxRjtVuZmzKXlVkrPZont7WrlavtV0mITOC+GfcxL3WeVOYMIkET8P3Rw/fH\nkA10f4FcbLjI/sr9nL92HqMyMit5FovTFw95AXUgjrxti93CkowlbMrd5DTQOwzXGVEwqG+v5+3P\n3uaDsg8wGU2kRKd4VAWztauVDy99yOGqwwAsSFvALZm3eHRRtrmzmXpzPenR6dw38z5mJc8a9sqc\ncgY5/IIm4Puqx2q1W68P2RysPOizIRsAi83Cp7WfUlxRTF17HVGmKBamLWRB2gKPMy3s2k5NSw1d\n9i6WZCzhjml3uHRLvj/PiIJVdUs1O87soLiymPCQcJKjkj36zDR1NLG3fC/Hao4RYghhccZilk5c\n6tFF2caORhrMDWTHZXPvjHuZnjh9WDKX5AwyMIIm4IPnPQqtNdWt1RRXFFNUWuTTIRvoLrR1sOog\nh6oOYbaaSYlOYVH6ImYmzfR4nLX30M2i9EXcmXunW7VXpIfvP+WN5bx+6nU+vfIpMWExTIiY4FGQ\nvdp+lT1lezhZd5LwkHCWTVxGYXqh2xdltdY0dDTQ1NHEtAnTuGfGPUwdP9WvgV8+X4Ehc9oOorWr\nlWM1x3jv4nuUN5ZjVEYSoxI9umjmTFVLFfsr9nOy7iR2bSdvQh6LMhaRFZvl8R9b70BfmF7Inbl3\nkjEuw+3trM5L5Pf7LzldLryTFZfFt5Z8i3PXzvHnE3/m/LXzxIXHuTQTWW8JkQlsm76N5a3LKSot\n4v3S99lfsZ9bsm5hfup8lzsLSinGR4wnPjyeqpYqfrjvh8xKnsXd+XeTE5/jya84pConwX6w5cL/\nxlQP39VTSH8O2UB3QD5z9QzFFcVcar5EqDGUgpQCCtMLPSqK1nu7V1qv0GnrpDC9kE3TNnmUv+0g\nPbDhobXmRO0JXj3xKhUtFSREJHh8h+ylpksUlRZR3lRObFgsK7NXMid5jttj81prattqabe0Mz9t\nPlvytnj1WXJGPl+BETRDOoN9wD76zmq/DtlAdz700eqjHKg6QGNHI3HhcRSmF1KQUuDVDTGOQN9h\n66AwrbtH74s/ThnDH142u40j1Ud49cSrXDVfJTkq2e1yGHDjgn9RWRFVLVVMiJjA6pzVTE9wf2y+\ndydiScYS7sy902f3aMgYfmAETcB3FsDsmOkyXORrt7X1GbLx5R2J18zXKK4s5ljNMbpsXWTFZrEo\nYxG5E3K9yoroHegXpC7grry7yIzN9Fm7pQcWGBabhU8uf8Lrp16npauFlOgUjz6PWmvO1J/hg9IP\nqGuvIyU6hTXZa5gyfopHgb+mtQaLzcKKrBXcMe0Ot4vwOSNZOsMvaAK+I4BpNFZVQYfhOJ2Gz4gM\nNfKFhXk+G7KB7j+2ssYyiiuLOVt/FoMyMDNpJovSFw2aCukKu7ZT21qL2WpmQZrvA72D9MACq8Pa\nwZ6yPew4s4NOWyep0ake3SFr13aO1x5nT9keGjsayRyXyZqcNWTFZbm9rd6VOdfmrOX2qbe7fd1B\nBFbQBHxHAGu1XKXB9GuUNhFqjOPWvFSfTdtntVs5UXuC/RX7udJ2hUhTJPNT57MwbaHXlQt7z3g0\nP20+d+Xe5dEfrTukBxZ4rV2tvHfhPd4+9zZaa1JjUj3K3LLZbRypOcK+8n20drUyOX4ya3LWeNQB\n6V2Zc/2U9dw2+TaZLW2UCJqAD90B7H+9+xFn239JXFiqz+Zobe1q5VDVIQ5VHaLN0kZSVBKL0hcx\nK2mW17NJOQJ9h7WDuSlz2Zy3WeY0DUIN5gb+fu7v7C7djVEZSYlO8ej6ksVm4WDVQT669BFmq5n8\nhHxWZ6/2aIjGYrNQ01qDwWBg07RNrMlZQ3RotNvbEcMnqAI+dM849IMPfuCTC5s1rTUUVxRzvPY4\nNm1j6vipLM5YTE5cjtfDQ7179HNT5rIlf4sEekFtWy07z+7ko0sfEWYMIyU6xaPPWoe1g/0V+/ln\nxT+x2CzMTp7NyqyVHg3ROEoyh4aEsjlvMyuzVnp9t7nwDwn4brJrO+fqz7G/cj9ljWWYDCbmpMxh\nUfoin0wxdz0lztrO3OQbPfqRULddjBwVzRW8efpNDlcfJsoURWJkokefkXZLOx9d+ogDlQfQaOan\nzueWzFs8GoLsXZlza/5Wlk1c5lFhP+E/EvBd1Gnt5NiVYxRXFNPQ0cC4sHEUphcyL2WeT3ozvQP9\nnOQ5bMnbIoFeDMqRgvnaqdc4XXea2PBY4sPjPfrMNHc2s698H0drjmJQBgrTC1k2cZlHqaFmi5kr\nbVeIDYtl2/RtLM5Y7PXQpvANCfhDaOxopLiymKPVR+m0dZIxLoPF6YvJS8jzSY5+75tcZiXPYmv+\nVp8MCYngobXm9NXT/PnEnylrLGN8xHiPKmlCdxrx3rK9fFr7KWHGMJZMXMLi9MUe9dTbutqobau9\nXplzftp8qcwZYBLwndBac7n5Mvsr9nPm6hkAZiTOYFHGIo/KFAy0j7r2Otot7cxMmsmWvC0eTWkn\nhINd2zlWfYxXT77KldYrJEYlenwRtbatlg9KP+BM/RkiTZEsz1zOwrSFHmUItXS2UN9eT0p0CvfO\nuJe5qXOHvTKn6BZUAX+oLB2b3cbJupMUVxZT1VJFeEg481PnU5he6LO0M0egb7O0MTNxJlvzt0qg\nFz5ltVs5UHmAv5z8C40djSRHJXs87FjZXElRaREXGy8yLmwcK7JWMDd5rkc99aaOJq6Zr5EZm8l9\nM+9jRuIM+dwPs6AJ+I48/BbLFRpNfyBEJxBiNLA2L4nMhBAOVR3iYNVBWrtamRAxgcUZi5mdPNtn\n08FdD/RdbcxImsHW/K1Mjp8sH3jhN53WTj6+/DFvnHqDdmu7x1MuApQ2lFJUWkRFSwXx4fGszl7N\nzKSZbn9+tdY0djTS2NHIlPFTuGfGPeROyJW/g2ESNAHfcaetlavXA76NJnToOTpVKVa7lcnxk1mU\nvsij288H0jvQT0+cztb8rT7dvhBD8dWUi1przl07R1FpEVfarpAUlcSa7DUezb6mtabeXE9LZwsz\nkmbwwOwHgmYu5UDye8BXSq0HngeMwC+01s/2e131vL4BaAce0lofGWq7ntbSsVBLfehPsalqrIYa\n0EbmpXWnVSZFJbnxmw1Oa83V9qu0WlrJT8jn7vy7JdCLgGrqaGLXhV3sOr/LqykXtdacrDvJB2Uf\ncM18jfSYdD6X8zm3SymfqW7mo/NXaeq8SmSY5tHCB3jq1ntlrl0/8mvAV0oZgc+AW4EK4CBwv9b6\nVK91NgD/SnfAXwQ8r7VeNNS2PenhX26spzrsv2E11KB0BKH2qYw35bH9lnz3frFBXA/0Xa3kJeRx\n9/S7/T6ZhBDuqG+v52+f/Y09ZXu8mnLRZrdRcqWEveV7ae5sJicuhzU5a1xKbDhT3czuM7VYbXYA\nNBYwNrAxv4D/vfH/8rrmlHDO3wF/CfCM1npdz/MnALTWP+q1zovAHq31n3qenwVWaa2rB9u2p2P4\nl/nfWNUVwuzTMRlNrM1L8kl5hd6BPjchl7vz7/Z4onEhhkN1SzVvnXmLA5UHvJpy0Wq3cqjqEB9e\n+pB2Szu5E3JZnb160PmXf/lRKS0dlj7LNJqIsHY2zJ7AvTPuZe2ktdLb9zF/z3iVDlzu9byC7l78\nUOukAzcFfKXUdmA7QGamexUiHUW//te7j3C2/ZeMCw/3SS2d3mOREujFaJIak8o3Fn6DjVM3ejXl\nomNe3Xmp89hfsZ9PLn/CC4dfYFbSLFZlr3I6sU//YA+gUHR0RpESncKfjv+J4opiHp73sFtTcwrP\n+SLgO/vU9D9tcGWd7oVavwS8BN09fHcbs7kgnYWTl/ODD/7hdS2d3oF+6oSpPFb4mGQbBLnRWl3U\nV1MuhhpDWZG1goVpC/n48scUVxZzovYEBakFrMxa2Se1OSbc5DTox4SbCDWGkh2XzZW2Kzz1wVPc\nnX8366ask96+n/ni6FYAvSNrBlDlwTojRu9AP2X8FL658JvkJeRJoA9y/ecPqGw088SbxwFGRdBX\nSjFtwjS+v+L7HK89zqsnXqW0sdSjKRcjTBGsnbSWRemL+PDShxyuPkxJTQkL0xeyfOJyokKjWDZ5\nAv84fQW7/Ua/zWBQLJs84Xp7kqKS6LJ18ZeTf+FA5QEenvewz6daFDf4IuAfBKYqpXKASuDzwBf6\nrbMTeEwp9Srdwz1NQ43fB4LWmmvmazR3NkugFzd5btfZPpPFAJgtNp7bdXZUBHwHpRSzk2czI3EG\nh6sO8+eTf6a0sdSjKRdjwmLYMHUDSycuZW/ZXoorijlSfYTFGYuZEDLz5vN4J+fsocZQcuJzqGur\n4wd7fsDded29fanN43teB3yttVUp9Riwi+60zFe01ieVUo/2vP4C8A7dGTrn6U7L/Iq3+x3IjTtt\nLxMXZnVpDL93oJ88fjJfX/h18hPyJdCLPqqcTAc52PKRzmgwUphRSEFqwfUpF2vbaj2acjEuPI67\n8u5i6cSl7Cnfw77yfRjYT6jKJ1RPQ/WEGrvWfHyh3unfZGJUIhabhddOvUZxZTGPzH/ELzO9BbMx\ndePVYHfaOvuA9Q70k+InsW36NqYnuj8RtAgOY30OYF9NuQjd2UGvHPr/sBqqUTqccNscTDoH1XM5\n73+snTbo++vaumtQbcnfwvop6312N3ww8HeWzojh7JTbarPf1KPQWtPQ0UBTRxM58Tk8uuBRCfRi\nSI+vy3U6B/Dj63ID2CrfCQ8JZ/2U9SzPXO71lIupMamkmG6lobOSDkMJ5pBirPYaImwLGRc+9LCR\no7f/5qk3r4/ty+RA3htTAX+gU2tHpoAj0Dd2NDIpfhLb529neuJ0qeonXOIYpx+NWTruiA6NZkv+\nFlZlr+oz5WJqTKpbfys5CZG0VCQRZVtLpz5Fp+E4tpBrTI5d59L7TUYT2fHZXG2/yjN7nuGu3LvY\nOG2j9Pa9MKaGdJzV0gGIDgth64I4mjqayIrL4p7p9zAjaYYEeiFc4OmUi/1vvLKqWtqNn6BVF3dM\nu515KfNcPqu22CxUtlSSFpPGI/MecbvEQzAJmuJp/cfwjXoCBmM7czNN3DJpugR6IbxQ0VzBG6fe\n4EjNEZemXPzx7s9uWmanA7Pxn1gNNcxKmsUd0+5wq8de315PS1cLd+beycapG2V6RSeCJuADfH/H\ncX5XfJSrIa9gwMictCn8fOs3mZk0UwK9EF5yZ8pFZ6UVoPuMe1pmBXvK9jA+Yjz3TL9n0BIN/Vnt\nViqaK0iJTmH7/O1Mip/k1e801gwW8MdUBNxxtJI3DleidRih9inEWDdzteouLlZNkGAvhA8opZg8\nfjLfWfYdvrP8O8SFx1HaWEpTR9NN6y6bPIEQY9+/uxCjgeVTEliRtYIH5zxIp62TXxz9BUeqj+Bq\n5zPEEEJ2XDbtlnb+fc+/89rJ1+i0dvrk9xvrxlQPf6ynzQkx0vSecrG2tZaEqIQ+Uy6eqW7m4wv1\ntHRYiAk33XRfTGtXK2+dfouLjReZnTybjVPduyhrtVupbK4kKSqJR+Y/wpTxU3z6+41GQTOk46iH\n358CSp/d6LN2CSH68mbKRbu282H5h+wp30NCZAL3TL/H7Xkrrpmv0dTZxIYpG7gr7y63bxwbS4Jm\nSCc2wvmt2AMtF0L4RoghhKUTl/Ifa/+DB2Y/QEtXC+VN5S4NtRiUgZXZK3lw9oOYLWZePvIyx2qO\nubX/8RHjmThuIu9eeJenip7is/qbLxiLMdbDn/7U32m32G9aHmkycOr/vt2XTRNCDKLd0k7RxSJ+\n8smfKKmop6NzHOPCI4YsddLa1cqbp9+ktLGUOclz2DB1g9t59w3m7ntt1k1Zx5a8LR5P7j5aBU0P\n31mwH2y5EMI/Ik2RWNsLKLu4Fbt5LjbVQGPHNXafqeVMdfOA74sOjeZLs7/EyqyVlFwp4RdHfkFd\nW51b+46PiCczNpP3LrzH94u+z9mrZ739dcaMMRXwhRAjx3O7ztJlCSfKfgtx1i+gMNBhv8rHF+oH\nfZ9BGViVvYoHZj9Au6Wdl4+8TElNiVv7NhqMZMVlYdM2fvjhD/nDp3+g3dLuza8zLHYcrWTZs0Xk\nfPdtlj1bxI6jlT7d/pgK+PGRzsfqB1ouhPCf3qVOQnQSsdYvEKITaOyscikFc1L8JL42/2ukx6Sz\n4+wO/nr2r1hsN+f1DyYuPI6s2Cx2l+7m+0Xf53Tdabd/j+HiuHG0stGM5sZ8C74M+mMq4G+cnerW\nciGE/6TF9R07NxJNrPVuEsNmcLHhIla7dchtxITF8MCcB1iRuYJjNcd4+cjLbg/xGA1GsmKz0Frz\n7EfP8tuS39LW1ebWNobDYPMt+MqYCvgfnHH+QRhouRDCf1bnJd60TBHK/TO+wt3T76a8sRyzZei5\nBAzKwOqc1Xxp1pdos7Tx8pGX+fTKp263JzY8lszYTPaU7uF7Rd/jVN0pt7fhT8Mx38KYCvhjbYIK\nIUazgTpae87WszlvM48VPkZdex2NHY0ubW/y+Mk8Ov9R0mLSeOvMW+w8u9PtIR6jwUhmXCYKxbMf\nPcuvj/56xPT2+58RDbXcE2Mq4A/HARNCuGaoDtiijEV875bvYbVbqWmtcWmbMWExPDjnQW7JvIWj\nNUf5xdFfcLX9qtttiw2PJSs2i32X9vFk0ZMcv3Lc5dIO/vL4ulwiTMY+y3w938KYCvjDccCEEK5x\npQM2efxknl75NBMiJ3C56bJLQdegDKzJWcMXZ32R1q5WXjr8EsevHHe7fUaDkczYTIzKyHOfPMev\njv2K1q5Wt7fjK5sL0vnR1lmkx0Wg6C4J86Ots3w638KYuvEKuq90j/UJKoQYDRxZJ/1nCHMWxNot\n7bx0+CWOVh/tDsIGY//NOdXc2cwbp97gUvMl5qXOY/3k9R5Nfm7XdiqbK4kMjeThgoeZnTx71M6A\nFzS1dIQQI4s7HTCr3cprp17jnc/eIWNchsu17u3aTlFpER9f/pjkqGTumX4PEyIneNTels4W6trr\nWJG5gvtm3kdMWIxH2/GULzqsfgv4SqnxwJ+BbKAMuFdr3eBkvTKgBbAB1oEa058EfCGCi9aaveV7\n+fWxXzMhYoJbAfdc/TneOvMWNm1j07RNzEya6VEb7NpOVUsVESERfLXgq8xJmTMsvX13zogG48+A\n//8A17TWzyqlvgvEa62/42S9MmCB1tqtqysS8IUY3TztsZ6uO83zxc8TokJIiEpweX9NHU28cfoN\nLjdfZn7qfNZPWe/2BOwOjt7+sonLuH/W/YwLG7gGkC/4qry7P2vp3AX8pufxb4DNXm5PCDFGeHPn\naH5iPk+vfJqI0AgqmytdzqCJDY/ly3O+zNKJSzlcfZhfHv0l18zXPGp/TFgM2XHZHKw8yJPvP8mR\nKtcnafHEaMjDT9ZaVwP0/DtQEWsN/EMpdVgptX2wDSqltiulDimlDtXVyQ1TQoxW3t45mhqTylMr\nnmLK+CmUNZVh164VQTQajNw66Vbun3k/TR1NvHj4RU7WnXS7/dCdEZQRm0F4SDg/3v9jXjj0gtPZ\nvXxhROThK6V2K6VOOPm5y439LNNazwNuB76plFox0Ipa65e01gu01gsSE2++U08IMTr4osc6Lmwc\n31ryLVZnraasscytG62mTZjG1+Z/jaSoJF4/9TrvnHvHpXIOzkSHRpMdn83h6sM88f4THKw86PPe\n/nCklQ85uKW1XjvQa0qpK0qpVK11tVIqFagdYBtVPf/WKqXeAgqBfR62WQgxCsRFmmhovzlAx7lZ\nzNBkNPHluV8mJSaFPx7/I8lRyUSFRrn03tjwWB6a8xDvl77PPyv+SUVzBdumb2N8xHi32gA9vf1x\nGbR1tfHTAz+lML2QL83+EnHhcW5vyxnHtQ1/ppV7djXjhp3Al4Fne/79a/8VlFJRgEFr3dLz+Dbg\nf3q53wFJHr4QI8NAHWBPOsZKKdZPWU9KVAr/dfC/6LR1uhy0jQYjt02+jazYLHac3cFLh1/iztw7\nmZ443f2GAFGhUeSYciipKeFk3UkemvMQhemFPsnk2VyQ7td45e0Y/rPArUqpc8CtPc9RSqUppd7p\nWScZ+EgpVQIcAN7WWr/r5X6dGo7yokII1zSZnQ+/DLTcFXNT5/LUyqcwKAPVLdVuvTc3IZevzf8a\nCZEJvHbqNf5+/u8eD/EopUgfl060KZr/Ovhf/OzAz2gw35SRPuKMqRuvfJXWJITwnj//HhvMDfz0\nwE8pbSglMzbTrd61zW5jd+lu9lfsJy0mjW3524iPiPe4LVprqlqqMBqMPDT3IRalLwroXbpBM8Wh\nVMsUYuTw50XI+Ih4vr3s2xSmF1LaWOpWT91oMLJu8jrum3Ef18zXePHwi5y+6vnEKI7e/riwcfz8\nwM95vvh5j1NBZcYrN0i1TCFGDn8XAwsPCedrC77G1vytLtfW7y0vIY/t87YzIXICfzn5F949/y42\nu23oNw4g0hRJTnwOJ2tP8sT7T/DxpY/dyuQZjiHpMTWk46tbk4UQo0txRTEvHn6R2LBYYsNj3Xqv\nzW7jvYvvUVxZTFpMGvdMv8frzJt2Szs1rTUUpBTw4JwHXartMxrutB1RhqO8qBBi5HHU1rfYLVxp\nveLWe40GI+unrOfe6fdS317Pi4df5MzVM161J9IUSU5cDqevnubJ95/kw/IPh7xxbDiGpMdUD18I\nEdzq2ur4cfGPqWmpIWNchtsXTxvMDbx26jWqW6tZnL6YtZPWulyqeSBmi5nq1mpmJ8/mobkPkRDp\nvDaQ9PCFEMINiVGJfO+W7zEreRaljaVuj8nHR8TzLwX/wsK0heyv3M+vjv3K5SkYBxJhiiAnLodz\n9ed48v0n2Vu+12lvfzjutJUevhBizPG0tn5vp+pOsfPsTpRSbM7dTG6C94HXbDFT1VLFrORZfGXu\nV0iM6ls+ZkTXw/c3CfhCCE85auv/6tivSIhI8Ggyk2vma7x+6nWqW6tZkrGEz+V8zushHq01Na01\n2LWdL87+IiuyVmBQvhtskYAvhAhajtr6RmW8qUftCqvdyq4LuzhUdYiMmAy2Td/mdiaQMx3WDqpa\nqpiROIOH5j5EcnSy19sECfhCiAAZKbWtqluq+c/9/0mjuZG0mDSP7oQ9WXuSnZ/txKiMbM7bzLQJ\n07xul6O3b9M27p95P6uyV3l9BiEBXwgx7EbafTHNnc38/ODPOX31NFmxWR4No9S31/P66depaa1h\n6cSlrMle43WAhhu9/fyEfL5S8BVSolM83pZk6Qghhp23E6D4mqO2/qqsVZQ2lLpVW99hQuQEvlrw\nVeanzueTy5/wm5Lf+GRClPCQcHLicihvKuf7Rd9n94XdfpldSwK+EMIvRmJtq1BjKA/NfYgvzv4i\nl5sv09bV5vY2Qgwh3DHtDu7Ov5srbVd48fCLnKs/53XblFKkRKeQGJnI747/zuN6PIORgC+E8IuR\nWtvKUVv/W4u/RUNHg8eBdWbSTLbP2864sHH88cQf2X1xt8vTMA7kTHUzv99fxZuHqrn9+X1SPE0I\nMToMx41E3pibOpcfrPyBR7X1HRxDPPNS5/Hx5Y/5zbHf0NzZ7NG2zlQ3849TV2jpsKCB6iYzj79W\n4tOgLwFfCOEXo6G2VWZsJk+vfJq0mDTKG8s9Gjc3GU1smraJLXlbqG6t5sXDL3L+2nm3t7Pnszrs\n/fZvsWue2enZBOzOeDvFoRBCDMjfU/b5QnxEPN9Z/h1+dfRX/LPin2TGZhJicD80zk6eTVpMGq+d\neo0/HP8DyzOXszp7tcvZQB0W52UgGr2YIaw/6eELIYKeo7b+lrwtHtXWd0iITODhgocpSCngo0sf\n8duS39LS2eLj1npOAr4QQgAGZWBL/ha+sfAb1LXXeZxuaTKauDP3TjbnbaaqpYoXDr/AhWsXhnxf\nuMl5Pn98pMmjdjgjAV8IIXpZMnEJT97ypEe19XubkzyHR+Y9QpQpit8f/z0flH4waBbPqmmJGAx9\n7wA2GRVPb5rhcRv68yrgK6XuUUqdVErZlVJO7+zqWW+9UuqsUuq8Uuq73uxTCCH8bcr4KTy98mnG\nR47nctNlj2+CSoxK5JF5jzA3ZS77Lu3jdyW/G3CIJy91HLflJxMTbkIBqbERPLdtjk+vgXhVWkEp\nlQ/YgReBf9Na31QHQSllBD4DbgUqgIPA/VrrU0NtX0orCCECqa2rjZcOv8TRmqNkxWZ5VUbhWM0x\n3j73NmHGMLbmb2VS/KQB173cfJn/WPsfA06WMhi/lVbQWp/WWg91n3QhcF5rfVFr3QW8CtzlzX6F\nEGI4RIVG8a+L/pX1U9ZT1lhGp7XT423NTZnLI/MeIdIUye8+/R17yvbcNMRzprqZX35Uyl8OXGbD\n8x+Oyhuv0oHLvZ5X9CxzSim1XSl1SCl1qK6uzu+NE0KIwYQYQrh/5v18dd5XqWqt8irrJikqiYfn\nPcyc5DnsLd/L7z/9Pa1drUB3sN99prbPjVdPvHl8eG+8UkrtVkqdcPLjai/dWR3SAceRtNYvaa0X\naK0XJCa6X7taCCF8TSnFquxVfHfZd2m1tFLX5nlnNNQYyua8zdyZeyeXmy/zwqEXKG0o5eML9Vht\nfXv8vi42N2TA11qv1VrPdPLzVxf3UQFM7PU8A6jypLFCCBFI+Yn5PL3yaSJDI6lsrvSqomVBSgGP\nzHuECFMEv/v0d9R1HUVzcxaPL4vNDceQzkFgqlIqRykVCnwe2DkM+xVCCJ9Li0njqRVPMXn8ZMqa\nyrwqmJYUlcQj8x5hVtIsOo3HaTfuwU7fAO/LYnPepmVuUUpVAEuAt5VSu3qWpyml3gHQWluBx4Bd\nwGngL1pr3xWHEEKIYeaL2voOjiGehcm3YlVXaQ15F5vqruDp62JzMuOVEEJ4SGvNrgu7+OPxP5Ic\nlUxUaJRX29tfdoH3y3di1Wbmhf6ap25f6XYe/mBpmVI8TQghPOSorZ8SlcLPDv6MTlsn4yPGe7y9\nxdmTKcj4BifrTvLKXbd5lIc/GCmtIIQQXupdW7+mpcarbYWFhJEY5Z8MRQn4QgjhA47a+qkxqVxq\nvOSXOWm9JQFfCCF8xFFbf2H6QkobS7HarYFuUh8S8IUQwod8VVvfHyTgCyGEj/WurV/bXutxbX1f\nk4AvhBB+smTiEr53y/fosnd5VVvfVyTgCyGEH00ZP4VnVj7D+Ajvauv7ggR8IYTws8SoRJ685Ulm\nJs2krLEMm935hOX+JgFfCCGGgaO2/rop6yhvKveqtr6nJOALIcQwcdTW/5eCf/G6tr4nJOALIcQw\nctTW//bSb3tdW99dEvCFECIAZiTN6K6tb/K+tr6rJOALIUSApMWk8dTK7tr65U3lXtXWd4UEfCGE\nCCBHbf2VWSspayjzqrb+UKQ8shBCBFioMZSH5j5Eakwqfzz+R78N70jAF0KIEcBRWz85Kpnffvpb\njMro831IwBdCiBGkILWAvIQ8Iky+m8vWQQK+EEKMEDuOVvLcrrNUNZpJi4vg8XW5bk9xOBgJ+EII\nMQLsOFrJE28ex2zpLrtQ2WjmiTePA/gs6HuVpaOUukcpdVIpZVdKOZ00t2e9MqXUcaXUMaWUzEou\nhBD9PLfr7PVg72C22Hhu11mf7cPbHv4JYCvwogvrrtZaX/Vyf0IIMSZVNjqfKGWg5Z7wKuBrrU9D\n99VlIYQQnjMqhc1JOqbRh/F1uG680sA/lFKHlVLbB1tRKbVdKXVIKXWorm74akwIIUQgOQv2gy33\nxJA9fKXUbiDFyUvf01r/1cX9LNNaVymlkoD3lFJntNb7nK2otX4JeAlgwYIFI2/adyGE8IP0uAin\nwzfpcb5Lzxwy4Gut13q7E611Vc+/tUqpt4BCwGnAF0KIYPT4utw+WToAESYjj6/L9dk+/D6ko5SK\nUkrFOB4Dt9F9sVcIIUSPzQXp/GjrLNLjIlB09+x/tHWWT/PwvU3L3KKUqgCWAG8rpXb1LE9TSr3T\ns1oy8JFSqgQ4ALyttX7Xm/0KIYRwn7dZOm8BbzlZXgVs6Hl8EZjjzX6EEGKs23G0ksdfK8Fi7750\nWdlo5vHXSoARcuOVEEII33hm58nrwd7BYtc8s/Okz/YhAV8IIUaARrPzOvgDLfeEBHwhhAgSEvCF\nEGIEiI80ubXcExLwhRBiBHh60wxMxr5lFExGxdObZvhsH1IeWQghRgBHJo7UwxdCiCCwuSDdpwG+\nPxnSEUKIICEBXwghgoQEfCGECBIS8IUQIkhIwBdCiCChtA9nU/E1pVQdUO7h2xOAkTiHrrTLPdIu\n90i73DMW25WltU509sKIDvjeUEod0lovCHQ7+pN2uUfa5R5pl3uCrV0ypCOEEEFCAr4QQgSJsRzw\nXwp0AwYg7XKPtMs90i73BFW7xuwYvhBCiL7Gcg9fCCFELxLwhRAiSIzqgK+UWq+UOquUOq+U+q6T\n15VS6ic9r3+qlJo3Qtq1SinVpJQ61vPzg2Fq1ytKqVql1IkBXg/U8RqqXYE6XhOVUh8opU4rpU4q\npf67k3WG/Zi52K5hP2ZKqXCl1AGlVElPu/7dyTqBOF6utCsgn7GefRuVUkeVUn9z8ppvj5fWelT+\nAEbgAjAJCAVKgOn91tkA/B1QwGKgeIS0axXwtwAcsxXAPODEAK8P+/FysV2BOl6pwLyexzHAZyPk\nM+ZKu4b9mPUcg+iexyagGFg8Ao6XK+0KyGesZ9/fAv7obP++Pl6juYdfCJzXWl/UWncBrwJ39Vvn\nLuC3utt+IE4plToC2hUQWut9wLVBVgnE8XKlXQGhta7WWh/pedwCnAb6Fysf9mPmYruGXc8xaO15\naur56Z8VEojj5Uq7AkIplQFsBH4xwCo+PV6jOeCnA5d7Pa/g5g+9K+sEol0AS3pOMf+ulPLdHGbe\nCcTxclVAj5dSKhsooLt32FtAj9kg7YIAHLOe4YljQC3wntZ6RBwvF9oFgfmM/Rj4NmAf4HWfHq/R\nHPCVk2X9v7VdWcfXXNnnEbrrXcwBfgrs8HObXBWI4+WKgB4vpVQ08AbwP7TWzf1fdvKWYTlmQ7Qr\nIMdMa23TWs8FMoBCpdTMfqsE5Hi50K5hP15KqTuAWq314cFWc7LM4+M1mgN+BTCx1/MMoMqDdYa9\nXVrrZscpptb6HcCklErwc7tcEYjjNaRAHi+llInuoPoHrfWbTlYJyDEbql2B/oxprRuBPcD6fi8F\n9DM2ULsCdLyWAXcqpcroHvpdo5T6fb91fHq8RnPAPwhMVUrlKKVCgc8DO/utsxN4sOdK92KgSWtd\nHeh2KaVSlFKq53Eh3f8P9X5ulysCcbyGFKjj1bPPXwKntdb/OcBqw37MXGlXII6ZUipRKRXX8zgC\nWAuc6bdaII7XkO0KxPHSWj+htc7QWmfTHSeKtNZf6reaT4/XqJ3EXGttVUo9BuyiOzPmFa31SaXU\noz2vvwC8Q/dV7vNAO/CVEdKubcDXlVJWwAx8XvdckvcnpdSf6M5GSFBKVQBP030BK2DHy8V2BeR4\n0d0DewA43jP+C/AkkNmrbYE4Zq60KxDHLBX4jVLKSHfA/IvW+m+B/pt0sV2B+ozdxJ/HS0orCCFE\nkBjNQzpCCCHcIAFfCCGChAR8IYQIEhLwhRAiSEjAF0KIICEBXwghgoQEfCGECBL/Px4t6r0Hy2Aw\nAAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "\u003cFigure size 600x400 with 1 Axes\u003e" ] }, "metadata": {}, "output_type": "display_data" } - ], - "source": [ - - "jk = Lowess('x', 'y') | Jackknife('cookie', confidence=0.9) | compute_on(df_sin)\n", - "point_est = jk[('y', 'Value')]\n", - "ci_lower = jk[('y', 'Jackknife CI-lower')]\n", - "ci_upper = jk[('y', 'Jackknife CI-upper')]\n", - "\n", - "plt.scatter(df_sin.x, df_sin.y)\n", - "plt.plot(x, point_est, c='g')\n", - "plt.fill_between(\n", - " x, ci_lower,\n", - " ci_upper,\n", - " color='g',\n", - " alpha=0.5)\n", - "plt.show()" ] } ], @@ -4525,6 +4582,6 @@ "name": "python3" } }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 0, + "nbformat": 4 } diff --git a/metrics.py b/metrics.py index d2c0313..f423f53 100644 --- a/metrics.py +++ b/metrics.py @@ -12,29 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. """Base classes for Meterstick.""" - from __future__ import absolute_import +from __future__ import annotations from __future__ import division from __future__ import print_function +from collections.abc import Callable, Iterable, Iterator, Sequence import copy import datetime import itertools -from typing import Any, Iterable, List, Optional, Sequence, Text, Union +from typing import Any, TYPE_CHECKING from meterstick import sql from meterstick import utils import numpy as np import pandas as pd +if TYPE_CHECKING: + import apache_beam + -def compute_on(df, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - cache=None, - **kwargs): +def compute_on( + df: pd.DataFrame, + split_by: utils.StrOrList | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + cache: dict | None = None, + **kwargs, +) -> Callable[[Metric], utils.ReturnType]: # pylint: disable=g-long-lambda return lambda x: x.compute_on(df, split_by, melted, return_dataframe, cache_key, cache, **kwargs) @@ -43,15 +49,16 @@ def compute_on(df, # pylint: disable=g-long-lambda def compute_on_sql( - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - return_dataframe=True, - **kwargs): + table: utils.TableType, + split_by: utils.StrOrList | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + return_dataframe: bool = True, + **kwargs, +) -> Callable[[Metric], utils.ReturnType]: """A wrapper that metric | compute_on_sql() === metric.compute_on_sql().""" return lambda m: m.compute_on_sql( table, @@ -66,20 +73,20 @@ def compute_on_sql( def compute_on_beam( - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - sql_transform_kwargs=None, - dialect=None, + pcol: apache_beam.pvalue.PCollection, + split_by: utils.StrOrList | None = None, + execute: Any = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + sql_transform_kwargs: dict | None = None, + dialect: str | None = None, **kwargs, -): +) -> Callable[[Metric], pd.DataFrame]: """A wrapper for metric.compute_on_beam().""" return lambda m: m.compute_on_beam( - table, + pcol, split_by, execute, melted, @@ -95,7 +102,9 @@ def compute_on_beam( # pylint: enable=g-long-lambda -def to_sql(table, split_by=None): +def to_sql( + table: utils.TableType, split_by: utils.StrOrList | None = None +) -> Callable[[Metric], sql.Sql]: return lambda metric: metric.to_sql(table, split_by) @@ -229,15 +238,16 @@ class Metric(object): cache_key: The key currently being used in computation. """ - def __init__(self, - name: Text, - children: Optional[Union['Metric', Sequence[Union['Metric', int, - float]]]] = (), - where: Optional[Union[Text, Sequence[Text]]] = None, - name_tmpl=None, - extra_split_by: Optional[Union[Text, Iterable[Text]]] = None, - extra_index: Optional[Union[Text, Iterable[Text]]] = None, - additional_fingerprint_attrs: Optional[List[str]] = None): + def __init__( + self, + name: str, + children: Metric | Sequence[Metric | utils.Number] | None = (), + where: str | Sequence[str] | None = None, + name_tmpl: str | None = None, + extra_split_by: str | Iterable[str] | None = None, + extra_index: str | Iterable[str] | None = None, + additional_fingerprint_attrs: list[str] | None = None, + ): self.name = name self.cache = {} self.cache_key = None @@ -259,7 +269,7 @@ def __init__(self, self.cache_key = None @property - def where(self): + def where(self) -> str | None: if isinstance(self.where_, (list, tuple)): where_ = self.where_ if len(where_) > 1: @@ -268,7 +278,7 @@ def where(self): return self.where_ @where.setter - def where(self, where): + def where(self, where: str | Sequence[str] | None): if where is None: self.where_ = None elif isinstance(where, str): @@ -276,7 +286,7 @@ def where(self, where): else: self.where_ = tuple(where) - def add_where(self, where): + def add_where(self, where: str | Sequence[str] | None) -> 'Metric': if where is None: return self where = [where] if isinstance(where, str) else list(where) or [] @@ -288,17 +298,17 @@ def add_where(self, where): def _compute_with_caching_and_postprocessing( self, - compute_fn, - df, - split_by, - melted, - return_dataframe, - apply_name_tmpl, - cache_key, - cache, - *args, - **kwargs, - ): + compute_fn: Callable[..., Any], + df: Any, + split_by: str | Sequence[str] | None, + melted: bool, + return_dataframe: bool, + apply_name_tmpl: bool | None, + cache_key: Any, + cache: dict | None, + *args: Any, + **kwargs: Any, + ) -> Any: """Wraps computation logic with caching and common postprocessing. This function does: @@ -349,37 +359,48 @@ def _compute_with_caching_and_postprocessing( if cache_key is None: # Only root metric can have None as cache_key self.clean_up_cache() - def wrap_cache_key(self, key, split_by=None, where=None, slice_val=None): + def wrap_cache_key( + self, + key: Any, + split_by: Sequence[str] | None = None, + where: str | Sequence[str] | None = None, + slice_val: dict | None = None, + ) -> utils.CacheKey: if key and not isinstance(key, utils.CacheKey) and self.cache_key: key = self.cache_key.replace_key(key) key = key or self.cache_key return utils.CacheKey(self, key, where or self.where_, split_by, slice_val) - def save_to_cache(self, key, val, split_by=None): + def save_to_cache( + self, + key: utils.CacheKey | Any, + val: Any, + split_by: Sequence[str] | None = None, + ) -> None: if not isinstance(key, utils.CacheKey): key = self.wrap_cache_key(key, split_by) val = val.copy() if isinstance(val, (pd.Series, pd.DataFrame)) else val self.cache[key] = val - def get_cached(self, key): + def get_cached(self, key: utils.CacheKey | Any) -> Any: key = key if isinstance(key, utils.CacheKey) else self.wrap_cache_key(key) return self.cache[key] - def in_cache(self, key): + def in_cache(self, key: utils.CacheKey | Any) -> bool: key = key if isinstance(key, utils.CacheKey) else self.wrap_cache_key(key) return key in self.cache - def find_all_in_cache_by_metric_type(self, metric): + def find_all_in_cache_by_metric_type(self, metric: type) -> dict: """Retrieves results from a certain type of metric from cache.""" return {k: v for k, v in self.cache.items() if k.metric.__class__ == metric} def manipulate( self, - res, + res: Any, melted: bool = False, return_dataframe: bool = True, - apply_name_tmpl=None, - ): + apply_name_tmpl: bool | None = None, + ) -> Any: """Common adhoc data manipulation. It does @@ -414,18 +435,25 @@ def manipulate( res = utils.apply_name_tmpl(self.name_tmpl, res, melted) return utils.remove_empty_level(res) - def to_dataframe(self, res): + def to_dataframe(self, res: Any) -> pd.DataFrame: if isinstance(res, pd.DataFrame): return res elif isinstance(res, pd.Series): return pd.DataFrame(res) return pd.DataFrame({self.name: [res]}) - def final_compute(self, res, melted, return_dataframe, split_by, df): + def final_compute( + self, + res: Any, + melted: bool, + return_dataframe: bool, + split_by: Sequence[str] | None, + df: Any, + ) -> Any: del melted, return_dataframe, split_by, df # Useful in derived classes. return res - def clean_up_cache(self): + def clean_up_cache(self) -> None: """Flushes the cache when a Metric tree has been computed. A Metric and all the descendants form a tree. When a computation is started @@ -439,17 +467,18 @@ def clean_up_cache(self): """ self.cache.clear() for m in self.traverse(): - m.cache_key = None + if isinstance(m, Metric): + m.cache_key = None def compute_on( self, df: pd.DataFrame, - split_by: Optional[Union[Text, List[Text]]] = None, + split_by: utils.StrOrList | None = None, melted: bool = False, return_dataframe: bool = True, cache_key: Any = None, - cache=None, - ): + cache: dict | None = None, + ) -> utils.ReturnType: """Key API of Metric. This is what you should call to use Metric. As caching is the shared part of @@ -483,22 +512,30 @@ def compute_on( cache, ) - def compute_through(self, df, split_by: Optional[List[Text]] = None): + def compute_through( + self, df: pd.DataFrame, split_by: list[str] | None = None + ) -> Any: """Precomputes df -> split df and apply compute() -> postcompute.""" df = df.query(self.where) if df is not None and self.where else df res = self.precompute(df, split_by) res = self.compute_slices(res, split_by) return self.postcompute(res, split_by) - def precompute(self, df, split_by): + def precompute( + self, df: pd.DataFrame, split_by: list[str] | None + ) -> utils.ReturnType: del split_by # Useful in derived classes. return df - def postcompute(self, df, split_by): + def postcompute( + self, df: utils.ReturnType, split_by: list[str] | None + ) -> utils.ReturnType: del split_by # Useful in derived classes. return df - def compute_slices(self, df, split_by: Optional[List[Text]] = None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> Any: """Applies compute() to all slices. Each slice needs a unique cache_key.""" if self.children: try: @@ -551,11 +588,18 @@ def compute_slices(self, df, split_by: Optional[List[Text]] = None): return self.compute_with_split_by(df) def compute_children( - self, df, split_by, melted=False, return_dataframe=True, cache_key=None - ): + self, + df: pd.DataFrame, + split_by: list[str] | None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: raise NotImplementedError - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: """Computes the return using the result returned by children Metrics. Args: @@ -577,34 +621,41 @@ def compute_on_children(self, children, split_by): return pd.concat(result, keys=slices, names=split_by, sort=False) @staticmethod - def split_data(df, split_by=None): + def split_data( + df: utils.ReturnType, split_by: list[str] | None = None + ) -> Iterable[tuple[utils.ReturnType, Any]]: if not split_by: yield df, None else: + assert isinstance(df, (pd.DataFrame, pd.Series)) for k, idx in df.groupby(split_by, observed=True).indices.items(): # Use iloc rather than loc because indexes can have duplicates. yield df.iloc[idx], k def compute_with_split_by( - self, df, split_by: Optional[List[Text]] = None, slice_value=None - ): + self, + df: utils.ReturnType, + split_by: list[str] | None = None, + slice_value: Any = None, + ) -> utils.ReturnType: del split_by, slice_value # In case users need them in derived classes. return self.compute(df) - def compute(self, df): + def compute(self, df: utils.ReturnType) -> utils.ReturnType: raise NotImplementedError + return df def compute_on_sql( self, - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - return_dataframe=True, - ): + table: utils.TableType, + split_by: utils.StrOrList | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + return_dataframe: bool = True, + ) -> utils.ReturnType: """Computes self in pure SQL or a mixed of SQL and Python. Args: @@ -648,7 +699,13 @@ def compute_on_sql( mode, ) - def compute_through_sql(self, table, split_by, execute, mode): + def compute_through_sql( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + ) -> Any: """Delegeates the computation to different modes.""" if mode not in (None, 'mixed', 'magic'): raise ValueError('Mode %s is not supported!' % mode) @@ -677,10 +734,10 @@ def compute_through_sql(self, table, split_by, execute, mode): else: raise - def to_series_or_number_if_not_operation(self, df): + def to_series_or_number_if_not_operation(self, df: Any) -> utils.ReturnType: return self.to_series_or_number(df) if not self.is_operation else df - def to_series_or_number(self, df): + def to_series_or_number(self, df: Any) -> utils.ReturnType: if not isinstance(df, pd.DataFrame): return df df = df.squeeze(axis=1) # squeeze to a Series if possible @@ -692,7 +749,12 @@ def to_series_or_number(self, df): df = df.squeeze() return df - def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): + def compute_on_sql_sql_mode( + self, + table: utils.TableType, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + ) -> pd.DataFrame: """Executes the query from to_sql() and process the result.""" query = self.to_sql(table, split_by) res = execute(str(query)) @@ -709,9 +771,9 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): def to_sql( self, - table, - split_by: Optional[Union[Text, List[Text]]] = None, - ): + table: utils.TableType, + split_by: utils.StrOrList | None = None, + ) -> sql.Sql: """Generates SQL query for the metric. Args: @@ -751,10 +813,15 @@ def to_sql( return query return query - def get_sql_and_with_clause(self, table: sql.Datasource, - split_by: sql.Columns, global_filter: sql.Filters, - indexes: sql.Columns, local_filter: sql.Filters, - with_data: sql.Datasources): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query for metric and its WITH clause separately. Args: @@ -788,14 +855,26 @@ def get_sql_and_with_clause(self, table: sql.Datasource, raise NotImplementedError('SQL generator is not implemented for %s.' % type(self)) - def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): + def compute_on_sql_mixed_mode( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None = None, + ) -> Any: """Computes the child in SQL and the rest in Python.""" children = self.compute_children_sql(table, split_by, execute, mode) return self.compute_on_children(children, split_by) def compute_children_sql( - self, table, split_by, execute, mode, *args, **kwargs - ): + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + *args: Any, + **kwargs: Any, + ) -> Any: """The return should be similar to compute_children().""" del args, kwargs # unused children = [] @@ -812,17 +891,17 @@ def compute_children_sql( def compute_on_beam( self, - pcol, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - sql_transform_kwargs=None, - dialect=None, - **kwargs, - ): + pcol: apache_beam.pvalue.PCollection, + split_by: list[str] | None = None, + execute: Callable[..., pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + sql_transform_kwargs: dict | None = None, + dialect: str | None = None, + **kwargs: Any, + ) -> utils.ReturnType: """Computes on an Apache Beam PCollection input. Args: @@ -887,7 +966,9 @@ def e(q): finally: sql.set_dialect(current_dialect) - def compute_equivalent(self, df, split_by=None): + def compute_equivalent( + self, df: pd.DataFrame, split_by: list[str] | None = None + ) -> Any: equiv, df = utils.get_fully_expanded_equivalent_metric_tree(self, df) return self.compute_util_metric_on( equiv, df, split_by, return_dataframe=False @@ -895,13 +976,13 @@ def compute_equivalent(self, df, split_by=None): def compute_util_metric_on( self, - metric, - df, - split_by, - melted=False, - return_dataframe=True, - cache_key=None, - ): + metric: 'Metric', + df: pd.DataFrame, + split_by: list[str] | None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: """Computes a util metric with caching and filtering handled correctly.""" cache_key = self.wrap_cache_key(cache_key, split_by) return metric.compute_on( @@ -910,15 +991,15 @@ def compute_util_metric_on( def compute_util_metric_on_sql( self, - metric, - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - return_dataframe=True, - ): + metric: 'Metric', + table: utils.TableType, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + return_dataframe: bool = True, + ) -> Any: """Computes a util metric with caching and filtering handled correctly.""" cache_key = self.wrap_cache_key(cache_key, split_by) return metric.compute_on_sql( @@ -932,7 +1013,7 @@ def compute_util_metric_on_sql( return_dataframe, ) - def get_equivalent(self, *auxiliary_cols): + def get_equivalent(self, *auxiliary_cols: Any) -> Metric | None: """Gets a Metric that is equivalent to self.""" res = self.get_equivalent_without_filter(*auxiliary_cols) # pylint: disable=assignment-from-none if res: @@ -940,12 +1021,14 @@ def get_equivalent(self, *auxiliary_cols): res.add_where(self.where_) return res - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: """Gets a Metric that is equivalent to self but ignoring the filter.""" del auxiliary_cols # might be used in derived classes return - def get_auxiliary_cols(self): + def get_auxiliary_cols(self) -> tuple: """Returns the auxiliary columns required by the equivalent Metric. See utils.add_auxiliary_cols() for the format of the return. @@ -953,10 +1036,12 @@ def get_auxiliary_cols(self): return () @staticmethod - def group(df, split_by=None): + def group(df: pd.DataFrame, split_by: list[str] | None = None) -> Any: return df.groupby(split_by, observed=True) if split_by else df - def visualize_metric_tree(self, rendering_fn, strict=True): + def visualize_metric_tree( + self, rendering_fn: Callable[..., Any], strict: bool = True + ) -> None: """Renders the Metric tree. Args: @@ -967,7 +1052,7 @@ def visualize_metric_tree(self, rendering_fn, strict=True): """ rendering_fn(self.to_dot(strict)) - def to_dot(self, strict=True): + def to_dot(self, strict: bool = True) -> str: """Represents the Metric in DOT language. Args: @@ -995,7 +1080,7 @@ def add_edges(metric): add_edges(self) return dot.to_string() - def get_extra_idx(self, return_superset=False): + def get_extra_idx(self, return_superset: bool = False) -> tuple: """Collects the extra indexes added by self and its descendants. Args: @@ -1021,7 +1106,9 @@ def get_extra_idx(self, return_superset=False): extra_idx += list(children_idx[0]) return tuple(extra_idx) - def traverse(self, include_self=True, include_constants=False): + def traverse( + self, include_self: bool = True, include_constants: bool = False + ) -> Iterable[Metric | utils.Number]: ms = [self] if include_self else list(self.children) while ms: m = ms.pop(0) @@ -1031,52 +1118,52 @@ def traverse(self, include_self=True, include_constants=False): elif include_constants: yield m - def __or__(self, fn): + def __or__(self, fn: Callable[[Metric], Any]) -> Any: """Overwrites the '|' operator to enable pipeline chaining.""" return fn(self) - def __add__(self, other): + def __add__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x + y, '{} + {}', (self, other)) - def __radd__(self, other): + def __radd__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x + y, '{} + {}', (other, self)) - def __sub__(self, other): + def __sub__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x - y, '{} - {}', (self, other)) - def __rsub__(self, other): + def __rsub__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x - y, '{} - {}', (other, self)) - def __mul__(self, other): + def __mul__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x * y, '{} * {}', (self, other)) - def __rmul__(self, other): + def __rmul__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x * y, '{} * {}', (other, self)) - def __neg__(self): + def __neg__(self) -> CompositeMetric: return CompositeMetric(lambda x, _: -x, '-{}', (self, -1)) - def __div__(self, other): + def __div__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x / y, '{} / {}', (self, other)) - def __truediv__(self, other): + def __truediv__(self, other: Metric | utils.Number) -> CompositeMetric: return self.__div__(other) - def __rdiv__(self, other): + def __rdiv__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x / y, '{} / {}', (other, self)) - def __rtruediv__(self, other): + def __rtruediv__(self, other: Metric | utils.Number) -> CompositeMetric: return self.__rdiv__(other) - def __pow__(self, other): + def __pow__(self, other: Metric | utils.Number) -> CompositeMetric: if isinstance(other, float) and other == 0.5: return CompositeMetric(lambda x, y: x**y, 'sqrt({})', (self, other)) return CompositeMetric(lambda x, y: x**y, '{} ^ {}', (self, other)) - def __rpow__(self, other): + def __rpow__(self, other: Metric | utils.Number) -> CompositeMetric: return CompositeMetric(lambda x, y: x**y, '{} ^ {}', (other, self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)) or not isinstance(self, type(other)): return False if self.name != other.name: @@ -1092,10 +1179,10 @@ def __eq__(self, other): return False return True - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.get_fingerprint())) - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: """Returns attributes that uniquely identify the Metric. Metrics with the same fingerprint will compute to the same numbers on the @@ -1142,14 +1229,14 @@ def get_fingerprint(self, attr_to_exclude=()): fingerprint[k] = tuple(list(v)) return tuple(sorted(fingerprint.items())) - def __str__(self): + def __str__(self) -> str: where = f' where {self.where}' if self.where else '' return self.name + where - def __repr__(self): + def __repr__(self) -> str: return self.__str__() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict) -> 'Metric': # We don't copy self.cache, for two reasons. # 1. The copied Metric can share the same cache to maximize caching. # 2. When deepcopy a Metric, its cache refers to CacheKey and CacheKey @@ -1183,12 +1270,14 @@ class MetricList(Metric): And all other attributes inherited from Metric. """ - def __init__(self, - children: Sequence[Metric], - where: Optional[Union[Text, Sequence[Text]]] = None, - children_return_dataframe: bool = True, - name_tmpl=None, - rename_columns=None): + def __init__( + self, + children: Sequence[Metric], + where: str | Sequence[str] | None = None, + children_return_dataframe: bool = True, + name_tmpl: str | None = None, + rename_columns: Sequence[str] | None = None, + ) -> None: for m in children: if not isinstance(m, Metric): raise ValueError('%s is not a Metric.' % m) @@ -1203,7 +1292,10 @@ def __init__(self, self.names = [m.name for m in children] self.columns = rename_columns - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> list[Any]: + assert isinstance(df, pd.DataFrame) """Computes all Metrics with caching. We know df is not going to change so we can safely enable caching with an @@ -1232,7 +1324,9 @@ def compute_slices(self, df, split_by=None): print('Warning: %s failed for reason %s.' % (m.name, repr(e))) return res - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> utils.ReturnType: if isinstance(children, list): children = self.to_dataframe(children) if isinstance(children, pd.DataFrame): @@ -1249,12 +1343,12 @@ def compute_on_children(self, children, split_by): children.columns = self.columns return children - def manipulate( # pytype: disable=annotation-type-mismatch + def manipulate( self, - res: pd.Series, + res: Any, melted: bool = False, return_dataframe: bool = True, - apply_name_tmpl: bool = None, + apply_name_tmpl: bool | None = None, ): """Rename columns if asked in addition to original manipulation.""" res = super(MetricList, self).manipulate( @@ -1275,7 +1369,7 @@ def manipulate( # pytype: disable=annotation-type-mismatch res = utils.melt(res) return res - def to_dataframe(self, res): + def to_dataframe(self, res: Any) -> pd.DataFrame: if not isinstance(res, (list, tuple)): return super(MetricList, self).to_dataframe(res) res_all = pd.concat(res, axis=1, sort=False) @@ -1283,7 +1377,7 @@ def to_dataframe(self, res): res_all.index.names = res[0].index.names return res_all - def unwrap(self) -> List[Metric]: + def unwrap(self) -> list[Metric]: """Unwraps a MetricList and returns a list of all child Metrics. It recursively removes the MetricList wrapper and collects all children @@ -1305,7 +1399,7 @@ def unwrap(self) -> List[Metric]: result.append(child) return result - def rename_columns(self, rename_columns: List[Text]): + def rename_columns(self, rename_columns: list[str]) -> None: """Rename the columns of the MetricList. Useful for instances where you have Metrics in the MetricList that are @@ -1323,15 +1417,15 @@ def rename_columns(self, rename_columns: List[Text]): def compute_on_sql( self, - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - return_dataframe=True, - ): + table: utils.TableType, + split_by: utils.StrOrList | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + return_dataframe: bool = True, + ) -> Any: if return_dataframe: return super(MetricList, self).compute_on_sql( table, split_by, execute, melted, mode, cache_key, cache @@ -1350,7 +1444,13 @@ def compute_on_sql( mode, ) - def compute_children_sql(self, table, split_by, execute, mode=None): + def compute_children_sql( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None = None, + ) -> list[Any]: """The return should be similar to compute_children().""" children = [] for c in self.children: @@ -1370,8 +1470,15 @@ def compute_children_sql(self, table, split_by, execute, mode=None): ) return children - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. The query is constructed by @@ -1396,6 +1503,8 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if with_data is None: + with_data = sql.Datasources() self.get_extra_idx() # Check if indexes are compatible. local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) @@ -1461,15 +1570,15 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, return query, with_data - def __iter__(self): + def __iter__(self) -> Iterator[Metric]: for m in self.children: - yield m + yield m # type: ignore - def __len__(self): + def __len__(self) -> int: return len(self.children) - def __getitem__(self, key): - return self.children[key] + def __getitem__(self, key: int | slice) -> Metric | list[Metric]: + return self.children[key] # type: ignore class CompositeMetric(Metric): @@ -1493,18 +1602,20 @@ class CompositeMetric(Metric): And all other attributes inherited from Metric. """ - def __init__(self, - op, - name_tmpl: Text, - children: Sequence[Union[Metric, int, float]], - rename_columns=None, - where: Optional[Text] = None): + def __init__( + self, + op: Callable[[Any, Any], Any], + name_tmpl: str, + children: Sequence[Metric | utils.Number], + rename_columns: Sequence[str] | None = None, + where: str | None = None, + ) -> None: if len(children) != 2: raise ValueError('CompositeMetric must take two children!') - if not isinstance(children[0], (Metric, int, float)): + if not isinstance(children[0], (Metric,) + utils.NumberTypes): raise ValueError('%s is not a Metric or a number!' % utils.get_name(children[0])) - if not isinstance(children[1], (Metric, int, float)): + if not isinstance(children[1], (Metric,) + utils.NumberTypes): raise ValueError('%s is not a Metric or a number!' % utils.get_name(children[1])) if not isinstance(children[0], Metric) and not isinstance( @@ -1521,17 +1632,24 @@ def __init__(self, self.op = op self.columns = rename_columns - def rename_columns(self, rename_columns): + def rename_columns( + self, rename_columns: Sequence[str] | None + ) -> CompositeMetric: self.columns = rename_columns return self - def set_name(self, name): + def set_name(self, name: str) -> CompositeMetric: self.name = name return self def compute_children( - self, df, split_by, melted=False, return_dataframe=True, cache_key=None - ): + self, + df: pd.DataFrame, + split_by: list[str] | None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> list[Any]: del melted, return_dataframe, cache_key # not used if len(self.children) != 2: raise ValueError('CompositeMetric can only have two children.') @@ -1549,7 +1667,9 @@ def compute_children( children.append(m) return children - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: list[Any], split_by: list[str] | None + ) -> utils.ReturnType: """Computes the result based on the results from the children. Computations between two DataFrames require columns to match. It makes @@ -1613,8 +1733,15 @@ def compute_on_children(self, children, split_by): res.columns = self.columns return res - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. A CompositeMetric has two children and at least one of them is a Metric. The @@ -1641,6 +1768,8 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if with_data is None: + with_data = sql.Datasources() local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) @@ -1741,7 +1870,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, return query, with_data - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: # Make Sum(x) / Count(x) indistinguishable to Mean(x) in cache. s = self.children[0] c = self.children[1] @@ -1754,11 +1883,13 @@ def get_fingerprint(self, attr_to_exclude=()): class Ratio(CompositeMetric): """Syntactic sugar for Sum('A') / Sum('B').""" - def __init__(self, - numerator: Text, - denominator: Text, - name: Optional[Text] = None, - where: Optional[Text] = None): + def __init__( + self, + numerator: str, + denominator: str, + name: str | None = None, + where: str | None = None, + ): super(Ratio, self).__init__( lambda x, y: x / y, '{} / {}', (Sum(numerator), Sum(denominator)), @@ -1767,7 +1898,7 @@ def __init__(self, self.denominator = denominator self.name = name or self.name - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: # Make the fingerprint same as the equivalent CompositeMetric for caching. util = self.children[0] / self.children[1] util.where = self.where_ # pytype: disable=not-writable @@ -1777,12 +1908,14 @@ def get_fingerprint(self, attr_to_exclude=()): class SimpleMetric(Metric): """Base class for common built-in aggregate functions of df.group_by().""" - def __init__(self, - var: Text, - name: Optional[Text] = None, - name_tmpl=None, - where: Optional[Union[Text, Sequence[Text]]] = None, - additional_fingerprint_attrs: Optional[List[str]] = None): + def __init__( + self, + var: str, + name: str | None = None, + name_tmpl=None, + where: str | Sequence[str] | None = None, + additional_fingerprint_attrs: list[str] | None = None, + ): name = name or name_tmpl.format(var) self.var = var additional_fingerprint_attrs = ['var', 'var2'] + ( @@ -1794,8 +1927,15 @@ def __init__(self, name_tmpl, additional_fingerprint_attrs=additional_fingerprint_attrs) - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) @@ -1806,7 +1946,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, return equiv.get_sql_and_with_clause(table, split_by, global_filter, indexes, local_filter, with_data) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: del local_filter # unused raise ValueError('get_sql_columns is not implemented for %s.' % type(self)) @@ -1823,21 +1963,26 @@ class Count(SimpleMetric): And all other attributes inherited from Metric. """ - def __init__(self, - var: Text, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None, - distinct: bool = False): + def __init__( + self, + var: str, + name: str | None = None, + where: str | Sequence[str] | None = None, + distinct: bool = False, + ): self.distinct = distinct if distinct: name = name or 'count(distinct %s)' % str(var) super(Count, self).__init__(var, name, 'count({})', where, ['distinct']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) grped = self.group(df, split_by)[self.var] return grped.nunique() if self.distinct else grped.count() - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: if self.distinct: return sql.Column(self.var, 'COUNT(DISTINCT {})', self.name, local_filter) else: @@ -1855,16 +2000,21 @@ class Sum(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var: Text, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): super(Sum, self).__init__(var, name, 'sum({})', where) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) return self.group(df, split_by)[self.var].sum() - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: return sql.Column(self.var, 'SUM({})', self.name, local_filter) @@ -1881,18 +2031,23 @@ class Dot(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var1: Text, - var2: Text, - normalize=False, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var1: str, + var2: str, + normalize: bool = False, + name: str | None = None, + where: str | Sequence[str] | None = None, + ) -> None: self.var2 = var2 self.normalize = normalize name_tmpl = ('mean({} * %s)' if normalize else 'sum({} * %s)') % str(var2) super(Dot, self).__init__(var1, name, name_tmpl, where, ['normalize']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if not split_by: prod = (df[self.var] * df[self.var2]) return prod.mean() if self.normalize else prod.sum() @@ -1902,19 +2057,21 @@ def compute_slices(self, df, split_by=None): fn = lambda df: (df[self.var] * df[self.var2]).sum() return df.groupby(split_by, observed=True).apply(fn) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: if self.normalize: return Sum(auxiliary_cols[0]) / Count(auxiliary_cols[0]) return Sum(auxiliary_cols[0]) - def get_auxiliary_cols(self): + def get_auxiliary_cols(self) -> tuple: return ((self.var, '*', self.var2),) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: tmpl = 'AVG({} * {})' if self.normalize else 'SUM({} * {})' return sql.Column((self.var, self.var2), tmpl, self.name, local_filter) - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: if str(self.var) > str(self.var2): util = copy.deepcopy(self) util.var = self.var2 @@ -1935,21 +2092,26 @@ class Mean(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var: Text, - weight: Optional[Text] = None, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + weight: str | None = None, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): name_tmpl = '%s-weighted mean({})' % str(weight) if weight else 'mean({})' super(Mean, self).__init__(var, name, name_tmpl, where, ['weight']) self.weight = weight - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.weight: return self.compute_equivalent(df, split_by) return self.group(df, split_by)[self.var].mean() - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> Any: if not self.weight: return sql.Column(self.var, 'AVG({})', self.name, local_filter) else: @@ -1958,7 +2120,9 @@ def get_sql_columns(self, local_filter): res /= sql.Column(self.weight, 'SUM({})', 'total_weight', local_filter) return res.set_alias(self.name) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: del auxiliary_cols # unused if not self.weight: return Sum(self.var) / Count(self.var) @@ -1976,16 +2140,21 @@ class Max(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var: Text, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): super(Max, self).__init__(var, name, 'max({})', where) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) return self.group(df, split_by)[self.var].max() - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: return sql.Column(self.var, 'MAX({})', self.name, local_filter) @@ -2000,16 +2169,21 @@ class Min(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var: Text, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): super(Min, self).__init__(var, name, 'min({})', where) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) return self.group(df, split_by)[self.var].min() - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: return sql.Column(self.var, 'MIN({})', self.name, local_filter) @@ -2030,15 +2204,15 @@ class Nth(SimpleMetric): def __init__( self, - var: Text, + var: str, n: int, - sort_by: Text, + sort_by: str, ascending: bool = True, dropna: bool = True, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None, - additional_fingerprint_attrs: Optional[List[str]] = None, - ): + name: str | None = None, + where: str | Sequence[str] | None = None, + additional_fingerprint_attrs: list[str] | None = None, + ) -> None: if not isinstance(n, int): raise ValueError('n must be an integer.') if n < 0: @@ -2073,7 +2247,10 @@ def __init__( additional_fingerprint_attrs=additional_fingerprint_attrs ) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.dropna: df = df.dropna(subset=[self.var]) df = df.sort_values(self.sort_by, ascending=self.ascending) @@ -2084,8 +2261,15 @@ def compute_slices(self, df, split_by=None): return np.nan return df[self.var].values[self.n] - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. If there is no local filter, the metric can be expressed in one line like @@ -2110,6 +2294,8 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if with_data is None: + with_data = sql.Datasources() local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) @@ -2143,7 +2329,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, filtered_table_alias, split_by.aliases, None, indexes, None, with_data ) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column: if local_filter: raise ValueError( 'This case should be handled by get_sql_and_with_clause() already.' @@ -2176,13 +2362,15 @@ class Quantile(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var: Text, - quantile: Union[float, int, Sequence[Union[float, int]]] = 0.5, - weight: Optional[Text] = None, - interpolation='linear', - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + quantile: float | int | Sequence[float | int] = 0.5, + weight: str | None = None, + interpolation: str = 'linear', + name: str | None = None, + where: str | Sequence[str] | None = None, + ) -> None: if isinstance(quantile, (int, float)): self.one_quantile = True else: @@ -2206,7 +2394,10 @@ def __init__(self, super(Quantile, self).__init__(var, name, name_tmpl, where, ['quantile', 'weight', 'interpolation']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.weight: # Adapted from https://stackoverflow.com/a/29677616/12728137. def interp(d): @@ -2243,7 +2434,7 @@ def interp(d): res.columns = [self.name_tmpl.format(self.var, c[0]) for c in res] return res - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> Any: """Get SQL columns.""" if self.weight: raise ValueError('SQL for weighted quantile should already be handled!') @@ -2265,8 +2456,15 @@ def get_sql_columns(self, local_filter): sql.Column(self.var, sql.QUANTILE_FN(q), alias, local_filter)) return sql.Columns(quantiles) - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL for weighted quantile. The query is constructed as following. @@ -2319,6 +2517,8 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if with_data is None: + with_data = sql.Datasources() if not self.weight: # Fall back to get_sql_columns(). return super(Quantile, self).get_sql_and_with_clause( table, split_by, global_filter, indexes, local_filter, with_data @@ -2349,7 +2549,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, v = split_by_and_value.aliases[-1] w = weight.alias - split_by = sql.Columns(split_by.aliases) + split_by = sql.Columns(sql.Columns(split_by).aliases) split_by_and_value = sql.Columns(split_by_and_value.aliases) total_weight = sql.Column(w, 'SUM({})', partition=split_by) cum_weight = sql.Column( @@ -2437,24 +2637,31 @@ class Variance(SimpleMetric): SimpleMetric. """ - def __init__(self, - var: Text, - unbiased: bool = True, - weight: Optional[Text] = None, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + unbiased: bool = True, + weight: str | None = None, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): self.ddof = 1 if unbiased else 0 self.weight = weight name_tmpl = '%s-weighted var({})' % str(weight) if weight else 'var({})' super(Variance, self).__init__(var, name, name_tmpl, where, ['ddof', 'weight']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.weight: return self.compute_equivalent(df, split_by) return self.group(df, split_by)[self.var].var(ddof=self.ddof) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: if not self.weight: return Cov(self.var, self.var, ddof=self.ddof) numer = Dot(auxiliary_cols[0], @@ -2463,12 +2670,12 @@ def get_equivalent_without_filter(self, *auxiliary_cols): # ddof is invalid if it makes the denom negative so we use ((denom)^0.5)^2. return numer / (denom**0.5)**2 - def get_auxiliary_cols(self): + def get_auxiliary_cols(self) -> tuple: if self.weight: return ((self.var, '**', 2),) return () - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column | None: if self.weight: return if self.ddof == 1: @@ -2491,28 +2698,35 @@ class StandardDeviation(SimpleMetric): SimpleMetric. """ - def __init__(self, - var: Text, - unbiased: bool = True, - weight: Optional[Text] = None, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + unbiased: bool = True, + weight: str | None = None, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): self.ddof = 1 if unbiased else 0 self.weight = weight name_tmpl = '%s-weighted sd({})' % str(weight) if weight else 'sd({})' super(StandardDeviation, self).__init__(var, name, name_tmpl, where, ['ddof', 'weight']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.weight: return self.compute_equivalent(df, split_by) return self.group(df, split_by)[self.var].std(ddof=self.ddof) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: del auxiliary_cols # unused return Variance(self.var, bool(self.ddof), self.weight) ** 0.5 - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column | None: if self.weight: return if self.ddof == 1: @@ -2534,23 +2748,30 @@ class CV(SimpleMetric): SimpleMetric. """ - def __init__(self, - var: Text, - unbiased: bool = True, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var: str, + unbiased: bool = True, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): self.ddof = 1 if unbiased else 0 super(CV, self).__init__(var, name, 'cv({})', where, ['ddof']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) var_grouped = self.group(df, split_by)[self.var] return var_grouped.std(ddof=self.ddof) / var_grouped.mean() - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: del auxiliary_cols # unused return StandardDeviation(self.var, bool(self.ddof)) / Mean(self.var) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> Any: if self.ddof == 1: res = sql.Column(self.var, sql.STDDEV_SAMP_FN, self.name, local_filter) / sql.Column( @@ -2577,13 +2798,15 @@ class Correlation(SimpleMetric): And all other attributes inherited from SimpleMetric. """ - def __init__(self, - var1: Text, - var2: Text, - weight: Optional[Text] = None, - name: Optional[Text] = None, - method='pearson', - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var1: str, + var2: str, + weight: str | None = None, + name: str | None = None, + method='pearson', + where: str | Sequence[str] | None = None, + ): name_tmpl = 'corr({}, {})' if weight: name_tmpl = '%s-weighted corr({}, {})' % str(weight) @@ -2595,7 +2818,10 @@ def __init__(self, super(Correlation, self).__init__(var1, name, name_tmpl, where, ['method', 'weight']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.weight and self.method != 'pearson': raise NotImplementedError( 'Only Pearson correlation is supported in weighted Correlation!' @@ -2615,7 +2841,9 @@ def compute_slices(self, df, split_by=None): return self.group(df, split_by)[self.var].corr( df[self.var2], method=self.method) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: del auxiliary_cols # unused if self.method == 'pearson': return ( @@ -2624,7 +2852,7 @@ def get_equivalent_without_filter(self, *auxiliary_cols): / StandardDeviation(self.var2, False, self.weight) ) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column | None: if self.weight: return if self.method != 'pearson': @@ -2632,7 +2860,7 @@ def get_sql_columns(self, local_filter): return sql.Column((self.var, self.var2), sql.CORR_FN, self.name, local_filter) - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: if str(self.var) > str(self.var2): util = copy.deepcopy(self) util.var = self.var2 @@ -2660,15 +2888,17 @@ class Cov(SimpleMetric): SimpleMetric. """ - def __init__(self, - var1: Text, - var2: Text, - bias: bool = False, - ddof: Optional[int] = None, - weight: Optional[Text] = None, - fweight: Optional[Text] = None, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None): + def __init__( + self, + var1: str, + var2: str, + bias: bool = False, + ddof: int | None = None, + weight: str | None = None, + fweight: str | None = None, + name: str | None = None, + where: str | Sequence[str] | None = None, + ): name_tmpl = 'cov({}, {})' if weight: name_tmpl = '%s-weighted %s' % (str(weight), name_tmpl) @@ -2685,10 +2915,15 @@ def __init__(self, super(Cov, self).__init__(var1, name, name_tmpl, where, ['bias', 'ddof', 'weight', 'fweight']) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) return self.compute_equivalent(df, split_by) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> Metric | None: """Gets the equivalent Metric for Cov.""" # See https://numpy.org/doc/stable/reference/generated/numpy.cov.html. ddof = self.ddof if self.ddof is not None else int(not self.bias) @@ -2723,7 +2958,7 @@ def get_equivalent_without_filter(self, *auxiliary_cols): res /= ((1 - ddof / v1) ** 0.5) ** 2 return res - def get_auxiliary_cols(self): + def get_auxiliary_cols(self) -> tuple: if not self.weight and not self.fweight: return () if not self.weight or not self.fweight: @@ -2733,7 +2968,7 @@ def get_auxiliary_cols(self): (self.fweight, '*', self.weight), ) - def get_sql_columns(self, local_filter): + def get_sql_columns(self, local_filter: sql.Filters | None) -> sql.Column | None: """Get SQL columns.""" if self.weight or self.fweight: return @@ -2750,7 +2985,7 @@ def get_sql_columns(self, local_filter): ) return - def get_fingerprint(self, attr_to_exclude=()): + def get_fingerprint(self, attr_to_exclude: Iterable[str] = ()) -> tuple: if str(self.var) > str(self.var2): util = copy.deepcopy(self) util.var = self.var2 diff --git a/models.py b/models.py index fed3dcf..ddf5342 100644 --- a/models.py +++ b/models.py @@ -1086,6 +1086,7 @@ def compute_on_sql_magic_mode(self, table, split_by, execute): n_y = n_y.compute_on_sql( y, y.groupby.aliases[len(self.group_by):], execute ) + assert isinstance(n_y, pd.DataFrame) if (n_y.values != 2).any(): raise ValueError( f'Magic mode only support two classes but got {n_y} distinct y' diff --git a/operations.py b/operations.py index 10b08ad..26dd7e8 100644 --- a/operations.py +++ b/operations.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """Operation classes for Meterstick.""" - from __future__ import absolute_import +from __future__ import annotations from __future__ import division from __future__ import print_function +from collections.abc import Callable, Iterable, Sequence import copy import inspect import types -from typing import Any, Iterable, List, Literal, Optional, Sequence, Text, Tuple, Type, Union +from typing import Any, Literal import warnings from meterstick import confidence_interval_display @@ -32,7 +33,7 @@ from scipy import stats -def count_features(m: metrics.Metric): +def count_features(m: metrics.Metric) -> int: """Gets the width of the result of m.compute_on().""" if not m: return 0 @@ -110,15 +111,17 @@ class Operation(metrics.Metric): all other attributes inherited from Metric. """ - def __init__(self, - child: Optional[metrics.Metric] = None, - name_tmpl: Optional[Text] = None, - extra_split_by: Optional[Union[Text, Iterable[Text]]] = None, - extra_index: Optional[Union[Text, Iterable[Text]]] = None, - name: Optional[Text] = None, - where: Optional[Union[Text, Sequence[Text]]] = None, - additional_fingerprint_attrs: Optional[List[str]] = None, - **kwargs): + def __init__( + self, + child: metrics.Metric | None = None, + name_tmpl: str | None = None, + extra_split_by: str | Iterable[str] | None = None, + extra_index: str | Iterable[str] | None = None, + name: str | None = None, + where: str | Sequence[str] | None = None, + additional_fingerprint_attrs: list[str] | None = None, + **kwargs, + ) -> None: if name_tmpl and not name: name = name_tmpl.format(utils.get_name(child)) super(Operation, @@ -127,7 +130,10 @@ def __init__(self, self.precomputable_in_jk_bs = True self.is_operation = True - def compute_slices(self, df, split_by: Optional[List[Text]] = None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) try: children = self.compute_children(df, split_by + self.extra_split_by) res = self.compute_on_children(children, split_by) @@ -137,42 +143,56 @@ def compute_slices(self, df, split_by: Optional[List[Text]] = None): except NotImplementedError: return super(Operation, self).compute_slices(df, split_by) - def compute_children(self, - df: pd.DataFrame, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None): + def compute_children( + self, + df: pd.DataFrame, + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: return self.compute_child(df, split_by, melted, return_dataframe, cache_key) - def compute_child(self, - df: pd.DataFrame, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None): + def compute_child( + self, + df: pd.DataFrame, + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: child = self.children[0] return self.compute_util_metric_on(child, df, split_by, melted, return_dataframe, cache_key) - def compute_child_sql(self, - table, - split_by, - execute, - melted=False, - mode=None, - cache_key=None): + def compute_child_sql( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + ) -> Any: child = self.children[0] cache_key = self.wrap_cache_key(cache_key, split_by) return self.compute_util_metric_on_sql(child, table, split_by, execute, melted, mode, cache_key) - def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): + def compute_on_sql_mixed_mode( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None = None, + ) -> Any: res = super(Operation, self).compute_on_sql_mixed_mode(table, split_by, execute, mode) return utils.apply_name_tmpl(self.name_tmpl, res) - def split_data(self, df, split_by=None): + def split_data( + self, df: pd.DataFrame, split_by: list[str] | None = None + ) -> Iterable[tuple[pd.DataFrame, Any]]: """Splits the DataFrame returned by the children.""" for k, idx in df.groupby(split_by, observed=True).indices.items(): # split_by will be added back later during the concatenation. @@ -181,17 +201,17 @@ def split_data(self, df, split_by=None): def manipulate( self, - res, + res: Any, melted: bool = False, return_dataframe: bool = True, - apply_name_tmpl=None, - ): + apply_name_tmpl: bool | None = None, + ) -> Any: apply_name_tmpl = True if apply_name_tmpl is None else apply_name_tmpl return super(Operation, self).manipulate( res, melted, return_dataframe, apply_name_tmpl ) - def __call__(self, child: metrics.Metric): + def __call__(self, child: metrics.Metric) -> 'Operation': op = copy.deepcopy(self) if self.children else self op.name = op.name_tmpl.format(utils.get_name(child)) op.children = (child,) @@ -207,14 +227,18 @@ class Distribution(Operation): And all other attributes inherited from Operation. """ - def __init__(self, - over: Union[Text, List[Text]], - child: Optional[metrics.Metric] = None, - name_tmpl: Text = 'Distribution of {}', - **kwargs): + def __init__( + self, + over: utils.StrOrList, + child: metrics.Metric | None = None, + name_tmpl: str = 'Distribution of {}', + **kwargs, + ) -> None: super(Distribution, self).__init__(child, name_tmpl, over, **kwargs) - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: total = ( children.groupby(level=split_by, observed=True).sum() if split_by @@ -226,8 +250,15 @@ def compute_on_children(self, children, split_by): return res.reorder_levels(children.index.names) return res - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. The query is constructed by @@ -250,6 +281,17 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if split_by is None: + split_by = sql.Columns() + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) @@ -292,16 +334,16 @@ class CumulativeDistribution(Distribution): def __init__( self, - over: Text, - child: Optional[metrics.Metric] = None, - order=None, + over: str, + child: metrics.Metric | None = None, + order: Iterable[Any] | None = None, ascending: bool = True, sort_by_values: bool = False, - name_tmpl: Text = 'Cumulative Distribution of {}', - additional_fingerprint_attrs=None, + name_tmpl: str = 'Cumulative Distribution of {}', + additional_fingerprint_attrs: list[str] | None = None, **kwargs, - ): - self.order = order + ) -> None: + self.order = list(order) if order is not None else None self.ascending = ascending self.sort_by_values = sort_by_values super(CumulativeDistribution, self).__init__( @@ -319,7 +361,9 @@ def __init__( if order and sort_by_values: raise ValueError('Custom order is not allowed when sorting by values!') - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: dist = super(CumulativeDistribution, self).compute_on_children( children, split_by ) @@ -344,8 +388,15 @@ def compute_on_children(self, children, split_by): res.sort_index(level=split_by, sort_remaining=False, inplace=True) return res - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. The query is constructed by @@ -368,6 +419,17 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if split_by is None: + split_by = sql.Columns() + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + dist_sql, with_data = super( CumulativeDistribution, self ).get_sql_and_with_clause( @@ -376,7 +438,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, child_table = sql.Datasource(dist_sql, 'CumulativeDistributionRaw') child_table_alias = with_data.merge(child_table) columns = sql.Columns(indexes.aliases) - order = list(self.get_extra_idx(self)) + order = list(self.get_extra_idx()) order = [ sql.Column(self.get_ordered_col(sql.Column(o).alias), auto_alias=False) for o in order @@ -396,7 +458,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, columns.add(col) return sql.Sql(columns, child_table_alias), with_data - def get_ordered_col(self, over): + def get_ordered_col(self, over: str) -> str: if self.order: over = 'CASE %s\n' % over tmpl = 'WHEN %s THEN %s' @@ -407,7 +469,7 @@ def get_ordered_col(self, over): return over if self.ascending else over + ' DESC' -def _format_to_condition(val): +def _format_to_condition(val: Any) -> str: if isinstance(val, str) and not val.startswith('$'): # Use single quotes instead of double quotes for string literals as it's # compatible with more SQL engines. @@ -418,14 +480,16 @@ def _format_to_condition(val): class Comparison(Operation): """Base class for comparisons like percent/absolute change.""" - def __init__(self, - condition_column, - baseline_key, - child: Optional[metrics.Metric] = None, - include_base: bool = False, - name_tmpl: Optional[Text] = None, - additional_fingerprint_attrs=None, - **kwargs): + def __init__( + self, + condition_column: str | Iterable[str], + baseline_key: Any, + child: metrics.Metric | None = None, + include_base: bool = False, + name_tmpl: str | None = None, + additional_fingerprint_attrs: list[str] | None = None, + **kwargs, + ) -> None: self.baseline_key = baseline_key self.include_base = include_base additional_fingerprint_attrs = additional_fingerprint_attrs or [] @@ -438,18 +502,25 @@ def __init__(self, **kwargs) @property - def stratified_by(self): + def stratified_by(self) -> list[str]: return self.extra_split_by[len(self.extra_index):] @stratified_by.setter - def stratified_by(self, stratified_by): + def stratified_by(self, stratified_by: utils.StrOrList) -> None: stratified_by = ( stratified_by if isinstance(stratified_by, list) else [stratified_by] ) self.extra_split_by[len(self.extra_index):] = stratified_by - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL for PercentChange or AbsoluteChange. The query is constructed by @@ -502,6 +573,17 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if split_by is None: + split_by = sql.Columns() + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + cond_cols = sql.Columns(self.extra_index) raw_table_sql, with_data = self.get_change_raw_sql( table, split_by, global_filter, indexes, local_filter, with_data @@ -547,8 +629,14 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, ) def get_change_raw_sql( - self, table, split_by, global_filter, indexes, local_filter, with_data - ): + self, + table: sql.Datasource, + split_by: sql.Columns, + global_filter: sql.Filters, + indexes: sql.Columns, + local_filter: sql.Filters, + with_data: sql.Datasources, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the query where the comparison will be carried out.""" local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) @@ -559,7 +647,9 @@ def get_change_raw_sql( ) return raw_table_sql, with_data - def get_sql_template_for_comparison(self, raw_table_alias, base_table_alias): + def get_sql_template_for_comparison( + self, raw_table_alias: str, base_table_alias: str + ) -> str: """Gets a string template to compute the comparison between columns. The template needs to use "%(r)s" to represent the column from @@ -592,17 +682,21 @@ class PercentChange(Comparison): And all other attributes inherited from Operation. """ - def __init__(self, - condition_column: Text, - baseline_key, - child: Optional[metrics.Metric] = None, - include_base: bool = False, - name_tmpl: Text = '{} Percent Change', - **kwargs): + def __init__( + self, + condition_column: str, + baseline_key: Any, + child: metrics.Metric | None = None, + include_base: bool = False, + name_tmpl: str = '{} Percent Change', + **kwargs, + ) -> None: super(PercentChange, self).__init__(condition_column, baseline_key, child, include_base, name_tmpl, **kwargs) - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: level = None if split_by: level = self.extra_index[0] if len( @@ -618,7 +712,9 @@ def compute_on_children(self, children, split_by): res = res[~idx_to_match.isin([self.baseline_key])] return res - def get_sql_template_for_comparison(self, raw_table_alias, base_table_alias): + def get_sql_template_for_comparison( + self, raw_table_alias: str, base_table_alias: str + ) -> str: return ( sql.SAFE_DIVIDE_FN( numer=f'{raw_table_alias}.%(r)s', @@ -643,17 +739,21 @@ class AbsoluteChange(Comparison): And all other attributes inherited from Operation. """ - def __init__(self, - condition_column: Text, - baseline_key, - child: Optional[metrics.Metric] = None, - include_base: bool = False, - name_tmpl: Text = '{} Absolute Change', - **kwargs): + def __init__( + self, + condition_column: str, + baseline_key: Any, + child: metrics.Metric | None = None, + include_base: bool = False, + name_tmpl: str = '{} Absolute Change', + **kwargs, + ) -> None: super(AbsoluteChange, self).__init__(condition_column, baseline_key, child, include_base, name_tmpl, **kwargs) - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: level = None if split_by: level = self.extra_index[0] if len( @@ -670,11 +770,13 @@ def compute_on_children(self, children, split_by): res = res[~idx_to_match.isin([self.baseline_key])] return res - def get_sql_template_for_comparison(self, raw_table_alias, base_table_alias): + def get_sql_template_for_comparison( + self, raw_table_alias: str, base_table_alias: str + ) -> str: return f'{raw_table_alias}.%(r)s - {base_table_alias}.%(b)s' -def _check_covariates_match_base(base, cov): +def _check_covariates_match_base(base: Any, cov: Any) -> None: len_base = len(base) if isinstance(base, metrics.MetricList) else 1 len_cov = len(cov) if isinstance(cov, metrics.MetricList) else 1 if len_cov != len_base: @@ -722,19 +824,25 @@ class PrePostChange(PercentChange): And all other attributes inherited from Operation. """ - def __init__(self, - condition_column, - baseline_key, - child=None, - covariates=None, - stratified_by=None, - include_base=False, - multiple_covariates=True, - name_tmpl: Text = '{} PrePost Percent Change', - **kwargs): - if isinstance(child, (List, Tuple)): + def __init__( + self, + condition_column: utils.StrOrList, + baseline_key: Any, + child: ( + metrics.Metric | list[metrics.Metric] | tuple[metrics.Metric] | None + ) = None, + covariates: ( + metrics.Metric | list[metrics.Metric] | tuple[metrics.Metric] | None + ) = None, + stratified_by: utils.StrOrList | None = None, + include_base: bool = False, + multiple_covariates: bool = True, + name_tmpl: str = '{} PrePost Percent Change', + **kwargs, + ) -> None: + if isinstance(child, (list, tuple)): child = metrics.MetricList(child) - if isinstance(covariates, (List, Tuple)): + if isinstance(covariates, (list, tuple)): covariates = metrics.MetricList(covariates) if child and covariates: if not multiple_covariates: @@ -767,14 +875,17 @@ def child(self): return self.children[0][0] if self.children else None @property - def covariates(self): + def covariates(self) -> metrics.Metric | None: return self.children[0][1] if self.children else None @property def k_covariates(self) -> int: return count_features(self.covariates) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.multiple_covariates: return super(PrePostChange, self).compute_slices(df, split_by) equiv, _ = utils.get_equivalent_metric(self) @@ -785,12 +896,12 @@ def compute_slices(self, df, split_by=None): def compute_children( self, - df, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - ): + df: pd.DataFrame, + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: if not self.multiple_covariates: raise NotImplementedError # shouldn't be called. child, covariates = super(PrePostChange, self).compute_children( @@ -798,7 +909,12 @@ def compute_children( original_split_by = [s for s in split_by if s not in self.extra_split_by] return self.adjust_value(child, covariates, original_split_by) - def adjust_value(self, child, covariates, split_by): + def adjust_value( + self, + child: pd.DataFrame, + covariates: pd.DataFrame, + split_by: list[str] | None, + ) -> pd.DataFrame: """Adjust the raw value by controlling for Pre-metrics. As described in the class doc, PrePost fits a linear model, @@ -851,7 +967,10 @@ class Adjust(metrics.Metric): c = avg(child) - θ * avg(covariate). """ - def compute_slices(self, df, split_by: Optional[List[Text]] = None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) child = df.iloc[:, :len_child] prefix = utils.get_unique_prefix(child) df.columns = list(child.columns) + [ @@ -870,7 +989,7 @@ def compute_slices(self, df, split_by: Optional[List[Text]] = None): adjusted[c] = adjusted[c] - covariate_adjusted * theta return adjusted.iloc[:, :-1] - def compute(self, df_slice): + def compute(self, df_slice: pd.DataFrame) -> Any: child_slice = df_slice.iloc[:, :len_child] covariate = df_slice.iloc[:, len_child:] adjusted = [ @@ -880,7 +999,13 @@ def compute(self, df_slice): return Adjust('').compute_on(aligned, split_by + self.extra_index) - def compute_through_sql(self, table, split_by, execute, mode): + def compute_through_sql( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + ) -> Any: if self.multiple_covariates: return super(PrePostChange, self).compute_through_sql( table, split_by, execute, mode @@ -895,7 +1020,13 @@ def compute_through_sql(self, table, split_by, execute, mode): res.columns = [self.name_tmpl.format(self.children[0][0].name)] return res - def compute_children_sql(self, table, split_by, execute, mode=None): + def compute_children_sql( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None = None, + ) -> Any: if not self.multiple_covariates: raise NotImplementedError # shouldn't be called. child = super(PrePostChange, @@ -904,8 +1035,15 @@ def compute_children_sql(self, table, split_by, execute, mode=None): child = child.iloc[:, :-self.k_covariates] return self.adjust_value(child, covariates, split_by) - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: if self.multiple_covariates: return super(PrePostChange, self).get_sql_and_with_clause( table, split_by, global_filter, indexes, local_filter, with_data @@ -916,8 +1054,14 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, ) def get_change_raw_sql( - self, table, split_by, global_filter, indexes, local_filter, with_data - ): + self, + table: sql.Datasource, + split_by: sql.Columns, + global_filter: sql.Filters, + indexes: sql.Columns, + local_filter: sql.Filters, + with_data: sql.Datasources, + ) -> tuple[sql.Sql, sql.Datasources]: """Generates PrePost-adjusted values for PercentChange computation. This function generates subqueries like @@ -1006,7 +1150,9 @@ def get_change_raw_sql( with_data, ) - def get_equivalent_without_filter(self, *auxiliary_cols): + def get_equivalent_without_filter( + self, *auxiliary_cols: Any + ) -> metrics.Metric | None: del auxiliary_cols # unused if self.multiple_covariates: return @@ -1073,19 +1219,25 @@ class CUPED(AbsoluteChange): And all other attributes inherited from Operation. """ - def __init__(self, - condition_column, - baseline_key, - child=None, - covariates=None, - stratified_by=None, - include_base=False, - multiple_covariates=True, - name_tmpl: Text = '{} CUPED Change', - **kwargs): - if isinstance(child, (List, Tuple)): + def __init__( + self, + condition_column: utils.StrOrList, + baseline_key: Any, + child: ( + metrics.Metric | list[metrics.Metric] | tuple[metrics.Metric] | None + ) = None, + covariates: ( + metrics.Metric | list[metrics.Metric] | tuple[metrics.Metric] | None + ) = None, + stratified_by: utils.StrOrList | None = None, + include_base: bool = False, + multiple_covariates: bool = True, + name_tmpl: str = '{} CUPED Change', + **kwargs, + ) -> None: + if isinstance(child, (list, tuple)): child = metrics.MetricList(child) - if isinstance(covariates, (List, Tuple)): + if isinstance(covariates, (list, tuple)): covariates = metrics.MetricList(covariates) if child and covariates: if not multiple_covariates: @@ -1114,18 +1266,21 @@ def __init__(self, self.extra_index = condition_column @property - def child(self): + def child(self) -> metrics.Metric | None: return self.children[0][0] if self.children else None @property - def covariates(self): + def covariates(self) -> metrics.Metric | None: return self.children[0][1] if self.children else None @property def k_covariates(self) -> int: return count_features(self.covariates) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) if self.multiple_covariates: return super(CUPED, self).compute_slices(df, split_by) equiv, _ = utils.get_equivalent_metric(self) @@ -1136,12 +1291,12 @@ def compute_slices(self, df, split_by=None): def compute_children( self, - df, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - ): + df: pd.DataFrame, + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: if not self.multiple_covariates: raise NotImplementedError # shouldn't be called. child, covariates = super(CUPED, self).compute_children( @@ -1149,7 +1304,12 @@ def compute_children( original_split_by = [s for s in split_by if s not in self.extra_split_by] return self.adjust_value(child, covariates, original_split_by) - def adjust_value(self, child, covariates, split_by): + def adjust_value( + self, + child: pd.DataFrame, + covariates: pd.DataFrame, + split_by: list[str] | None, + ) -> utils.ReturnType: """Adjust the raw value by controlling for Pre-metrics. Args: @@ -1189,7 +1349,10 @@ class Adjust(metrics.Metric): Covariance(child, covariate) / Var(covariate) """ - def compute_slices(self, df, split_by: Optional[List[Text]] = None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) child = df.iloc[:, :len_child] prefix = utils.get_unique_prefix(child) df.columns = list(child.columns) + [ @@ -1208,7 +1371,7 @@ def compute_slices(self, df, split_by: Optional[List[Text]] = None): adjusted[c] = adjusted[c] - covariate_adjusted * theta return adjusted.iloc[:, :-1] - def compute(self, df_slice): + def compute(self, df_slice: pd.DataFrame) -> Any: child_slice = df_slice.iloc[:, :len_child] covariate = df_slice.iloc[:, len_child:] adjusted = df_slice.groupby(extra_index, observed=True).mean() @@ -1219,7 +1382,13 @@ def compute(self, df_slice): return Adjust('').compute_on(aligned, split_by) - def compute_through_sql(self, table, split_by, execute, mode): + def compute_through_sql( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + ) -> Any: if self.multiple_covariates: return super(CUPED, self).compute_through_sql( table, split_by, execute, mode @@ -1234,7 +1403,13 @@ def compute_through_sql(self, table, split_by, execute, mode): res.columns = [self.name_tmpl.format(self.children[0][0].name)] return res - def compute_children_sql(self, table, split_by, execute, mode=None): + def compute_children_sql( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None = None, + ) -> Any: if not self.multiple_covariates: raise NotImplementedError # shouldn't be called. child = super(CUPED, self).compute_children_sql(table, split_by, execute, @@ -1243,8 +1418,15 @@ def compute_children_sql(self, table, split_by, execute, mode=None): child = child.iloc[:, :-self.k_covariates] return self.adjust_value(child, covariates, split_by) - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: if self.multiple_covariates: return super(CUPED, self).get_sql_and_with_clause( table, split_by, global_filter, indexes, local_filter, with_data @@ -1410,14 +1592,16 @@ class MH(Comparison): And all other attributes inherited from Operation. """ - def __init__(self, - condition_column: Union[Text, List[Text]], - baseline_key: Any, - stratified_by: Union[Text, List[Text]], - metric: Optional[metrics.Metric] = None, - include_base: bool = False, - name_tmpl: Text = '{} MH Ratio', - **kwargs): + def __init__( + self, + condition_column: utils.StrOrList, + baseline_key: Any, + stratified_by: utils.StrOrList, + metric: metrics.Metric | None = None, + include_base: bool = False, + name_tmpl: str = '{} MH Ratio', + **kwargs, + ) -> None: stratified_by = ( stratified_by if isinstance(stratified_by, list) else [stratified_by] ) @@ -1432,7 +1616,7 @@ def __init__(self, extra_index=condition_column, **kwargs) - def check_is_ratio(self, metric, allow_metriclist=True): + def check_is_ratio(self, metric: Any, allow_metriclist: bool = True) -> None: if isinstance(metric, metrics.MetricList) and allow_metriclist: for m in metric: self.check_is_ratio(m, False) @@ -1445,12 +1629,14 @@ def check_is_ratio(self, metric, allow_metriclist=True): ' Got %s.' % metric ) - def compute_children(self, - df: pd.DataFrame, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None): + def compute_children( + self, + df: pd.DataFrame, + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: child = self.children[0] self.check_is_ratio(child) if isinstance(child, metrics.MetricList): @@ -1467,7 +1653,9 @@ def compute_children(self, return self.compute_util_metric_on( util_metric, df, split_by, cache_key=cache_key) - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: Any, split_by: list[str] | None + ) -> Any: child = self.children[0] if isinstance(child, metrics.MetricList): res = [ @@ -1477,7 +1665,12 @@ def compute_on_children(self, children, split_by): return pd.concat(res, axis=1, sort=False) return self.compute_mh(child, children, split_by) - def compute_mh(self, child, df_all, split_by): + def compute_mh( + self, + child: metrics.Metric, + df_all: pd.DataFrame, + split_by: list[str] | None, + ) -> pd.DataFrame: """Computes MH statistics for one Metric.""" level = self.extra_index[0] if len( self.extra_index) == 1 else self.extra_index @@ -1505,7 +1698,13 @@ def compute_mh(self, child, df_all, split_by): res = res[~idx_to_match.isin([self.baseline_key])] return pd.DataFrame(res.sort_index(level=split_by + self.extra_index)) - def compute_children_sql(self, table, split_by=None, execute=None, mode=None): + def compute_children_sql( + self, + table: utils.TableType, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + mode: str | None = None, + ) -> Any: child = self.children[0] self.check_is_ratio(child) if isinstance(child, metrics.MetricList): @@ -1526,8 +1725,15 @@ def compute_children_sql(self, table, split_by=None, execute=None, mode=None): return self.compute_util_metric_on_sql( util_metric, table, split_by + self.extra_split_by, execute, mode=mode) - def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, - local_filter, with_data): + def get_sql_and_with_clause( + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query and WITH clause. The query is constructed in a similar way to AbsoluteChange except that we @@ -1586,6 +1792,15 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + child = self.children[0] self.check_is_ratio(child) local_filter = ( @@ -1698,12 +1913,14 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, ) -def get_display_fn(name, - split_by=None, - value='Value', - condition_column: Optional[List[Text]] = None, - ctrl_id=None, - default_metric_formats=None): +def get_display_fn( + name: str, + split_by: list[str] | None = None, + value: str = 'Value', + condition_column: list[str] | None = None, + ctrl_id: Any = None, + default_metric_formats: dict | None = None, +) -> Callable[..., Any]: """Returns a function that displays confidence interval nicely. Args: @@ -1834,7 +2051,7 @@ def display( return display -def get_comparison_child(op): +def get_comparison_child(op: metrics.Metric) -> Comparison | None: """Checks if `op` wraps a `PercentChange` or `AbsoluteChange` metric. This function recursively checks if `op` is a `MetricWithCI` with a single @@ -1907,16 +2124,16 @@ class MetricWithCI(Operation): def __init__( self, - unit: Optional[Text], - child: Optional[metrics.Metric] = None, - confidence: Optional[float] = None, - name_tmpl: Optional[Text] = None, - prefix: Optional[Text] = None, - additional_fingerprint_attrs=None, - sql_batch_size=None, - enable_optimization=True, + unit: str | None, + child: metrics.Metric | None = None, + confidence: float | None = None, + name_tmpl: str | None = None, + prefix: str | None = None, + additional_fingerprint_attrs: list[str] | None = None, + sql_batch_size: int | None = None, + enable_optimization: bool = True, **kwargs, - ): + ) -> None: if confidence and not 0 < confidence < 1: raise ValueError('Confidence must be in (0, 1).') self.unit = unit @@ -1950,10 +2167,12 @@ def compute_slices(self, df, split_by=None): base = self.compute_change_base(df, split_by) return self.add_base_to_res(res, base) - def compute_point_estimate(self, df, split_by): + def compute_point_estimate( + self, df: pd.DataFrame, split_by: list[str] | None + ) -> pd.DataFrame: return self.compute_child(df, split_by, melted=True) - def compute_ci(self, res): + def compute_ci(self, res: pd.DataFrame) -> pd.DataFrame: """Constructs the confidence interval. Args: @@ -1976,12 +2195,14 @@ def compute_ci(self, res): res[self.prefix + ' CI-upper'] += res.iloc[:, 0] return res - def compute_change_base(self, - df, - split_by, - execute=None, - mode=None, - cache_key=None): + def compute_change_base( + self, + df: pd.DataFrame | utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame] | None = None, + mode: str | None = None, + cache_key: Any = None, + ) -> pd.DataFrame | None: """Computes the base values for Change. It's used in res.display().""" if not self.confidence: return None @@ -1997,9 +2218,11 @@ def compute_change_base(self, to_split = ( split_by + change.extra_index if split_by else change.extra_index) if execute is None: + assert isinstance(df, pd.DataFrame) base = self.compute_util_metric_on( util_metric, df, to_split, cache_key=cache_key) else: + assert isinstance(df, (str, sql.Datasource, sql.Sql)) base = self.compute_util_metric_on_sql( util_metric, df, to_split, execute, mode=mode, cache_key=cache_key) base.columns = [change.name_tmpl.format(c) for c in base.columns] @@ -2008,7 +2231,9 @@ def compute_change_base(self, return base @staticmethod - def add_base_to_res(res, base): + def add_base_to_res( + res: pd.DataFrame, base: pd.DataFrame | None + ) -> pd.DataFrame: with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=UserWarning) res.meterstick_change_base = base @@ -2017,11 +2242,11 @@ def add_base_to_res(res, base): def compute_children( self, df: pd.DataFrame, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - ): + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: del melted, return_dataframe, cache_key # unused return self.compute_on_samples(self.get_samples(df, split_by), split_by) @@ -2029,8 +2254,10 @@ def get_samples(self, df, split_by=None): raise NotImplementedError def compute_on_samples( - self, keyed_samples: Iterable[Tuple[Any, pd.DataFrame]], split_by=None - ): + self, + keyed_samples: Iterable[tuple[Any, pd.DataFrame | None]], + split_by: list[str] | None = None, + ) -> list[Any]: """Iters through sample DataFrames and collects results. Args: @@ -2067,7 +2294,9 @@ def compute_on_children(self, children, split_by): bucket_estimates = pd.concat(children, axis=1, sort=False) return self.get_stderrs_or_ci_half_width(bucket_estimates) - def get_stderrs_or_ci_half_width(self, bucket_estimates): + def get_stderrs_or_ci_half_width( + self, bucket_estimates: pd.DataFrame + ) -> pd.DataFrame: """Returns confidence interval information in an unmelted DataFrame.""" stderrs, dof = self.get_stderrs(bucket_estimates) if self.confidence: @@ -2083,7 +2312,9 @@ def get_stderrs(bucket_estimates): dof = bucket_estimates.count(axis=1) - 1 return bucket_estimates.std(1), dof - def get_ci_width(self, stderrs, dof): + def get_ci_width( + self, stderrs: pd.Series, dof: pd.Series + ) -> tuple[pd.Series, pd.Series]: """You can return asymmetrical confidence interval.""" dof = dof.fillna(0).astype(int) # Scipy might not recognize the Int64 type. half_width = stderrs * stats.t.ppf((1 + self.confidence) / 2, dof) @@ -2110,12 +2341,14 @@ def manipulate( ) return self.add_base_to_res(res, base) if self.confidence else res - def final_compute(self, - res, - melted: bool = False, - return_dataframe: bool = True, - split_by: Optional[List[Text]] = None, - df=None): + def final_compute( + self, + res: Any, + melted: bool = False, + return_dataframe: bool = True, + split_by: Sequence[str] | None = None, + df: Any = None, + ) -> Any: """Add a display function if confidence is specified.""" del return_dataframe # unused if self.confidence: @@ -2134,7 +2367,9 @@ def final_compute(self, res.display = types.MethodType(warn, res) return res - def add_display_fn(self, res, split_by, melted): + def add_display_fn( + self, res: pd.DataFrame, split_by: list[str] | None, melted: bool + ) -> pd.DataFrame: """Bounds a display function to res so res.display() works.""" value = res.columns[0] if melted else res.columns[0][1] ctrl_id = None @@ -2155,16 +2390,17 @@ def add_display_fn(self, res, split_by, melted): def compute_on_sql( self, - table, - split_by=None, - execute=None, - melted=False, - mode=None, - cache_key=None, - cache=None, - batch_size=None, - return_dataframe=True, - ): + table: utils.TableType, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + melted: bool = False, + mode: str | None = None, + cache_key: Any = None, + cache: dict | None = None, + batch_size: int | None = None, + return_dataframe: bool = True, + engine: str = 'f1', + ) -> utils.ReturnType: """Computes self in pure SQL or a mixed of SQL and Python. Args: @@ -2205,7 +2441,13 @@ def compute_on_sql( finally: self._runtime_batch_size = None - def compute_through_sql(self, table, split_by, execute, mode): + def compute_through_sql( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + ) -> Any: try: return super(MetricWithCI, self).compute_through_sql( table, split_by, execute, mode @@ -2223,7 +2465,12 @@ def compute_through_sql(self, table, split_by, execute, mode): 'the query being too large/complex, you can try %s' % msg ) from e - def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): + def compute_on_sql_sql_mode( + self, + table: Any, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + ) -> Any: """Computes self in a SQL query and process the result. We first execute the SQL query then process the result. @@ -2300,7 +2547,13 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): res = pd.concat((sub_dfs), axis=1, keys=metric_names, names=['Metric']) return self.add_base_to_res(res, base) - def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): + def compute_on_sql_mixed_mode( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame] | None, + mode: str | None = None, + ) -> Any: """Computes the child in SQL and the rest in Python. There are two parts. First we compute the standard errors. Then we join it @@ -2372,7 +2625,7 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): ] preagg = sql.Sql(cols, table, self.where_, all_split_by) equiv = get_preaggregated_metric_tree(expanded) - equiv.unit = sql.Column(equiv.unit).alias + equiv.unit = sql.Column(equiv.unit).alias # type: ignore split_by = sql.Columns(split_by).aliases for m in equiv.traverse(): if isinstance(m, metrics.Metric): @@ -2387,16 +2640,20 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): equiv.has_local_filter = any([l.where for l in leaf]) return equiv.compute_on_sql_mixed_mode(preagg, split_by, execute, mode) - def compute_children_sql(self, - table, - split_by, - execute, - mode=None, - batch_size=None): + def compute_children_sql( + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame] | None, + mode: str | None = None, + batch_size: int | None = None, + ) -> Any: """The return should be similar to compute_children().""" raise NotImplementedError - def to_sql(self, table, split_by=None): + def to_sql( + self, table: utils.TableType, split_by: utils.StrOrList | None = None + ) -> sql.Sql: """Generates SQL query for the metric. The SQL generation is actually delegated to get_sql_and_with_clause(). This @@ -2446,7 +2703,7 @@ def to_sql(self, table, split_by=None): ] preagg = sql.Sql(cols, table, self.where_, all_split_by) equiv = get_preaggregated_metric_tree(expanded) - equiv.unit = sql.Column(equiv.unit).alias + equiv.unit = sql.Column(equiv.unit).alias # type: ignore split_by = sql.Columns(split_by).aliases for m in equiv.traverse(): if isinstance(m, metrics.Metric): @@ -2461,8 +2718,14 @@ def to_sql(self, table, split_by=None): return equiv.to_sql(preagg, split_by) def get_sql_and_with_clause( - self, table, split_by, global_filter, indexes, local_filter, with_data - ): + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL for Jackknife or Bootstrap. The query is constructed by @@ -2492,6 +2755,15 @@ def get_sql_and_with_clause( The global with_data which holds all datasources we need in the WITH clause. """ + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + # Confidence interval cannot be computed in SQL completely so the SQL # generated below doesn't work correctly if self is not a root node. if self.confidence and not self._is_root_node: @@ -2580,12 +2852,12 @@ def get_sql_and_with_clause( def get_se_sql( self, - table, - split_by, - global_filter, - indexes, - with_data, - ): + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, + ) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query that computes the standard error and dof if needed.""" global_filter = sql.Filters(global_filter).add(self.where_) self_copy = copy.deepcopy(self) # self_copy might get modified in-place. @@ -2611,15 +2883,15 @@ def get_se_sql( def get_resampled_data_sql( self, - table, - split_by, - global_filter, - indexes, - with_data, - ): + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, + ) -> tuple[sql.Datasource, sql.Datasources]: raise NotImplementedError - def can_precompute(self): + def can_precompute(self) -> bool: return False @@ -2639,13 +2911,15 @@ class Jackknife(MetricWithCI): And all other attributes inherited from Operation. """ - def __init__(self, - unit: Text, - child: Optional[metrics.Metric] = None, - confidence: Optional[float] = None, - enable_optimization=True, - name_tmpl: Text = '{} Jackknife', - **kwargs): + def __init__( + self, + unit: str, + child: metrics.Metric | None = None, + confidence: float | None = None, + enable_optimization: bool = True, + name_tmpl: str = '{} Jackknife', + **kwargs, + ) -> None: super(Jackknife, self).__init__( unit, child, @@ -2655,7 +2929,10 @@ def __init__(self, **kwargs, ) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) """Computes Jackknife with precomputation when possible. For Sum, Count, it's possible to compute the LOO estimates in a vectorized @@ -2689,6 +2966,7 @@ def compute_slices(self, df, split_by=None): precomputed.update( self.find_all_in_cache_by_metric_type(metric=metrics.Count) ) + assert self.cache_key is not None precomputed = { k: v for k, v in precomputed.items() if k.key == self.cache_key.key } @@ -2704,8 +2982,13 @@ def compute_slices(self, df, split_by=None): ) return super(Jackknife, self).compute_slices(df, split_by) - def precompute_sum_or_count_for_jackknife(self, cache_key, each_bucket, - original_split_by, df): + def precompute_sum_or_count_for_jackknife( + self, + cache_key: Any, + each_bucket: pd.DataFrame, + original_split_by: list[str] | None, + df: pd.DataFrame, + ) -> None: """Caches point estimate and leave-one-out (LOO) results for Sum and Count. For Sum, Count, it's possible to compute the LOO estimates in a vectorized @@ -2726,7 +3009,7 @@ def precompute_sum_or_count_for_jackknife(self, cache_key, each_bucket, None. Two additional results are saved to the cache. 1. The total sum/count, which is each_bucket summed over self.unit. 2. The LOO estimates, which is saved under key - ('_RESERVED', 'Jackknife', self.unit). + ('_RESERVED', Jackknife, self.unit). """ if not cache_key.split_by: return @@ -2742,7 +3025,7 @@ def precompute_sum_or_count_for_jackknife(self, cache_key, each_bucket, key = cache_key.replace_split_by(split_by) self.save_to_cache(key, total) - key = cache_key.replace_key(('_RESERVED', 'Jackknife', self.unit)) + key = cache_key.replace_key(('_RESERVED', Jackknife, self.unit)) if not self.in_cache(key): each_bucket = utils.adjust_slices_for_loo(each_bucket, original_split_by, df) @@ -2755,11 +3038,11 @@ def precompute_sum_or_count_for_jackknife(self, cache_key, each_bucket, def compute_children( self, df: pd.DataFrame, - split_by=None, - melted=False, - return_dataframe=True, - cache_key=None, - ): + split_by: list[str] | None = None, + melted: bool = False, + return_dataframe: bool = True, + cache_key: Any = None, + ) -> Any: if not self.can_precompute(): return super(Jackknife, self).compute_children( df, split_by, melted, return_dataframe, cache_key @@ -2768,11 +3051,16 @@ def compute_children( df, split_by + [self.unit], True, - cache_key=('_RESERVED', 'Jackknife', self.unit), + cache_key=('_RESERVED', Jackknife, self.unit), ) return [replicates.unstack(self.unit)] - def get_samples(self, df, split_by=None, return_cache_key=False): + def get_samples( + self, + df: pd.DataFrame, + split_by: list[str] | None = None, + return_cache_key: bool = False, + ) -> Iterable[tuple[Any, pd.DataFrame | None]]: """Yields leave-one-out (LOO) DataFrame with level value. If self.can_precompute(), this function will not get triggered because the @@ -2784,7 +3072,7 @@ def get_samples(self, df, split_by=None, return_cache_key=False): return_cache_key: If to return a cache key. Yields: - ('_RESERVED', 'Jackknife', unit, i) if return_cache_key else None, and the + ('_RESERVED', Jackknife, unit, i) if return_cache_key else None, and the leave-i-out DataFrame. """ levels = df[self.unit].unique() @@ -2793,7 +3081,7 @@ def get_samples(self, df, split_by=None, return_cache_key=False): if not split_by: for lvl in levels: - key = ('_RESERVED', 'Jackknife', self.unit, lvl) + key = ('_RESERVED', Jackknife, self.unit, lvl) yield key if return_cache_key else None, df[df[self.unit] != lvl] else: df = df.set_index(split_by) @@ -2804,17 +3092,24 @@ def get_samples(self, df, split_by=None, return_cache_key=False): if len(unique_slice_val) != max_slices: # Keep only the slices that appeared in the dropped bucket. df_rest = df_rest[df_rest.index.isin(unique_slice_val)] - key = ('_RESERVED', 'Jackknife', self.unit, lvl) + key = ('_RESERVED', Jackknife, self.unit, lvl) yield key if return_cache_key else None, df_rest.reset_index() @staticmethod - def get_stderrs(bucket_estimates): + def get_stderrs( + bucket_estimates: pd.DataFrame, + ) -> tuple[pd.Series, pd.Series]: stderrs, dof = super(Jackknife, Jackknife).get_stderrs(bucket_estimates) return stderrs * dof / np.sqrt(dof + 1), dof def compute_children_sql( - self, table, split_by, execute, mode=None, batch_size=None - ): + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame] | None, + mode: str | None = None, + batch_size: int | None = None, + ) -> Any: """Compute the children on leave-one-out data in SQL. When @@ -2913,7 +3208,7 @@ def compute_children_sql( if pd.api.types.is_numeric_dtype(slice_and_units[self.unit]): loo_where = '%s != %s' % (self.unit, unit) loo_sql.where = sql.Filters(where).add(loo_where) - key = ('_RESERVED', 'Jackknife', self.unit, unit) + key = ('_RESERVED', Jackknife, self.unit, unit) loo = self.compute_child_sql(loo_sql, split_by, execute, False, mode, key) # If a slice doesn't have the unit in the input data, we should exclude @@ -2934,7 +3229,7 @@ def compute_children_sql( ).join(table, on='meterstick_resample_idx != %s' % self.unit), self.where_, ) - key = ('_RESERVED', 'Jackknife', self.unit, tuple(units)) + key = ('_RESERVED', Jackknife, self.unit, tuple(units)) loo = self.compute_child_sql( loo, split_by + ['meterstick_resample_idx'], @@ -2981,6 +3276,8 @@ def can_precompute(self): def is_metric_precomputable(metric: MetricWithCI) -> bool: """If metric is precomputable in Jackknife or Bootstrap.""" for m in metric.traverse(include_self=False): + if isinstance(m, utils.NumberTypes): + continue if isinstance(m, Operation) and not m.precomputable_in_jk_bs: return False if isinstance(m, metrics.Count) and m.distinct: @@ -3030,15 +3327,15 @@ class Bootstrap(MetricWithCI): def __init__( self, - unit: Optional[Text] = None, - child: Optional[metrics.Metric] = None, + unit: str | None = None, + child: metrics.Metric | None = None, n_replicates: int = 10000, - confidence: Optional[float] = None, - enable_optimization=True, - name_tmpl: Text = '{} Bootstrap', + confidence: float | None = None, + enable_optimization: bool = True, + name_tmpl: str = '{} Bootstrap', ci_method: Literal['std', 'percentile'] = 'std', **kwargs, - ): + ) -> None: if ci_method not in ('std', 'percentile'): raise ValueError('ci_method must be either "std" or "percentile"') if ci_method == 'percentile' and not confidence: @@ -3056,7 +3353,12 @@ def __init__( self.n_replicates = n_replicates self.ci_method = ci_method - def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): + def compute_on_sql_sql_mode( + self, + table: Any, + split_by: list[str] | None = None, + execute: Callable[[str], pd.DataFrame] | None = None, + ) -> Any: """Computes self in a SQL query and processes the result. It behaves identically to MetricWithCI.compute_on_sql_sql_mode when @@ -3153,7 +3455,7 @@ def select_percentiles(self) -> dict[str, float]: def compute_on_children( self, children: list[pd.DataFrame], - split_by: Text | Optional[List[Text]] | None = None, + split_by: utils.StrOrList | None = None, ) -> pd.DataFrame: """Computes stderrs or percentiles on the bootstrap replicates.""" if self.ci_method == 'std': @@ -3176,7 +3478,10 @@ def compute_ci(self, res: pd.DataFrame) -> pd.DataFrame: return res return super(Bootstrap, self).compute_ci(res) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) """Computes Bootstrap with unit with precomputation when possible. For Bootstrap with unit, if all leafs can be expressed as Sum or Count, we @@ -3224,7 +3529,9 @@ def compute_slices(self, df, split_by=None): preagg, preagg_df = get_preaggregated_data(self, df, split_by) return self.compute_util_metric_on(preagg, preagg_df, split_by) - def get_samples(self, df, split_by=None): + def get_samples( + self, df: pd.DataFrame, split_by: list[str] | None = None + ) -> Iterable[tuple[Any, pd.DataFrame | None]]: """Resamples for Bootstrap. When samples are likely to repeat, cache.""" # If there is no extra split_by added, each unit will correspond to one row # in the preaggregated data so we can just sample by rows. @@ -3257,8 +3564,13 @@ def get_samples(self, df, split_by=None): yield None, df.loc[resampled].reset_index() def compute_children_sql( - self, table, split_by, execute, mode=None, batch_size=None - ): + self, + table: utils.TableType, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame] | None, + mode: str | None = None, + batch_size: int | None = None, + ) -> Any: """Compute the children on resampled data in SQL. We compute the child on bootstrapped data in a batched way. We bootstrap for @@ -3318,8 +3630,13 @@ def compute_children_sql( return replicates def get_resampled_data_sql( - self, table, split_by, global_filter, indexes, with_data - ): + self, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, + ) -> tuple[sql.Datasource, sql.Datasources]: """Gets self.n_replicates bootstrap resamples.""" del indexes # not used if not self.unit: @@ -3332,8 +3649,8 @@ def get_resampled_data_sql( with_data, ) - def can_precompute(self): - return ( + def can_precompute(self) -> bool: + return bool( self.unit and self.enable_optimization and is_metric_precomputable(self) ) @@ -3370,15 +3687,15 @@ class PoissonBootstrap(Bootstrap): def __init__( self, - unit: Optional[Text] = None, - child: Optional[metrics.Metric] = None, + unit: str | None = None, + child: metrics.Metric | None = None, n_replicates: int = 10000, - confidence: Optional[float] = None, - enable_optimization=True, - name_tmpl: Text = '{} Poisson Bootstrap', + confidence: float | None = None, + enable_optimization: bool = True, + name_tmpl: str = '{} Poisson Bootstrap', ci_method: Literal['std', 'percentile'] = 'std', **kwargs, - ): + ) -> None: super(PoissonBootstrap, self).__init__( unit, child, @@ -3390,7 +3707,9 @@ def __init__( **kwargs, ) - def get_samples(self, df, split_by=None): + def get_samples( + self, df: pd.DataFrame, split_by: list[str] | None = None + ) -> Iterable[tuple[Any, pd.DataFrame | None]]: """Resamples for PoissonBootstrap. There are three cases here. @@ -3483,12 +3802,17 @@ def get_samples(self, df, split_by=None): ) yield cache_key, resampled.reset_index() - def get_sample_weight(self, n): + def get_sample_weight(self, n: int) -> np.ndarray: return np.random.poisson(size=n) def get_resampled_data_sql( - self, table, split_by, global_filter, indexes, with_data - ): + self, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, + ) -> tuple[sql.Datasource, sql.Datasources]: """Gets self.n_replicates Poisson bootstrap resamples. The function makes three or four subqueries. The first one adds a uniformly @@ -3602,9 +3926,9 @@ def get_resampled_data_sql( sql.RAND_FN(), alias='poisson_bootstrap_uniform_var' ) split_by_cols = ( - split_by.aliases + sql.Columns(split_by).aliases if self.has_been_preaggregated - else split_by.original_columns + else sql.Columns(split_by).original_columns ) if self.unit: cols = ', '.join( @@ -3712,7 +4036,9 @@ def get_resampled_data_sql( return poisson_sampled_table_alias, with_data -def get_preaggregated_data(m, df, split_by): +def get_preaggregated_data( + m: Bootstrap, df: pd.DataFrame, split_by: list[str] | None +) -> tuple[metrics.Metric, pd.DataFrame]: """Gets the preaggegated Metric and data. Read the doc of Bootstrap.compute_slices() first. @@ -3747,12 +4073,13 @@ def get_preaggregated_data(m, df, split_by): filter_in_leaf = utils.push_filters_to_leaf(m) leafs = metrics.MetricList(tuple(set(utils.get_leaf_metrics(filter_in_leaf)))) preagg_df = m.compute_util_metric_on(leafs, df, all_split_by) - preagg_leafs = get_preaggregated_metric_tree(leafs) + preagg_leafs: Any = get_preaggregated_metric_tree(leafs) preagg_df.columns = [c.var for c in preagg_leafs] preagg_df = preagg_df.loc[:, ~preagg_df.columns.duplicated()].copy() preagg_df = preagg_df.reindex(original_idx.index) if all_split_by: preagg_df.reset_index(all_split_by, inplace=True) + assert m.cache_key is not None for l, p in zip(leafs, preagg_leafs): key = utils.CacheKey(l, m.cache_key, l.where_, all_split_by) res = m.get_cached(key) @@ -3762,7 +4089,7 @@ def get_preaggregated_data(m, df, split_by): return preagg, preagg_df -def get_preaggregated_metric_tree(m): +def get_preaggregated_metric_tree(m: metrics.Metric) -> metrics.Metric: """Gets the equivalent Metric of m on the preaggregated data.""" if not isinstance(m, metrics.Metric): return m @@ -3774,7 +4101,7 @@ def get_preaggregated_metric_tree(m): return m -def get_preaggregated_metric(m): +def get_preaggregated_metric(m: metrics.Metric) -> metrics.Metric: """Gets the equivalent metric of on the preaggregated data if m is a leaf.""" var = get_preaggregated_metric_var(m) if isinstance(m, metrics.Max): @@ -3784,7 +4111,7 @@ def get_preaggregated_metric(m): return metrics.Sum(var, name=m.name) -def get_preaggregated_metric_var(m: metrics.Metric): +def get_preaggregated_metric_var(m: metrics.Metric) -> str: """Gets the new column name for leaf metric m in the preaggregated data.""" if not isinstance(m, (metrics.Sum, metrics.Count, metrics.Max, metrics.Min)): raise ValueError( @@ -3802,8 +4129,13 @@ def get_preaggregated_metric_var(m: metrics.Metric): def get_se_sql( - metric, table, split_by, global_filter, indexes, with_data -): + metric: MetricWithCI, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, +) -> tuple[sql.Sql, sql.Datasources]: """Gets the SQL query that computes the standard error and dof if needed.""" samples, with_data = metric.children[0].get_sql_and_with_clause( table, @@ -3828,6 +4160,8 @@ def get_se_sql( alias = c.alias ci_method = getattr(metric, 'ci_method', 'std') if ci_method == 'percentile': + if not isinstance(metric, Bootstrap): + raise ValueError('Percentile CI is only supported for Bootstrap.') for k, v in metric.select_percentiles().items(): pct_col = sql.Column(alias, sql.QUANTILE_FN(v), f'{c.alias_raw} {k}') @@ -3852,7 +4186,7 @@ def get_se_sql( return sql.Sql(columns, samples_alias, groupby=groupby), with_data -def adjust_indexes_for_jk_fast(indexes): +def adjust_indexes_for_jk_fast(indexes: sql.Columns) -> sql.Columns: """For the indexes that get renamed, only keep the alias. For a Jackknife that only has Sum and Count as leaf Metrics, we cut the corner @@ -3876,8 +4210,12 @@ def adjust_indexes_for_jk_fast(indexes): def get_jackknife_data_general( - metric, table, split_by, global_filter, with_data -): + metric: Jackknife, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + with_data: sql.Datasources, +) -> tuple[sql.Datasource, sql.Datasources]: """Gets jackknife samples. If the leave-one-out estimates can be precomputed, see the doc of @@ -3928,6 +4266,7 @@ def get_jackknife_data_general( (sql.Column(unit, alias='meterstick_resample_idx')), distinct=True ) if split_by: + split_by = sql.Columns(split_by) groupby = sql.Columns( (sql.Column(c.expression, alias='jk_%s' % c.alias) for c in split_by) ) @@ -3963,8 +4302,13 @@ def get_jackknife_data_general( def get_jackknife_data_fast( - metric, table, split_by, global_filter, indexes, with_data -): + metric: Jackknife, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + indexes: sql.Columns, + with_data: sql.Datasources, +) -> tuple[sql.Datasource, sql.Datasources]: """Gets jackknife samples in a fast way for precomputable Jackknife. If all the leaf Metrics are Sum and/or Count, we can compute the @@ -4079,12 +4423,12 @@ def get_jackknife_data_fast( def modify_descendants_for_jackknife_fast( - metric, - columns_to_preagg, - columns_in_loo, - global_filter, - needs_adjustment, -): + metric: metrics.Metric, + columns_to_preagg: sql.Columns, + columns_in_loo: sql.Columns, + global_filter: sql.Filters | None, + needs_adjustment: bool, +) -> metrics.Metric: """Gets the columns for leaf Metrics and modify them for fast Jackknife SQL. See the doc of get_jackknife_data_fast() first. Here we @@ -4158,8 +4502,13 @@ def modify_descendants_for_jackknife_fast( def bootstrap_by_row( - metric, table, split_by, global_filter, with_data, columns_in_table=None -): + metric: Bootstrap, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + with_data: sql.Datasources, + columns_in_table: list[str] | None = None, +) -> tuple[sql.Datasource, sql.Datasources]: """Gets metric.n_replicates bootstrap resamples for Bootstrap without unit. The SQL is constructed as @@ -4209,6 +4558,7 @@ def bootstrap_by_row( The alias of the table in the WITH clause that has all resampled data. The global with_data which holds all datasources we need in the WITH clause. """ + split_by = sql.Columns(split_by) columns = sql.Columns(['meterstick_resample_idx']) if columns_in_table: columns = columns.add(columns_in_table) @@ -4274,7 +4624,13 @@ def bootstrap_by_row( return table, with_data -def bootstrap_by_unit(metric, table, split_by, global_filter, with_data): +def bootstrap_by_unit( + metric: Bootstrap, + table: Any, + split_by: list[str] | None, + global_filter: sql.Filters | None, + with_data: sql.Datasources, +) -> tuple[sql.Datasource, sql.Datasources]: """Gets metric.n_replicates bootstrap resamples. The SQL is constructed as @@ -4305,6 +4661,7 @@ def bootstrap_by_unit(metric, table, split_by, global_filter, with_data): The alias of the table in the WITH clause that has all resampled data. The global with_data which holds all datasources we need in the WITH clause. """ + split_by = sql.Columns(split_by) columns = sql.Columns(split_by).add(metric.unit) units = sql.Sql(columns, table, global_filter, columns) units_alias = with_data.add(sql.Datasource(units, 'Candidates')) @@ -4337,7 +4694,9 @@ def bootstrap_by_unit(metric, table, split_by, global_filter, with_data): return table, with_data -def copy_meterstick_metadata(original_df, new_df): +def copy_meterstick_metadata( + original_df: pd.DataFrame, new_df: pd.DataFrame +) -> pd.DataFrame: """Copies meterstick metadata attributes from original_df to new_df. The metadata attributes include `display` and `meterstick_change_base`, @@ -4369,19 +4728,43 @@ class MetricFunction(Operation): name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child, func, sql_func, name_tmpl, **kwargs): + def __init__( + self, + child: metrics.Metric, + func: Callable[..., Any], + sql_func: str | None, + name_tmpl: str, + **kwargs, + ) -> None: super().__init__(child, name_tmpl, **kwargs) self.func = func self.sql_func = sql_func - def compute_on_children(self, children, split_by): + def compute_on_children( + self, children: pd.DataFrame, split_by: list[str] | None + ) -> pd.DataFrame: new_df = self.func(children) new_df = copy_meterstick_metadata(children, new_df) return new_df def get_sql_and_with_clause( - self, table, split_by, global_filter, indexes, local_filter, with_data - ): + self, + table: sql.Datasource, + split_by: sql.Columns | None, + global_filter: sql.Filters | None, + indexes: sql.Columns | None, + local_filter: sql.Filters | None, + with_data: sql.Datasources | None, + ) -> tuple[sql.Sql, sql.Datasources]: + if indexes is None: + indexes = sql.Columns() + if global_filter is None: + global_filter = sql.Filters() + if local_filter is None: + local_filter = sql.Filters() + if with_data is None: + with_data = sql.Datasources() + if not self.sql_func: raise NotImplementedError( f'SQL generation not supported for {type(self)}.' @@ -4404,13 +4787,24 @@ def get_sql_and_with_clause( return child_sql, with_data def manipulate( - self, res, melted=False, return_dataframe=True, apply_name_tmpl=None - ): + self, + res: Any, + melted: bool = False, + return_dataframe: bool = True, + apply_name_tmpl: bool | None = None, + ) -> Any: new_res = super().manipulate(res, melted, return_dataframe, apply_name_tmpl) new_res = copy_meterstick_metadata(res, new_res) return new_res - def final_compute(self, res, melted, return_dataframe, split_by, df): + def final_compute( + self, + res: Any, + melted: bool, + return_dataframe: bool, + split_by: Sequence[str] | None, + df: Any, + ) -> Any: new_res = super().final_compute(res, melted, return_dataframe, split_by, df) new_res = copy_meterstick_metadata(res, new_res) return new_res @@ -4427,7 +4821,13 @@ class LogTransform(MetricFunction): name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): + def __init__( + self, + child: metrics.Metric | None = None, + base: str = 'ln', + name_tmpl: str | None = None, + **kwargs, + ) -> None: if base not in ('ln', 'log10'): raise ValueError("base must be 'ln' or 'log10'") self.base = base @@ -4448,7 +4848,12 @@ def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): class ExponentialTransform(MetricFunction): """Base class for applying exponential transformations to Metric.""" - def __init__(self, child=None, name_tmpl='Exp({})', **kwargs): + def __init__( + self, + child: metrics.Metric | None = None, + name_tmpl: str = 'Exp({})', + **kwargs, + ) -> None: sql_func = 'EXP({})' super().__init__( child, @@ -4491,7 +4896,13 @@ class ExponentialPercentTransform(MetricFunction): name_tmpl: The template to generate the name from child Metric's name. """ - def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): + def __init__( + self, + child: metrics.Metric | None = None, + base: str = 'ln', + name_tmpl: str | None = None, + **kwargs, + ) -> None: """Initializes an ExponentialPercentTransform. Args: @@ -4523,7 +4934,7 @@ def __init__(self, child=None, base: str = 'ln', name_tmpl=None, **kwargs): self._check_and_update_for_log_transformed_abs_change() ) - def _check_and_update_for_log_transformed_abs_change(self): + def _check_and_update_for_log_transformed_abs_change(self) -> bool: """Checks if child is MetricWithCI(AbsoluteChange(LogTransform(...))). If the child structure matches, it means we are calculating percentage @@ -4571,16 +4982,20 @@ def _check_and_update_for_log_transformed_abs_change(self): self.children = tuple([ci_method(ab(log_transform))]) return True - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: res = super().__call__(*args, **kwargs) res._has_log_transformed_abs_change = ( - res._check_and_update_for_log_transformed_abs_change() + res._check_and_update_for_log_transformed_abs_change() # type: ignore ) return res def manipulate( - self, res, melted=False, return_dataframe=True, apply_name_tmpl=None - ): + self, + res: Any, + melted: bool = False, + return_dataframe: bool = True, + apply_name_tmpl: bool | None = None, + ) -> Any: """Transforms base value back to original scale if needed.""" new_res = super().manipulate(res, melted, return_dataframe, apply_name_tmpl) if self._has_log_transformed_abs_change and hasattr( @@ -4615,16 +5030,16 @@ class LogTransformedPercentChangeWithCI(Operation): def __init__( self, - condition_column: Text, - baseline_key, - unit: Text, + condition_column: str, + baseline_key: Any, + unit: str, confidence: float, - child: Optional[metrics.Metric] = None, - ci_method_type: Type[MetricWithCI] = Jackknife, + child: metrics.Metric | None = None, + ci_method_type: type[MetricWithCI] = Jackknife, name_tmpl: str = '{}', include_base: bool = False, **kwargs, - ): + ) -> None: """Initializes a LogTransformedPercentChangeWithCI. Args: @@ -4649,11 +5064,11 @@ def __init__( additional_fingerprint_attrs=['change', 'ci_method'], ) - def __call__(self, child: metrics.Metric): + def __call__(self, child: metrics.Metric) -> Any: op = super().__call__(child) return op - def _get_equiv(self, child): + def _get_equiv(self, child: metrics.Metric) -> metrics.Metric: return ( LogTransform(child) | self.change @@ -4661,25 +5076,45 @@ def _get_equiv(self, child): | ExponentialPercentTransform() ) - def compute_slices(self, df, split_by=None): + def compute_slices( + self, df: utils.ReturnType, split_by: list[str] | None = None + ) -> utils.ReturnType: + assert isinstance(df, pd.DataFrame) """Computes CI on log-scale and transform back to percent change.""" equiv = self._get_equiv(self.children[0]) return self.compute_util_metric_on(equiv, df, split_by) - def compute_through_sql(self, table, split_by, execute, mode): + def compute_through_sql( + self, + table: Any, + split_by: list[str] | None, + execute: Callable[[str], pd.DataFrame], + mode: str | None, + ) -> Any: equiv = self._get_equiv(self.children[0]) return self.compute_util_metric_on_sql( - equiv, table, split_by, execute, mode + equiv, table, split_by, execute, mode=mode ) def manipulate( - self, res, melted=False, return_dataframe=True, apply_name_tmpl=None - ): + self, + res: Any, + melted: bool = False, + return_dataframe: bool = True, + apply_name_tmpl: bool | None = None, + ) -> Any: new_res = super().manipulate(res, melted, return_dataframe, apply_name_tmpl) new_res = copy_meterstick_metadata(res, new_res) return new_res - def final_compute(self, res, melted, return_dataframe, split_by, df): + def final_compute( + self, + res: Any, + melted: bool, + return_dataframe: bool, + split_by: Sequence[str] | None, + df: Any, + ) -> Any: new_res = super().final_compute(res, melted, return_dataframe, split_by, df) new_res = copy_meterstick_metadata(res, new_res) return new_res diff --git a/sql.py b/sql.py index 80a155d..a189fe0 100644 --- a/sql.py +++ b/sql.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module to generate SQL scripts for Metrics.""" - from __future__ import absolute_import +from __future__ import annotations from __future__ import division from __future__ import print_function @@ -22,7 +22,7 @@ import copy import functools import re -from typing import Any, Iterable, List, Optional, Text, Union +from typing import Any DEFAULT_DIALECT = 'GoogleSQL' @@ -63,15 +63,15 @@ COVAR_SAMP_FN = None -def drop_table_if_exists(alias: str): +def drop_table_if_exists(alias: str) -> str: return f'DROP TABLE IF EXISTS {alias};' -def drop_temp_table_if_exists(alias: str): +def drop_temp_table_if_exists(alias: str) -> str: return f'DROP TEMPORARY TABLE IF EXISTS {alias};' -def drop_table_if_exists_then_create_temp_table(alias: str, query: str): +def drop_table_if_exists_then_create_temp_table(alias: str, query: str) -> str: """Drops a table if it exists then creates a temporary table.""" return ( drop_table_if_exists(alias) @@ -79,7 +79,9 @@ def drop_table_if_exists_then_create_temp_table(alias: str, query: str): ) -def drop_temp_table_if_exists_then_create_temp_table(alias: str, query: str): +def drop_temp_table_if_exists_then_create_temp_table( + alias: str, query: str +) -> str: """Drops a table if it exists then creates a temporary table.""" return ( drop_temp_table_if_exists(alias) @@ -87,12 +89,12 @@ def drop_temp_table_if_exists_then_create_temp_table(alias: str, query: str): ) -def create_temp_table_fn_not_implemented(alias: str, query: str): +def create_temp_table_fn_not_implemented(alias: str, query: str) -> Any: del alias, query # Unused raise NotImplementedError('CREATE TEMP TABLE is not implemented.') -def sql_server_rand_fn_not_implemented(): +def sql_server_rand_fn_not_implemented() -> Any: raise NotImplementedError( "SQL Server's RAND() without a seed parameter will return the same value" " for every row within the same SELECT statement, which doesn't work" @@ -100,37 +102,37 @@ def sql_server_rand_fn_not_implemented(): ) -def safe_divide_fn_default(numer: str, denom: str): +def safe_divide_fn_default(numer: str, denom: str) -> str: return ( f'CASE WHEN {{denom}} = 0 THEN NULL ELSE {FLOAT_CAST_FN("{numer}")} /' f' {FLOAT_CAST_FN("{denom}")} END'.format(numer=numer, denom=denom) ) -def approx_quantiles_fn(percentile): +def approx_quantiles_fn(percentile: float) -> str: p = int(100 * percentile) return f'APPROX_QUANTILES({{}}, 100)[SAFE_OFFSET({p})]' -def percentile_cont_fn(percentile): +def percentile_cont_fn(percentile: float) -> str: return f'PERCENTILE_CONT({percentile}) WITHIN GROUP (ORDER BY {{}})' -def approx_percentile_fn(percentile): +def approx_percentile_fn(percentile: float) -> str: return f'APPROX_PERCENTILE({{}}, {percentile})' -def quantile_fn_not_implemented(percentile): +def quantile_fn_not_implemented(percentile: float) -> Any: del percentile # Unused raise NotImplementedError('Quantile is not implemented.') def array_agg_fn_googlesql( - sort_by: Optional[str], - ascending: Optional[bool], - dropna: Optional[bool], - limit: Optional[int], -): + sort_by: str | None, + ascending: bool | None, + dropna: bool | None, + limit: int | None, +) -> str: """Uses GoogleSQL's ARRAY_AGG to aggregate arrays.""" dropna = ' IGNORE NULLS' if dropna else '' order_by = f' ORDER BY {sort_by}' if sort_by else '' @@ -141,11 +143,11 @@ def array_agg_fn_googlesql( def array_agg_fn_no_use_filter_no_limit( - sort_by: Optional[str], - ascending: Optional[bool], - dropna: Optional[bool], - limit: Optional[int], -): + sort_by: str | None, + ascending: bool | None, + dropna: bool | None, + limit: int | None, +) -> str: """Uses ARRAY_AGG to aggregate arrays. Use FILTER to filter out NULLs.""" del limit # LIMIT is not supported in PostgreSQL so just skip. dropna = ' FILTER (WHERE {} IS NOT NULL)' if dropna else '' @@ -156,11 +158,11 @@ def array_agg_fn_no_use_filter_no_limit( def json_array_agg_fn( - sort_by: Optional[str], - ascending: Optional[bool], - dropna: Optional[bool], - limit: Optional[int], -): + sort_by: str | None, + ascending: bool | None, + dropna: bool | None, + limit: int | None, +) -> str: """Uses JSON_ARRAYAGG to aggregate arrays.""" del limit # LIMIT is not supported in PostgreSQL so just skip. if not dropna: @@ -172,47 +174,47 @@ def json_array_agg_fn( def array_agg_fn_not_implemented( - sort_by: Optional[str], - ascending: Optional[bool], - dropna: Optional[bool], - limit: Optional[int], -): + sort_by: str | None, + ascending: bool | None, + dropna: bool | None, + limit: int | None, +) -> str: del sort_by, ascending, dropna, limit # Unused raise NotImplementedError('ARRAY_AGG is not implemented.') -def array_index_safe_offset_fn(array: str, zero_based_idx: int): +def array_index_safe_offset_fn(array: str, zero_based_idx: int) -> str: return f'{array}[SAFE_OFFSET({zero_based_idx})]' -def array_subscript_fn(array: str, zero_based_idx: int): +def array_subscript_fn(array: str, zero_based_idx: int) -> str: return f'({array})[{zero_based_idx + 1}]' -def element_at_index_fn(array: str, zero_based_idx: int): +def element_at_index_fn(array: str, zero_based_idx: int) -> str: return f'element_at({array}, {zero_based_idx + 1})' -def json_extract_fn(array: str, zero_based_idx: int): +def json_extract_fn(array: str, zero_based_idx: int) -> str: return f"JSON_EXTRACT({array}, '$[{zero_based_idx}]')" -def json_value_fn(array: str, zero_based_idx: int): +def json_value_fn(array: str, zero_based_idx: int) -> str: return f"JSON_VALUE({array}, '$[{zero_based_idx}]')" -def array_index_fn_not_implemented(array: str, zero_based_idx: int): +def array_index_fn_not_implemented(array: str, zero_based_idx: int) -> str: del array, zero_based_idx # Unused raise NotImplementedError('ARRAY_INDEX is not implemented.') def nth_fn_default( zero_based_idx: int, - sort_by: Optional[str], - ascending: Optional[bool], - dropna: Optional[bool], - limit: Optional[int], -): + sort_by: str | None, + ascending: bool | None, + dropna: bool | None, + limit: int | None, +) -> str: try: array = ARRAY_AGG_FN(sort_by, ascending, dropna, limit) return ARRAY_INDEX_FN(array, zero_based_idx) @@ -220,16 +222,16 @@ def nth_fn_default( raise NotImplementedError('Nth value is not implemented.') from e -def uniform_mapping_fn_not_implemented(_): +def uniform_mapping_fn_not_implemented(_: Any) -> Any: raise NotImplementedError('Uniform mapping is not implemented.') def unnest_array_with_offset_fn( array: str, - alias: Optional[str] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, -): + alias: str | None = None, + offset: int | None = None, + limit: int | None = None, +) -> str: """Unnests an array in GoogleSQL.""" if alias is None: return f'UNNEST({array})' @@ -241,10 +243,10 @@ def unnest_array_with_offset_fn( def unnest_array_with_ordinality_fn( array: str, - alias: Optional[str] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, -): + alias: str | None = None, + offset: int | None = None, + limit: int | None = None, +) -> str: """Unnests an array in PostgreSQL.""" if alias is None: return f'UNNEST({array})' @@ -258,10 +260,10 @@ def unnest_array_with_ordinality_fn( def unnest_json_array_fn( array: str, - alias: Optional[str] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, -): + alias: str | None = None, + offset: int | None = None, + limit: int | None = None, +) -> str: """Unnests a JSON_ARRAY in Oracle SQL.""" where = f' WHERE {offset} < {limit + 1}' if limit else '' return f'''JSON_TABLE({array}, '$[*]' @@ -274,38 +276,42 @@ def unnest_json_array_fn( def unnest_array_fn_not_implemented( array: str, - alias: Optional[str] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, -): + alias: str | None = None, + offset: int | None = None, + limit: int | None = None, +) -> str: del array, alias, offset, limit # Unused raise NotImplementedError('UNNEST is not implemented.') -def unnest_array_literal_fn_googlesql(array: List[Any], alias: str = ''): +def unnest_array_literal_fn_googlesql(array: list[Any], alias: str = '') -> str: return f'UNNEST({array}) {alias}'.strip() -def unnest_array_literal_fn_postgresql(array: List[Any], alias: str = ''): +def unnest_array_literal_fn_postgresql( + array: list[Any], alias: str = '' +) -> str: return f'UNNEST(ARRAY{array}) {alias}'.strip() -def unnest_array_literal_fn_not_implemented(array, alias=''): +def unnest_array_literal_fn_not_implemented( + array: list[Any], alias: str = '' +) -> str: del array, alias # Unused raise NotImplementedError('UNNEST with literal array is not implemented.') -def generate_array_fn(n): +def generate_array_fn(n: int) -> str: """Generates an array of n elements using GENERATE_ARRAY.""" return f'GENERATE_ARRAY(1, {n})' -def generate_series_fn(n): +def generate_series_fn(n: int) -> str: """Generates an array of n elements using GENERATE_SERIES.""" return f'GENERATE_SERIES(1, {n})' -def generate_sequence_fn_mariadb(n): +def generate_sequence_fn_mariadb(n: int) -> str: """Generates an array of n elements using sequence in MariaDB.""" try: n = int(n) @@ -321,7 +327,7 @@ def generate_sequence_fn_mariadb(n): ) from e -def generate_array_fn_oracle(n, alias: str = '_'): +def generate_array_fn_oracle(n: int, alias: str = '_') -> str: """Generates an array of n elements using sequence in Oracle.""" try: return f'SELECT LEVEL AS {alias} FROM DUAL CONNECT BY LEVEL <= {int(n)}' @@ -331,75 +337,83 @@ def generate_array_fn_oracle(n, alias: str = '_'): ) from e -def generate_sequence_fn_trino(n): +def generate_sequence_fn_trino(n: int) -> str: """Generates an array of n elements using sequence in Trino.""" return f'SEQUENCE(1, {n})' -def generate_array_fn_not_implemented(n): +def generate_array_fn_not_implemented(n: int) -> Any: del n # Unused raise NotImplementedError( 'GENERATE_ARRAY/GENERATE_SERIES is not implemented.' ) -def unnest_generated_array(n, alias: Optional[str] = None): +def unnest_generated_array(n: int, alias: str | None = None) -> str: """Unnest a generated array, used to duplicate data.""" return UNNEST_ARRAY_FN(GENERATE_ARRAY_FN(n), alias) -def implicitly_unnest_generated_array(n, alias: Optional[str] = None): +def implicitly_unnest_generated_array( + n: int, alias: str | None = None +) -> str: """Unnest a generated series, used to duplicate data.""" if not alias: return GENERATE_ARRAY_FN(n) return f'{GENERATE_ARRAY_FN(n)} {alias}' -def implicitly_unnest_generated_sequence(n, alias: Optional[str] = None): +def implicitly_unnest_generated_sequence( + n: int, alias: str | None = None +) -> str: """Unnest a generated series, used to duplicate data.""" if not alias: return GENERATE_ARRAY_FN(n) return f'(SELECT seq AS {alias} FROM {GENERATE_ARRAY_FN(n)}) unnested' -def duplicate_data_n_times_oracle(n, alias: Optional[str] = None): +def duplicate_data_n_times_oracle( + n: int, alias: str | None = None +) -> str: if not alias: return generate_array_fn_oracle(n) return generate_array_fn_oracle(n, alias) -def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None): +def duplicate_data_n_times_not_implemented( + n: int, alias: str | None = None +) -> Any: del n, alias # Unused raise NotImplementedError( 'Duplicate data n times is not implemented.' ) -def stddev_pop_not_implemented(): +def stddev_pop_not_implemented() -> Any: raise NotImplementedError('STDDEV_POP is not implemented.') -def stddev_samp_not_implemented(): +def stddev_samp_not_implemented() -> Any: raise NotImplementedError('STDDEV_SAMP is not implemented.') -def variance_pop_not_implemented(): +def variance_pop_not_implemented() -> Any: raise NotImplementedError('VARIANCE_POP is not implemented.') -def variance_samp_not_implemented(): +def variance_samp_not_implemented() -> Any: raise NotImplementedError('VARIANCE_SAMP is not implemented.') -def corr_not_implemented(): +def corr_not_implemented() -> Any: raise NotImplementedError('CORR is not implemented.') -def covar_pop_not_implemented(): +def covar_pop_not_implemented() -> Any: raise NotImplementedError('COVAR_POP is not implemented.') -def covar_samp_not_implemented(): +def covar_samp_not_implemented() -> Any: raise NotImplementedError('COVAR_SAMP is not implemented.') @@ -608,7 +622,7 @@ def covar_samp_not_implemented(): } -def set_dialect(dialect: Optional[str]): +def set_dialect(dialect: str | None) -> None: """Sets the dialect of the SQL query.""" # You can manually override the options below. You can manually test it in # https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing. @@ -660,14 +674,14 @@ def set_dialect(dialect: Optional[str]): COVAR_SAMP_FN = _get_dialect_option(COVAR_SAMP_OPTIONS) -def _get_dialect_option(options: dict[str, Any]): +def _get_dialect_option(options: dict[str, Any]) -> Any: return options.get(DIALECT, options['Default']) set_dialect(DEFAULT_DIALECT) -def is_compatible(sql0, sql1): +def is_compatible(sql0: 'Sql', sql1: 'Sql') -> bool: """Checks if two datasources are compatible so their columns can be merged. Being compatible means datasources @@ -695,7 +709,7 @@ def is_compatible(sql0, sql1): ) -def add_suffix(alias): +def add_suffix(alias: str) -> str: """Adds an int suffix to alias.""" alias = alias.strip('`') m = re.search(r'([0-9]+)$', alias) @@ -707,7 +721,7 @@ def add_suffix(alias): return alias + '_1' -def rand_run_only_once_in_with_clause(execute): +def rand_run_only_once_in_with_clause(execute: abc.Callable[..., Any]) -> bool: """Check if the RAND() is only evaluated once in the WITH clause.""" d = execute( f'''WITH T AS (SELECT {RAND_FN()} AS r) @@ -717,7 +731,7 @@ def rand_run_only_once_in_with_clause(execute): return bool(d.iloc[0, 0] == 0) -def dep_on_rand_table(query, rand_tables): +def dep_on_rand_table(query: Any, rand_tables: abc.Iterable[str]) -> bool: """Returns if a SQL query depends on any stochastic table in rand_tables.""" for rand_table in rand_tables: if re.search(r'\b%s\b' % rand_table, str(query)): @@ -725,7 +739,7 @@ def dep_on_rand_table(query, rand_tables): return False -def get_temp_tables(with_data: 'Datasources'): +def get_temp_tables(with_data: 'Datasources') -> set[str]: """Gets all the subquery tables that need to be materialized. When generating the query, we assume that volatile functions like RAND() in @@ -781,11 +795,11 @@ def get_temp_tables(with_data: 'Datasources'): return tmp_tables -def get_alias(c): +def get_alias(c: Any) -> str: return getattr(c, 'alias_raw', c) -def escape_alias(alias): +def escape_alias(alias: str) -> str: """Replaces special characters in SQL column name alias.""" special = set(r""" `~!@#$%^&*()-=+[]{}\|;:'",.<>/?""") if not alias or not special.intersection(alias): @@ -820,46 +834,46 @@ def escape_alias(alias): class SqlComponent: """Base class for a SQL component like column, tabel and filter.""" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return str(self) == str(other) - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return str(self) < other - def __repr__(self): + def __repr__(self) -> str: return str(self) - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __bool__(self): + def __bool__(self) -> bool: return bool(str(self)) - def __nonzero__(self): + def __nonzero__(self) -> bool: return bool(str(self)) - def __add__(self, other): + def __add__(self, other: str) -> str: return str.__add__(str(self), other) - def __mul__(self, other): + def __mul__(self, other: int) -> str: return str.__mul__(str(self), other) - def __rmul__(self, other): + def __rmul__(self, other: int) -> str: return str.__rmul__(str(self), other) - def __getitem__(self, idx): + def __getitem__(self, idx: Any) -> str: return str(self)[idx] class SqlComponents(SqlComponent): """Base class for a bunch of SQL components like columns and filters.""" - def __init__(self, children=None): + def __init__(self, children: Any | None = None) -> None: super(SqlComponents, self).__init__() self.children = [] self.add(children) - def add(self, children): + def add(self, children: Any) -> 'SqlComponents': if not isinstance(children, str) and isinstance(children, abc.Iterable): for c in list(children): self.add(c) @@ -868,24 +882,24 @@ def add(self, children): self.children.append(children) return self - def __iter__(self): + def __iter__(self) -> abc.Iterable[Any]: for c in self.children: yield c - def __len__(self): + def __len__(self) -> int: return len(self.children) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: return self.children[key] - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: self.children[key] = value class Filter(SqlComponent): """Represents single condition in SQL WHERE clause.""" - def __init__(self, cond: Optional[Text]): + def __init__(self, cond: str | None) -> None: super(Filter, self).__init__() self.cond = '' if isinstance(cond, Filter): @@ -893,7 +907,7 @@ def __init__(self, cond: Optional[Text]): elif cond: self.cond = cond.replace('==', '=') or '' - def __str__(self): + def __str__(self) -> str: if not self.cond: return '' return '(%s)' % self.cond if ' OR ' in self.cond.upper() else self.cond @@ -903,16 +917,16 @@ class Filters(SqlComponents): """Represents a bunch of SQL conditions.""" @property - def where(self): + def where(self) -> list[str]: return sorted((str(Filter(f)) for f in self.children)) - def remove(self, filters): + def remove(self, filters: Any) -> 'Filters': if not filters: return self self.children = [c for c in self.children if c not in Filters(filters)] return self - def __str__(self): + def __str__(self) -> str: return ' AND '.join(self.where) @@ -955,14 +969,14 @@ class Column(SqlComponent): def __init__( self, column, - fn: Text = '{}', - alias: Optional[Text] = None, + fn: str = '{}', + alias: str | None = None, filters=None, partition=None, order=None, window_frame=None, auto_alias=True, - ): + ) -> None: super(Column, self).__init__() self.column = [column] if isinstance(column, str) else column or [] self.fn = fn @@ -980,26 +994,26 @@ def __init__( self.suffix = 0 @property - def alias(self): + def alias(self) -> str: a = self.alias_raw if self.suffix: a = '%s_%s' % (a, self.suffix) return escape_alias(a) @alias.setter - def alias(self, alias): + def alias(self, alias: str) -> None: self.alias_raw = alias.strip('`') - def set_alias(self, alias): + def set_alias(self, alias: str) -> 'Column': self.alias = alias return self - def add_suffix(self): + def add_suffix(self) -> str: self.suffix += 1 return self.alias @property - def expression(self): + def expression(self) -> str: """Genereates the representation without the 'AS ...' part.""" over = None if not (self.partition is None and self.order is None and @@ -1028,7 +1042,7 @@ def expression(self): res = self.fn.format(*column) return res + over if over else res - def __str__(self): + def __str__(self) -> str: if not self.expression: return '' res = self.expression @@ -1036,42 +1050,42 @@ def __str__(self): return res return '%s AS %s' % (res, self.alias) - def __add__(self, other): + def __add__(self, other: Any) -> 'Column': return Column( '{} + {}'.format(*add_parenthesis_if_needed(self, other)), alias='%s + %s' % (self.alias_raw, get_alias(other))) - def __radd__(self, other): + def __radd__(self, other: Any) -> 'Column': alias = '%s + %s' % (get_alias(other), self.alias_raw) return Column( '{} + {}'.format(*add_parenthesis_if_needed(other, self)), alias=alias) - def __sub__(self, other): + def __sub__(self, other: Any) -> 'Column': return Column( '{} - {}'.format(*add_parenthesis_if_needed(self, other)), alias='%s - %s' % (self.alias_raw, get_alias(other))) - def __rsub__(self, other): + def __rsub__(self, other: Any) -> 'Column': alias = '%s - %s' % (get_alias(other), self.alias_raw) return Column( '{} - {}'.format(*add_parenthesis_if_needed(other, self)), alias=alias) - def __mul__(self, other): + def __mul__(self, other: Any) -> 'Column': return Column( '{} * {}'.format(*add_parenthesis_if_needed(self, other)), alias='%s * %s' % (self.alias_raw, get_alias(other))) - def __rmul__(self, other): + def __rmul__(self, other: Any) -> 'Column': alias = '%s * %s' % (get_alias(other), self.alias_raw) return Column( '{} * {}'.format(*add_parenthesis_if_needed(other, self)), alias=alias) - def __neg__(self): + def __neg__(self) -> 'Column': return Column( '-{}'.format(*add_parenthesis_if_needed(self)), alias='-%s' % self.alias_raw) - def __div__(self, other): + def __div__(self, other: Any) -> 'Column': return Column( SAFE_DIVIDE_FN( numer=self.expression, denom=getattr(other, 'expression', other) @@ -1079,10 +1093,10 @@ def __div__(self, other): alias='%s / %s' % (self.alias_raw, get_alias(other)), ) - def __truediv__(self, other): + def __truediv__(self, other: Any) -> 'Column': return self.__div__(other) - def __rdiv__(self, other): + def __rdiv__(self, other: Any) -> 'Column': alias = '%s / %s' % (get_alias(other), self.alias_raw) return Column( SAFE_DIVIDE_FN( @@ -1091,10 +1105,10 @@ def __rdiv__(self, other): alias=alias, ) - def __rtruediv__(self, other): + def __rtruediv__(self, other: Any) -> 'Column': return self.__rdiv__(other) - def __pow__(self, other): + def __pow__(self, other: Any) -> 'Column': if isinstance(other, float) and other == 0.5: return Column( 'SAFE.SQRT({})'.format(self.expression), @@ -1104,7 +1118,7 @@ def __pow__(self, other): getattr(other, 'expression', other)), alias='%s ^ %s' % (self.alias_raw, get_alias(other))) - def __rpow__(self, other): + def __rpow__(self, other: Any) -> 'Column': alias = '%s ^ %s' % (get_alias(other), self.alias_raw) return Column( 'SAFE.POWER({}, {})'.format( @@ -1112,7 +1126,7 @@ def __rpow__(self, other): alias=alias) -def add_parenthesis_if_needed(*columns): +def add_parenthesis_if_needed(*columns: Any) -> abc.Iterable[str]: for column in columns: if not isinstance(column, Column): yield column @@ -1127,7 +1141,9 @@ def add_parenthesis_if_needed(*columns): class Columns(SqlComponents): """Represents a bunch of SQL columns.""" - def __init__(self, columns=None, distinct=None): # pylint: disable=super-init-not-called + def __init__( + self, columns: Any | None = None, distinct: bool | None = None + ) -> None: super(Columns, self).__init__() self.add(columns) self.distinct = distinct @@ -1135,24 +1151,24 @@ def __init__(self, columns=None, distinct=None): # pylint: disable=super-init-n self.distinct = columns.distinct @property - def aliases(self): + def aliases(self) -> list[str]: return [c.alias for c in self] @property - def original_columns(self): + def original_columns(self) -> list[Any]: # Returns the original Column instances added. return [c.column[0] for c in self] - def get_matched_column(self, expression): + def get_matched_column(self, expression: str) -> Any | None: return next((c for c in self if c.expression == expression), None) - def get_column(self, alias): + def get_column(self, alias: str) -> Any | None: res = [c for c in self if c.alias == alias] if res: return res[0] return None - def add(self, children): + def add(self, children: Any) -> 'Columns': """Adds a Column if not existing. Renames it when necessary. @@ -1192,18 +1208,18 @@ def add(self, children): children.add_suffix() return self.add(children) - def difference(self, columns): + def difference(self, columns: Any) -> 'Columns': return Columns((c for c in self if c not in Columns(columns))) @property - def expression(self): + def expression(self) -> list[str]: return list(map(str, self)) @property - def expressions(self): + def expressions(self) -> list[str]: return [c.expression for c in self] - def get_columns(self, break_line=False, indent=True): + def get_columns(self, break_line: bool = False, indent: bool = True) -> str: delimiter = ',\n' if break_line else ', ' if indent: res = delimiter.join((' %s' % e for e in self.expression)) @@ -1211,17 +1227,19 @@ def get_columns(self, break_line=False, indent=True): res = delimiter.join(self.expression) return 'DISTINCT ' + res if self.distinct else res - def as_groupby(self): + def as_groupby(self) -> str: return GROUP_BY_FN(self) - def __str__(self): + def __str__(self) -> str: return self.get_columns(True) class Datasource(SqlComponent): """Represents a SQL datasource, could be a table name or a SQL query.""" - def __init__(self, table, alias=None): + def __init__( + self, table: str | SqlComponent, alias: str | None = None + ) -> None: super(Datasource, self).__init__() self.table = table self.alias = alias @@ -1235,7 +1253,7 @@ def __init__(self, table, alias=None): and 'WITH\n' not in str(self.table).upper() ) - def get_expression(self, form='FROM'): + def get_expression(self, form: str = 'FROM') -> str: """Gets the expression that can be used in a FROM or WITH clause.""" if form.upper() not in ('FROM', 'WITH'): raise ValueError('Unrecognized form for datasource!') @@ -1248,7 +1266,14 @@ def get_expression(self, form='FROM'): else: return str(self) - def join(self, other, on=None, using=None, join='', alias=None): + def join( + self, + other: str | SqlComponent, + on: str | Filter | abc.Iterable[str | Filter] | None = None, + using: str | Column | abc.Iterable[str | Column] | None = None, + join: str = '', + alias: str | None = None, + ) -> Join: return Join(self, other, on, using, join, alias) def get_source_prefix(self, col: Column) -> str: @@ -1256,7 +1281,7 @@ def get_source_prefix(self, col: Column) -> str: return self.table.get_source_prefix(col) return (self.alias or self.table) + '.{c}' - def __str__(self): + def __str__(self) -> str: table = self.table if self.is_table else '(%s)' % self.table # No "AS" between a table and its alias is supported by more dialects. return '%s %s' % (table, self.alias) if self.alias else str(table) @@ -1265,13 +1290,15 @@ def __str__(self): class Join(Datasource): """Represents a JOIN of two Datasources.""" - def __init__(self, - datasource1, - datasource2, - on=None, - using=None, - join='', - alias=None): + def __init__( + self, + datasource1: str | SqlComponent, + datasource2: str | SqlComponent, + on: str | Filter | abc.Iterable[str | Filter] | None = None, + using: str | Column | abc.Iterable[str | Column] | None = None, + join: str = '', + alias: str | None = None, + ) -> None: if on and using: raise ValueError('A JOIN cannot have both ON and USING condition!') if join.upper() not in ('', 'INNER', 'FULL', 'FULL OUTER', 'LEFT', @@ -1326,7 +1353,7 @@ def get_source_prefix(self, col: Column) -> str: return right return left - def __str__(self): + def __str__(self) -> str: if self.ds1 == self.ds2: return str(self.ds1) join_type = self.join_type @@ -1364,17 +1391,17 @@ def __str__(self): class Datasources(SqlComponents): """Represents a bunch of SQL datasources in a WITH clause.""" - def __init__(self, datasources=None): + def __init__(self, datasources: Any | None = None) -> None: super(Datasources, self).__init__() self.children = collections.OrderedDict() self.temp_tables = set() self.add(datasources) @property - def datasources(self): + def datasources(self) -> abc.Iterable[Datasource]: return (Datasource(v, k) for k, v in self.children.items()) - def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']): + def merge(self, new_child: Datasource | 'Datasources' | 'Sql') -> str | None: """Merges a datasource if possible. The difference between merge() and add() is that in add() we skip only when @@ -1430,7 +1457,7 @@ def merge(self, new_child: Union[Datasource, 'Datasources', 'Sql']): self.children[new_child.alias] = table return new_child.alias - def add(self, children: Union[Datasource, Iterable[Datasource]]): + def add(self, children: Datasource | abc.Iterable[Datasource]) -> str | None: """Adds a datasource if not existing. Renames it when necessary. @@ -1473,7 +1500,9 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]): children.alias = add_suffix(alias) return self.add(children) - def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]): + def add_temp_table( + self, table: str | 'Sql' | Join | Datasource + ) -> 'Datasources': """Marks alias and all its data dependencies as temp tables.""" if isinstance(table, str): self.temp_tables.add(table) @@ -1490,7 +1519,7 @@ def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]): return self.add_temp_table(table.from_data) return self - def extend(self, other: 'Datasources'): + def extend(self, other: 'Datasources') -> 'Datasources': """Merge other to self. Adjust the query if a new alias is needed.""" datasources = list(other.datasources) while datasources: @@ -1504,7 +1533,7 @@ def extend(self, other: 'Datasources'): str(d2.table)) return self - def __str__(self): + def __str__(self) -> str: temp_tables = [] with_tables = [] for d in self.datasources: @@ -1526,12 +1555,12 @@ class Sql(SqlComponent): def __init__( self, columns, - from_data: Union[str, 'Sql', Datasource], + from_data: str | 'Sql' | Datasource, where=None, groupby=None, with_data=None, orderby=None, - ): + ) -> None: super(Sql, self).__init__() self.columns = Columns(columns) self.where = Filters(where) @@ -1563,7 +1592,7 @@ def __init__( self.from_data = from_data_table.from_data @property - def all_columns(self): + def all_columns(self) -> Columns: """Returns all columns in the SELECT clause.""" cols = Columns(self.groupby).add(self.columns) if ( @@ -1583,11 +1612,11 @@ def all_columns(self): res.append(c) return Columns(res) - def add(self, attr, values): + def add(self, attr: str, values: Any) -> 'Sql': getattr(self, attr).add(values) return self - def merge(self, other: 'Sql'): + def merge(self, other: 'Sql') -> bool: """Merges columns from other to self if possible. If self and other are compatible, we can merge their columns. The columns @@ -1608,7 +1637,7 @@ def merge(self, other: 'Sql'): self.columns.add(other.columns) return True - def __str__(self): + def __str__(self) -> str: with_clause = str(self.with_data) if self.with_data else None all_columns = self.all_columns or '*' select_clause = f'SELECT\n{all_columns}' diff --git a/utils.py b/utils.py index 6af2b98..f6ac144 100644 --- a/utils.py +++ b/utils.py @@ -14,28 +14,57 @@ """Utils functions for things like DataFrame manipulation.""" from __future__ import absolute_import +from __future__ import annotations from __future__ import division from __future__ import print_function +from collections.abc import Iterable import copy import datetime import glob import os -from typing import Iterable, List, Optional, Text, Union +from typing import Any, TYPE_CHECKING, TypeVar from meterstick import sql import pandas as pd +# Number can also be any type that can do arithmetic operations with numbers, +# e.g., fractions.Fraction. We use int | float as a representative. +Number = int | float -def get_name(obj): +# ReturnType represents the types that can be returned by Metric.compute_on. +# - pd.DataFrame: standard return type for multi-metric or multi-slice results. +# - pd.Series: returned when the result is 1D (e.g. single metric with +# split_by). +# - Number: returned when the result is a single scalar. +# - list[Any]: returned by some internal computations or MetricList when +# return_dataframe=False. +ReturnType = pd.DataFrame | pd.Series | Number | list[Any] + +# TableType represents the allowed types for SQL table arguments. +# - str: a table name. +# - sql.Datasource: a datasource representation. +# - sql.Sql: a subquery. +TableType = str | sql.Datasource | sql.Sql +StrOrList = str | list[str] +NumberTypes = (int, float) + +T = TypeVar('T', bound=ReturnType) + + +if TYPE_CHECKING: + import apache_beam + + +def get_name(obj: Any) -> str: return getattr(obj, 'name', str(obj)) -def is_metric(m): +def is_metric(m: Any) -> bool: return hasattr(m, 'compute_on') -def melt(df): +def melt(df: T) -> T: """Stacks the outermost comlumn level to the outermost index level. Similar to pd.stack(0) except @@ -76,7 +105,7 @@ def melt(df): return remove_empty_level(df) -def unmelt(df): +def unmelt(df: T) -> T: """Unstacks the outermost index level to the outermost column level. Similar to pd.unstack(0) except @@ -109,16 +138,15 @@ def unmelt(df): # It should be removed when future_stack becomes the default in pandas 3.0. df = pd.DataFrame(df.stack(0, future_stack=True)).T else: - df = pd.concat([df.loc[n] for n in names], - axis=1, - keys=names, - names=['Metric']) + df = pd.concat( + [df.loc[n] for n in names], axis=1, keys=names, names=['Metric'] + ) if single_value_col: return df.droplevel(1, axis=1) return df -def remove_empty_level(df): +def remove_empty_level(df: T) -> T: """Drops redundant levels in the index of df.""" if not isinstance(df, pd.DataFrame) or not isinstance(df.index, pd.MultiIndex): @@ -131,7 +159,11 @@ def remove_empty_level(df): return df.droplevel(drop) -def apply_name_tmpl(name_tmpl, res, melted=False): +def apply_name_tmpl( + name_tmpl: str | None, + res: T, + melted: bool = False, +) -> T: """Applies name_tmpl to all columns or pd.Series.name.""" if not name_tmpl: return res @@ -155,7 +187,9 @@ def apply_name_tmpl(name_tmpl, res, melted=False): return res -def get_extra_split_by(metric, return_superset=False): +def get_extra_split_by( + metric: Any, return_superset: bool = False +) -> tuple[str, ...]: """Collects the extra split_by added by Operations for the metric tree. Args: @@ -183,7 +217,9 @@ def get_extra_split_by(metric, return_superset=False): return tuple(extra_split_by) -def get_leaf_metrics(metric, include_constants=False): +def get_leaf_metrics( + metric: Any, include_constants: bool = False +) -> list[Any]: leaf = [] for m in metric.traverse(include_constants=include_constants): if not getattr(m, 'children', []): @@ -191,7 +227,7 @@ def get_leaf_metrics(metric, include_constants=False): return leaf -def get_global_filter(metric) -> sql.Filters: +def get_global_filter(metric: Any) -> sql.Filters: """Collects the filters that can be applied globally to the Metric tree.""" global_filter = sql.Filters() if metric.where: @@ -208,7 +244,7 @@ def get_global_filter(metric) -> sql.Filters: return global_filter -def push_filters_to_leaf(metric, is_root=True): +def push_filters_to_leaf(metric: Any, is_root: bool = True) -> Any: """Returns a Metric that all filters have been pushed to leaf nodes. Note that the return can differ subtly to the original metric when computing @@ -293,13 +329,15 @@ class CacheKey: fingerprint: The unique identifier of CacheKey. Used to hash. """ - def __init__(self, - metric, - key, - where: Optional[Union[Text, Iterable[Text]]] = None, - split_by: Optional[Union[Text, List[Text]]] = None, - slice_val=None, - extra_info=()): + def __init__( + self, + metric: Any, + key: Any, + where: str | Iterable[str] | None = None, + split_by: StrOrList | None = None, + slice_val: dict[str, Any] | None = None, + extra_info: tuple[Any, ...] = (), + ): """Wraps cache_key, split_by, filters and slice information. Args: @@ -351,23 +389,23 @@ def __init__(self, 'where': tuple(sorted(tuple(self.where))), } - def add_extra_info(self, extra_info: str): + def add_extra_info(self, extra_info: str) -> None: self.extra_info = tuple(list(self.extra_info) + [extra_info]) self.fingerprint['extra_info'] = self.extra_info - def replace_key(self, key): + def replace_key(self, key: Any) -> CacheKey: new_key = copy.deepcopy(self) new_key.key = key new_key.fingerprint['key'] = key return new_key - def replace_metric(self, new_metric): + def replace_metric(self, new_metric: Any) -> CacheKey: new_key = copy.deepcopy(self) new_key.metric = new_metric new_key.fingerprint['metric'] = new_metric.get_fingerprint() return new_key - def replace_split_by(self, split_by): + def replace_split_by(self, split_by: str | Iterable[str] | None) -> CacheKey: split_by = split_by or () split_by = (split_by,) if isinstance(split_by, str) else tuple(split_by) new_key = copy.deepcopy(self) @@ -375,7 +413,7 @@ def replace_split_by(self, split_by): new_key.fingerprint['split_by'] = split_by return new_key - def replace_where(self, where): + def replace_where(self, where: str | Iterable[str] | None) -> CacheKey: where = (where,) if isinstance(where, str) else tuple(sorted(where)) or () new_key = copy.deepcopy(self) new_key.where = where @@ -411,8 +449,10 @@ def __repr__(self): def adjust_slices_for_loo( - bucket_res: pd.Series, split_by: Optional[List[Text]] = None, df=None -): + bucket_res: pd.Series | pd.DataFrame, + split_by: list[str] | None, + df: pd.DataFrame, +) -> pd.Series | pd.DataFrame: """Corrects the slices in the bucketized result. Jackknife has a precomputation step where we precompute leave-one-out (LOO) @@ -465,6 +505,8 @@ def adjust_slices_for_loo( operation_lvl = unit_and_operation_lvl[1:] split_by_and_unit = indexes[: len(split_by) + 1] unit = split_by_and_unit[-1] + if df is None: + raise ValueError('df cannot be None') expected_units = ( df.groupby(split_by_and_unit, observed=True).first().iloc[:, [0]] ) @@ -492,7 +534,9 @@ def adjust_slices_for_loo( return bucket_res.reindex(expected_slices, fill_value=0) -def get_fully_expanded_equivalent_metric_tree(m, df=None): +def get_fully_expanded_equivalent_metric_tree( + m: Any, df: pd.DataFrame | None = None +) -> tuple[Any, pd.DataFrame | None]: """Gets a Metric that is equivalent to m, and cannot be further expanded. Some Metrics can be expressed by simpler Metrics like Sum and Count. Sum and @@ -526,7 +570,9 @@ def get_fully_expanded_equivalent_metric_tree(m, df=None): return curr, df -def get_equivalent_metric_tree(m, df=None, prefix=''): +def get_equivalent_metric_tree( + m: Any, df: pd.DataFrame | None = None, prefix: str = '' +) -> Any: """Replaces Metrics in the tree of m with equivalent Metrics.""" if not is_metric(m): return m @@ -542,7 +588,9 @@ def get_equivalent_metric_tree(m, df=None, prefix=''): return equiv -def get_equivalent_metric(m, df=None, prefix=''): +def get_equivalent_metric( + m: Any, df: pd.DataFrame | None = None, prefix: str = '' +) -> tuple[Any, pd.DataFrame | None]: """Gets the equivalent Metric of m and adds auxiliary columns to df.""" if df is not None and not prefix: prefix = get_unique_prefix(df) @@ -551,16 +599,18 @@ def get_equivalent_metric(m, df=None, prefix=''): return equiv, df -def get_unique_prefix(df): +def get_unique_prefix(df: pd.DataFrame) -> str: prefix = 'meterstick_tmp:' while any(str(c).startswith(prefix) for c in df.columns): prefix += ':' return prefix -def add_auxiliary_cols(auxiliary_cols, - df: Optional[pd.DataFrame] = None, - prefix: str = ''): +def add_auxiliary_cols( + auxiliary_cols: Iterable[Any], + df: pd.DataFrame | None = None, + prefix: str = '', +) -> tuple[pd.DataFrame | None, list[str]]: """Parses auxiliary_cols from Metric.get_auxiliary_cols and adds them to df. Some Metrics can be expressed by simpler Metrics. For example, Dot(x, y) is @@ -602,7 +652,9 @@ def add_auxiliary_cols(auxiliary_cols, return df, auxiliary_col_names -def parse_auxiliary_col(auxiliary_col, df: Optional[pd.DataFrame] = None): +def parse_auxiliary_col( + auxiliary_col: Any, df: pd.DataFrame | None = None +) -> tuple[str, Any]: """Parses an auxiliary_col and computes it. Args: @@ -653,7 +705,10 @@ def parse_auxiliary_col(auxiliary_col, df: Optional[pd.DataFrame] = None): def pcollection_to_df_via_file_io( - pcol, pipeline, output_dir: str, cleanup=False + pcol: apache_beam.pvalue.PCollection, + pipeline: apache_beam.Pipeline, + output_dir: str, + cleanup: bool = False, ) -> pd.DataFrame: """Evaluates a PCollection, saves result, reads back to a DataFrame.