Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/business_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,45 @@ def get_sales_from_csv(filename: str) -> float:
def calculate_commission(premiums: list[float], rate: float = 0.1) -> float:
"""Return total commission in USD rounded to two decimals."""
return round(sum(premiums) * rate, 2)


def load_insurance_sales(filename: str) -> list[dict[str, str]]:
"""Return all rows from an insurance sales CSV as dictionaries.

Args:
filename: Path to ``insurance_sales.csv``.

Returns:
A list of dictionaries, one per CSV row.
"""
with open(filename, newline="") as csvfile:
reader = csv.DictReader(csvfile)
return list(reader)


def total_commission(records: list[dict[str, str]]) -> float:
"""Return the sum of the ``Commission`` column from insurance records.

Args:
records: Rows loaded via :func:`load_insurance_sales`.

Returns:
Total commission as a float.
"""
total = 0.0
for row in records:
total += float(row["Commission"])
return total


def filter_by_state(records: list[dict[str, str]], state: str) -> list[dict[str, str]]:
"""Return only the rows matching a given state code.

Args:
records: Insurance sale rows.
state: Two-letter state abbreviation.

Returns:
Filtered list containing rows where ``State`` equals ``state``.
"""
return [row for row in records if row["State"] == state]
30 changes: 28 additions & 2 deletions tests/test_business_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@

import os
import sys
import pytest

# Ensure the repository root is on sys.path so business_tools can be imported
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.insert(0, os.path.join(REPO_ROOT, "src"))

from business_tools import calculate_profit, get_sales_from_csv, calculate_commission
from business_tools import ( # noqa: E402
calculate_profit,
get_sales_from_csv,
calculate_commission,
load_insurance_sales,
total_commission,
filter_by_state,
)


def test_calculate_profit():
Expand All @@ -26,3 +32,23 @@ def test_get_sales_from_csv(tmp_path):
def test_calculate_commission():
premiums = [300, 700, 200]
assert calculate_commission(premiums, rate=0.1) == 120.0


def test_load_insurance_sales_and_total_commission(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
assert len(records) == 15
assert total_commission(records) == 2545.0


def test_filter_by_state(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
ca_records = filter_by_state(records, "CA")
assert len(ca_records) == 4