diff --git a/src/business_tools.py b/src/business_tools.py index c2762d1a..b118c867 100644 --- a/src/business_tools.py +++ b/src/business_tools.py @@ -21,3 +21,45 @@ def get_sales_from_csv(filename: str) -> float: def calculate_commission(premiums: list[float], rate: float = 0.1) -> float: """Return total commission in USD rounded to two decimals.""" return round(sum(premiums) * rate, 2) + + +def load_insurance_sales(filename: str) -> list[dict[str, str]]: + """Return all rows from an insurance sales CSV as dictionaries. + + Args: + filename: Path to ``insurance_sales.csv``. + + Returns: + A list of dictionaries, one per CSV row. + """ + with open(filename, newline="") as csvfile: + reader = csv.DictReader(csvfile) + return list(reader) + + +def total_commission(records: list[dict[str, str]]) -> float: + """Return the sum of the ``Commission`` column from insurance records. + + Args: + records: Rows loaded via :func:`load_insurance_sales`. + + Returns: + Total commission as a float. + """ + total = 0.0 + for row in records: + total += float(row["Commission"]) + return total + + +def filter_by_state(records: list[dict[str, str]], state: str) -> list[dict[str, str]]: + """Return only the rows matching a given state code. + + Args: + records: Insurance sale rows. + state: Two-letter state abbreviation. + + Returns: + Filtered list containing rows where ``State`` equals ``state``. + """ + return [row for row in records if row["State"] == state] diff --git a/tests/test_business_tools.py b/tests/test_business_tools.py index 7316b78a..a4b50b75 100644 --- a/tests/test_business_tools.py +++ b/tests/test_business_tools.py @@ -1,13 +1,19 @@ import os import sys -import pytest # Ensure the repository root is on sys.path so business_tools can be imported REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) sys.path.insert(0, os.path.join(REPO_ROOT, "src")) -from business_tools import calculate_profit, get_sales_from_csv, calculate_commission +from business_tools import ( # noqa: E402 + calculate_profit, + get_sales_from_csv, + calculate_commission, + load_insurance_sales, + total_commission, + filter_by_state, +) def test_calculate_profit(): @@ -26,3 +32,23 @@ def test_get_sales_from_csv(tmp_path): def test_calculate_commission(): premiums = [300, 700, 200] assert calculate_commission(premiums, rate=0.1) == 120.0 + + +def test_load_insurance_sales_and_total_commission(tmp_path): + src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") + dst = tmp_path / "insurance_sales.csv" + with open(src, "r") as fsrc, open(dst, "w") as fdst: + fdst.write(fsrc.read()) + records = load_insurance_sales(str(dst)) + assert len(records) == 15 + assert total_commission(records) == 2545.0 + + +def test_filter_by_state(tmp_path): + src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") + dst = tmp_path / "insurance_sales.csv" + with open(src, "r") as fsrc, open(dst, "w") as fdst: + fdst.write(fsrc.read()) + records = load_insurance_sales(str(dst)) + ca_records = filter_by_state(records, "CA") + assert len(ca_records) == 4