Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def make_failed_response(failed_message: str) -> tuple[GameState, list[BuyAssetR
player = game_state.players[msg.player_id]

if not msg.asset_id in game_state.assets.asset_ids:
return make_failed_response("Asset {msg.asset_id} does not exist.")
return make_failed_response(f"Asset {msg.asset_id} does not exist.")

asset = game_state.assets[msg.asset_id]
if not asset.is_for_sale:
Expand Down
20 changes: 19 additions & 1 deletion src/models/buses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass

from src.models.assets import AssetType
from src.models.data.ldc_repo import LdcRepo
from src.models.data.light_dc import LightDc
from src.models.ids import PlayerId, BusId
Expand All @@ -27,3 +26,22 @@ def _get_dc_type(cls) -> type[Bus]:
@property
def bus_ids(self) -> list[BusId]:
return [BusId(x) for x in self.df.index.tolist()]

@property
def npc_bus_ids(self) -> list[BusId]:
return self.filter({"player_id": PlayerId.get_npc()}).bus_ids

@property
def player_bus_ids(self) -> list[BusId]:
return self.filter(operator="not", condition={"player_id": PlayerId.get_npc()}).bus_ids

@property
def ice_cream_buses(self) -> list[Bus]:
"""Get all buses that are ice cream buses."""
return [self[b] for b in self.player_bus_ids]

def get_bus_for_player(self, player_id: PlayerId) -> Bus:
"""Get the bus for a specific player."""
player_buses = self.filter({"player_id": player_id})
assert len(player_buses) == 1
return player_buses.as_objs()[0]
70 changes: 44 additions & 26 deletions src/models/data/ldc_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,32 @@
Generic,
Any,
Iterator,
Self,
Callable,
overload,
Literal,
Optional,
Iterable,
TypeVar,
)

import pandas as pd

from src.models.data.light_dc import GenericLightDc
from src.models.data.light_dc import T_LightDc
from src.tools.serialization import simplify_type
from src.tools.typing import T

Condition = dict[str, Any] | Callable[[pd.Series], bool]
Operator = Literal["or", "and", "not", None]


class LdcRepo(Generic[GenericLightDc], ABC):
class LdcRepo(Generic[T_LightDc], ABC):
# A dataframe-based repo containing an indexed list of light dataclass objects

@classmethod
@abstractmethod
def _get_dc_type(cls) -> type[GenericLightDc]: ...
def _get_dc_type(cls) -> type[T_LightDc]: ...

def __init__(self, dcs: list[GenericLightDc] | pd.DataFrame) -> None:
def __init__(self, dcs: list[T_LightDc] | pd.DataFrame) -> None:
if isinstance(dcs, list):
assert len(dcs) > 0
assert [isinstance(dc, self._get_dc_type()) for dc in dcs]
Expand All @@ -50,12 +51,12 @@ def __init__(self, dcs: list[GenericLightDc] | pd.DataFrame) -> None:
self._df = df

@overload
def __getitem__(self, index: int) -> GenericLightDc: ...
def __getitem__(self, index: int) -> T_LightDc: ...

@overload
def __getitem__(self, index: str) -> pd.Series: ...

def __getitem__(self, x: int | str) -> GenericLightDc | pd.Series:
def __getitem__(self, x: int | str) -> T_LightDc | pd.Series:
if isinstance(x, str):
return self.df.loc[:, x]
assert isinstance(x, int)
Expand All @@ -65,7 +66,7 @@ def __getitem__(self, x: int | str) -> GenericLightDc | pd.Series:
row = self.df.loc[simple_x]
return self._get_dc_type().from_simple_dict({**row.to_dict(), "id": x})

def __iter__(self) -> Iterator[GenericLightDc]:
def __iter__(self) -> Iterator[T_LightDc]:
for dc_id in self._df.index:
yield self[dc_id]

Expand All @@ -75,7 +76,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}:\n{repr(self._df)}"

def __add__(self, other: Self | GenericLightDc) -> Self:
def __add__(self: T, other: T | T_LightDc) -> T:
if isinstance(other, self.__class__):
return self.__class__(pd.concat([self.df, other.df], axis=0))
elif isinstance(other, self._get_dc_type()):
Expand All @@ -87,11 +88,16 @@ def __add__(self, other: Self | GenericLightDc) -> Self:
def __len__(self) -> int:
return len(self._df)

def __eq__(self, other: "LdcRepo") -> bool:
if not type(self) == type(other):
return False
return self.df.equals(other.df)

# UPDATE
def add(self, x: Self | GenericLightDc) -> Self:
def add(self: T, x: T | T_LightDc) -> T:
return self + x

def update_frame(self, df: pd.DataFrame) -> Self:
def update_frame(self: T, df: pd.DataFrame) -> T:
return self.from_frame(df)

@property
Expand All @@ -110,7 +116,7 @@ def _condition_to_logical_indexer(self, condition: Condition) -> pd.Series:
def _filter(
self,
condition: Condition,
operator: Literal["or", "and"],
operator: Operator,
condition_2: Optional[Condition] = None,
) -> pd.Series:
"""
Expand All @@ -120,7 +126,7 @@ def _filter(
Advanced: A function that is called on the underlying series
Note that if a function is provided, the function is called on the underlying series which cannot contain
complex types. Make sure to convert to simple types before using them in the function
:param operator: "or" or "and" to combine two conditions
:param operator: "or", "and", "not" to combine two conditions or negate the first one
:param condition_2: A second condition to combine with the first one
:return: A logical indexer for the given condition or combination of conditions
>>> self._filter({"bus": BusId(1), "color": Color.Red})
Expand All @@ -131,20 +137,29 @@ def _filter(

"""
if condition_2 is None:
return self._condition_to_logical_indexer(condition)
elif operator == "and":
assert operator in ["not", None], f"Invalid operator for one condition: {operator}"
else:
assert operator in ["or", "and"], f"Invalid operator for two conditions: {operator}"

if condition_2 is None:
if operator == "not":
return ~self._condition_to_logical_indexer(condition)
else:
return self._condition_to_logical_indexer(condition)

if operator == "and":
return self._condition_to_logical_indexer(condition) & self._condition_to_logical_indexer(condition_2)
elif operator == "or":
return self._condition_to_logical_indexer(condition) | self._condition_to_logical_indexer(condition_2)
else:
raise ValueError(f"Invalid operator: {operator}. Use 'or' or 'and'.")
raise ValueError(f"Invalid operator {operator}")

def filter(
self,
self: T,
condition: Condition,
operator: Literal["or", "and"] = "or",
operator: Operator = None,
condition_2: Optional[Condition] = None,
) -> Self:
) -> T:
"""
Returns a copy of the repo filtered using the given condition
:return: The filtered LdcFrame
Expand All @@ -154,11 +169,11 @@ def filter(
return self.__class__(filtered_df)

def drop_items(
self,
self: T,
condition: Condition,
operator: Literal["or", "and"] = "or",
operator: Operator = None,
condition_2: Optional[Condition] = None,
) -> Self:
) -> T:
"""
Returns a copy of the repo with elements deleted using the given condition
:return: A new version of the LdcFrame with the items dropped
Expand All @@ -168,17 +183,17 @@ def drop_items(
index = logical_indexer.loc[logical_indexer].index
return self.from_frame(self.df.drop(index, axis=0))

def drop_by_ids(self, ids: Iterable[int]) -> Self:
def drop_by_ids(self: T, ids: Iterable[int]) -> T:
"""
:return: A copy of the repo with the elements with the given ids deleted
"""
simple_ids = [simplify_type(x) for x in ids]
return self.from_frame(self.df.drop(simple_ids, axis=0))

def drop_one(self, item: int) -> Self:
def drop_one(self: T, item: int) -> T:
return self.drop_by_ids([item])

def as_objs(self) -> list[GenericLightDc]:
def as_objs(self) -> list[T_LightDc]:
return list(self.__iter__())

def to_simple_dict(self) -> dict[str, Any]:
Expand All @@ -196,5 +211,8 @@ def from_simple_dict(cls: type[T], data: dict) -> T:
return cls(dcs)

@classmethod
def from_frame(cls, df: pd.DataFrame) -> Self:
def from_frame(cls: type[T], df: pd.DataFrame) -> T:
return cls(df)


T_LdcRepo = TypeVar("T_LdcRepo", bound=LdcRepo)
2 changes: 1 addition & 1 deletion src/models/data/light_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def from_simple_dict(cls, simple_dict: SimpleDict) -> Self:
return cls(**init_dict) # noqa


GenericLightDc = TypeVar("GenericLightDc", bound=LightDc)
T_LightDc = TypeVar("T_LightDc", bound=LightDc)
12 changes: 9 additions & 3 deletions src/models/game_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from src.models.assets import AssetRepo
from src.models.buses import BusRepo
from src.models.game_settings import GameSettings
from src.models.ids import (
PlayerId,
GameId,
)
from src.models.market_coupling_result import MarketCouplingResult
from src.models.player import PlayerRepo
from src.models.transmission import TransmissionRepo
from src.models.game_settings import GameSettings
from src.tools.serialization import (
simplify_type,
un_simplify_type,
Expand Down Expand Up @@ -50,7 +50,9 @@ def to_simple_dict(self) -> dict:
"buses": self.buses.to_simple_dict(),
"assets": self.assets.to_simple_dict(),
"transmission": self.transmission.to_simple_dict(),
"market_coupling_result": self.market_coupling_result.to_simple_dict(),
"market_coupling_result": (
self.market_coupling_result.to_simple_dict() if self.market_coupling_result else None
),
}

@classmethod
Expand All @@ -63,5 +65,9 @@ def from_simple_dict(cls, simple_dict: dict) -> Self:
buses=BusRepo.from_simple_dict(simple_dict["buses"]),
assets=AssetRepo.from_simple_dict(simple_dict["assets"]),
transmission=TransmissionRepo.from_simple_dict(simple_dict["transmission"]),
market_coupling_result=MarketCouplingResult.from_simple_dict(simple_dict["market_coupling_result"]),
market_coupling_result=(
MarketCouplingResult.from_simple_dict(simple_dict["market_coupling_result"])
if simple_dict.get("market_coupling_result")
else None
),
)
2 changes: 1 addition & 1 deletion src/models/transmission.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TransmissionInfo(LightDc):
purchase_cost: float = 0.0 # 0 = Not for sale

def __post_init__(self) -> None:
assert self.bus2 > self.bus1
assert self.bus2 > self.bus1, f"bus2 must be greater than bus1. Got {self.bus2} and {self.bus1}"
assert self.reactance > 0, f"Reactance must be positive. Got {self.reactance}"


Expand Down
Empty file added tests/test_engine/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tests/test_engine/test_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Callable
from unittest import TestCase

from src.engine.engine import Engine
from src.models.ids import PlayerId, AssetId
from src.models.message import PlayerToGameMessage, BuyAssetRequest, BuyAssetResponse
from src.models.player import Player
from tests.utils.comparisons import assert_game_states_are_equal, assert_game_states_are_not_equal
from tests.utils.game_state_maker import GameStateMaker
from tests.utils.repo_maker import PlayerRepoMaker


class DummyMessage(PlayerToGameMessage):
pass


class TestAssets(TestCase):
def test_bad_message(self) -> None:
game_state = GameStateMaker().make()
dumb_message = DummyMessage(player_id=PlayerId(5))
with self.assertRaises(NotImplementedError):
Engine.handle_update_bid_message(game_state=game_state, msg=dumb_message) # noqa

def test_update_bid_message(self) -> None:
player_repo = PlayerRepoMaker.make_quick()
rich_player = Player(id=PlayerId(100), name="Rich player", color="#000000", money=1000000, is_having_turn=True)
player_repo += rich_player
game_state = GameStateMaker().add_player_repo(player_repo).make()

is_for_sale_ids = game_state.assets.filter(condition={"is_for_sale": True}).asset_ids
not_for_sale_ids = game_state.assets.filter(condition={"is_for_sale": False}).asset_ids

def assert_fails_with_message_matching(request: BuyAssetRequest, x: Callable[[str], bool]) -> None:
new_game_state, msgs = Engine.handle_message(game_state=game_state, msg=request)
self.assertEqual(len(msgs), 1)
message = msgs[0]
self.assertIsInstance(message, BuyAssetResponse)
self.assertFalse(message.success)
self.assertTrue(x(message.message))
assert_game_states_are_equal(game_state1=game_state, game_state2=new_game_state)

msg = BuyAssetRequest(player_id=rich_player.id, asset_id=AssetId(-5))
assert_fails_with_message_matching(request=msg, x=lambda s: "asset" in s.lower())

msg = BuyAssetRequest(player_id=rich_player.id, asset_id=not_for_sale_ids[0])
assert_fails_with_message_matching(request=msg, x=lambda s: "for sale" in s.lower())

msg = BuyAssetRequest(player_id=rich_player.id, asset_id=is_for_sale_ids[0])
result_game_state, messages = Engine.handle_message(game_state=game_state, msg=msg)
self.assertEqual(len(messages), 1)
success_msg = messages[0]
self.assertIsInstance(success_msg, BuyAssetResponse)
self.assertTrue(success_msg.success)
assert_game_states_are_not_equal(game_state1=game_state, game_state2=result_game_state)

sold_asset = result_game_state.assets[is_for_sale_ids[0]]
self.assertEqual(sold_asset.owner_player, rich_player.id)
self.assertFalse(sold_asset.is_for_sale)
Loading