diff --git a/src/app/simple_front_end/plotting/base_plot_object.py b/src/app/simple_front_end/plotting/base_plot_object.py
index 5e1aac0..d8af0df 100644
--- a/src/app/simple_front_end/plotting/base_plot_object.py
+++ b/src/app/simple_front_end/plotting/base_plot_object.py
@@ -5,6 +5,8 @@
import plotly.graph_objs as go
from plotly.graph_objs import Scatter
+from src.models.colors import Color
+
@dataclass(frozen=True)
class Point:
@@ -54,7 +56,7 @@ def centre(self) -> Point:
@property
@abstractmethod
- def color(self) -> str:
+ def color(self) -> Color:
pass
@property
@@ -71,6 +73,11 @@ def data_dict(self) -> dict[str, str]:
def text_locations(self) -> list[Point]:
return [self.centre]
+ @staticmethod
+ def deactivate_color(c: Color) -> Color:
+ h, s, v = c.hsv
+ return Color(x=(h, round(s / 2), round(v / 2)), color_model="hsv")
+
def render_hover_text(self) -> Scatter:
hover_template = (
f"{self.title}
"
@@ -82,7 +89,7 @@ def render_hover_text(self) -> Scatter:
mode="markers",
marker={
"size": 10,
- "color": self.color,
+ "color": self.color.rgb_hex_str,
"symbol": "circle",
"line": {"width": 0.0},
"opacity": 0.0, # Make the marker invisible
diff --git a/src/app/simple_front_end/plotting/colors.py b/src/app/simple_front_end/plotting/colors.py
deleted file mode 100644
index 3836d43..0000000
--- a/src/app/simple_front_end/plotting/colors.py
+++ /dev/null
@@ -1,16 +0,0 @@
-def get_contrasting_color(color: str) -> str:
- """
- Given a color in hex format, return a contrasting color (black or white).
- """
- if not color.startswith("#") or len(color) != 7:
- raise ValueError(f"Invalid color format: {color}. Expected hex format #RRGGBB.")
-
- r = int(color[1:3], 16)
- g = int(color[3:5], 16)
- b = int(color[5:7], 16)
-
- # Calculate the brightness of the color
- brightness = (r * 299 + g * 587 + b * 114) / 1000
-
- # Return black for light colors and white for dark colors
- return "#000000" if brightness > 128 else "#FFFFFF"
diff --git a/src/app/simple_front_end/plotting/po_asset.py b/src/app/simple_front_end/plotting/po_asset.py
index 1f02221..02c2558 100644
--- a/src/app/simple_front_end/plotting/po_asset.py
+++ b/src/app/simple_front_end/plotting/po_asset.py
@@ -6,9 +6,9 @@
from plotly.graph_objs import Scatter
from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject
-from src.app.simple_front_end.plotting.colors import get_contrasting_color
from src.app.simple_front_end.plotting.po_bus import PlotBus
from src.models.assets import AssetInfo, AssetType
+from src.models.colors import get_contrasting_color, Color
from src.models.player import Player
@@ -33,8 +33,8 @@ def title(self) -> str:
return title
@property
- def color(self) -> str:
- return self.owner.color
+ def color(self) -> Color:
+ return self.owner.color_obj
@property
def data_dict(self) -> dict[str, str]:
@@ -67,15 +67,21 @@ def render_shape(self) -> Scatter:
else:
raise ValueError(f"Unknown asset type: {self.asset.asset_type}")
+ if self.asset.is_active:
+ color = self.color
+ else:
+ color = self.deactivate_color(self.color)
+ contrast_color = get_contrasting_color(color)
+
main = go.Scatter(
x=x,
y=y,
mode="lines+text",
text=[""] * (len(x) - 1) + [text],
fill="toself",
- fillcolor=self.color,
+ fillcolor=color.rgb_hex_str,
line={"width": 0.0},
hoverinfo="skip",
- textfont={"size": 10, "color": get_contrasting_color(self.color)},
+ textfont={"size": 10, "color": contrast_color.rgb_hex_str},
)
return main
diff --git a/src/app/simple_front_end/plotting/po_bus.py b/src/app/simple_front_end/plotting/po_bus.py
index 9460d9e..edfd221 100644
--- a/src/app/simple_front_end/plotting/po_bus.py
+++ b/src/app/simple_front_end/plotting/po_bus.py
@@ -8,6 +8,7 @@
from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject
from src.models.buses import Bus
+from src.models.colors import Color
from src.models.player import Player
SocketSide = Literal["tr", "bl"] # Top Right or Bottom Left
@@ -80,8 +81,8 @@ def title(self) -> str:
return f"Bus{self.bus.id}"
@property
- def color(self) -> str:
- return self.owner.color
+ def color(self) -> Color:
+ return self.owner.color_obj
@property
def data_dict(self) -> dict[str, str]:
@@ -95,7 +96,7 @@ def render_shape(self) -> Scatter:
x=[p.x for p in points],
y=[p.y for p in points],
fill="toself",
- fillcolor=self.owner.color,
+ fillcolor=self.color.rgb_hex_str,
line=dict(color="black", width=1),
mode="lines",
hoverinfo="skip",
diff --git a/src/app/simple_front_end/plotting/po_line.py b/src/app/simple_front_end/plotting/po_line.py
index 78b11a7..6d0d891 100644
--- a/src/app/simple_front_end/plotting/po_line.py
+++ b/src/app/simple_front_end/plotting/po_line.py
@@ -6,6 +6,7 @@
from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject, point_linspace
from src.app.simple_front_end.plotting.po_bus import PlotBus
+from src.models.colors import Color
from src.models.player import Player
from src.models.transmission import TransmissionInfo
@@ -21,8 +22,8 @@ def title(self) -> str:
return f"Line{self.line.id}"
@property
- def color(self) -> str:
- return self.owner.color
+ def color(self) -> Color:
+ return self.owner.color_obj
@property
def data_dict(self) -> dict[str, str]:
@@ -69,10 +70,16 @@ def text_locations(self) -> list[Point]:
def render_shape(self) -> Scatter:
points = self.vertices
+
+ if self.line.is_active:
+ color = self.color
+ else:
+ color = self.deactivate_color(self.color)
+
scatter = go.Scatter(
x=[p.x for p in points],
y=[p.y for p in points],
- line=dict(color=self.owner.color, width=3),
+ line=dict(color=color.rgb_hex_str, width=3),
opacity=0.8,
mode="lines",
hoverinfo="skip",
diff --git a/src/app/simple_front_end/plotting/po_player_legend.py b/src/app/simple_front_end/plotting/po_player_legend.py
index a80bd88..da09062 100644
--- a/src/app/simple_front_end/plotting/po_player_legend.py
+++ b/src/app/simple_front_end/plotting/po_player_legend.py
@@ -1,14 +1,11 @@
from dataclasses import dataclass
from functools import cached_property
-from typing import Literal, Optional
-import numpy as np
import plotly.graph_objects as go
from plotly.graph_objs import Scatter
from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject
-from src.app.simple_front_end.plotting.colors import get_contrasting_color
-from src.models.buses import Bus
+from src.models.colors import get_contrasting_color, Color
from src.models.player import Player
@@ -24,8 +21,8 @@ def title(self) -> str:
return self.player.name
@property
- def color(self) -> str:
- return self.player.color
+ def color(self) -> Color:
+ return self.player.color_obj
@property
def data_dict(self) -> dict[str, str]:
@@ -42,9 +39,9 @@ def render_shape(self) -> Scatter:
mode="lines+text",
text=[""] * (len(points) - 1) + [self.player.name],
fill="toself",
- fillcolor=self.color,
+ fillcolor=self.color.rgb_hex_str,
line=dict(color="black", width=0),
- textfont={"size": 10, "color": get_contrasting_color(self.color)},
+ textfont={"size": 10, "color": get_contrasting_color(self.color).rgb_hex_str},
hoverinfo="skip",
)
return scatter
diff --git a/src/models/assets.py b/src/models/assets.py
index a62881d..dda7be9 100644
--- a/src/models/assets.py
+++ b/src/models/assets.py
@@ -27,6 +27,7 @@ class AssetInfo(LightDc):
marginal_price: float = 0.0
bid_price: float = 0.0
is_ice_cream: bool = False # This is a special type of load
+ is_active: bool = True
def __post_init__(self) -> None:
if self.is_ice_cream:
diff --git a/src/models/colors.py b/src/models/colors.py
new file mode 100644
index 0000000..eb358f2
--- /dev/null
+++ b/src/models/colors.py
@@ -0,0 +1,127 @@
+import colorsys
+from functools import cached_property
+from typing import Union, Literal
+
+import numpy as np
+
+
+class Color:
+ def __init__(
+ self,
+ x: Union[str, tuple[int, int, int], Literal["red", "green", "blue", "black", "gray", "white"]],
+ color_model: Literal["rgb", "hsv", "hls"] = "rgb",
+ ):
+ """
+ You can create a color using a pre-defined color name, e.g. "red", "green", "blue", "black", "gray" "white".
+ >>> Color("red")
+
+ Or pass an RGB hex string, e.g. "#FF5733".
+ >>> Color("#FF5733")
+
+ You can pass a tuple of unsigned 8bit integers if you prefer
+ >>> Color((255, 87, 51))
+
+ The default model is RGB but if you want you can specify the color as hsv or hls
+ >>> Color((255, 87, 51), color_model="hsv")
+ """
+ assert color_model in ["rgb", "hsv", "hls"], f"Invalid color model: {color_model}."
+
+ expected_format = {"rgb": "#RRGGBB", "hsv": "#HHSSVV", "hls": "#HHLLSS"}[color_model]
+
+ if isinstance(x, str):
+ if not x.startswith("#"):
+ x = {
+ "red": "#FF0000",
+ "blue": "#0000FF",
+ "green": "#00FF00",
+ "black": "#000000",
+ "gray": "#808080",
+ "white": "#FFFFFF",
+ }[x]
+ if not len(x) == 7:
+ raise ValueError(f"Invalid hex color format: {x}. Expected format {expected_format}.")
+ a = int(x[1:3], 16)
+ b = int(x[3:5], 16)
+ c = int(x[5:7], 16)
+ else:
+ assert len(x) == 3, f"{color_model.upper()} tuple must have exactly three elements. Received {x}."
+ assert all(
+ isinstance(value, int) for value in x
+ ), f"{color_model.upper()} values must be integers. Received {x}."
+ a, b, c = x
+
+ abc = (a, b, c)
+ assert all(0 <= value <= 255 for value in abc), f"Values must be between 0 and 255. Received {abc}."
+
+ self._color_model = color_model
+ self._abc = abc
+
+ def __str__(self) -> str:
+ al, bl, cl = self._color_model
+ a, b, c = self._abc
+ return f""
+
+ def __repr__(self) -> str:
+ return str(self)
+
+ def __eq__(self, other: "Color") -> bool:
+ if not isinstance(other, Color):
+ return False
+ return self.rgb_hex_str == other.rgb_hex_str
+
+ def calculate_distance_factor(self, other: "Color") -> float:
+ # Returns a number between 0 and 1, where 0 means the colors are identical
+ assert isinstance(other, Color), f"Expected a Color instance, got {type(other)}."
+ diff = (np.array(self.rgb) - np.array(other.rgb)) / 255
+ return float(np.linalg.norm(x=diff, ord=2) / np.sqrt(3))
+
+ @cached_property
+ def rgb(self) -> tuple[int, int, int]:
+ if self._color_model == "rgb":
+ r, g, b = self._abc
+ return r, g, b
+ if self._color_model == "hsv":
+ h, s, v = self._abc
+ r, g, b = colorsys.hsv_to_rgb(h / 255, s / 255, v / 255)
+ else:
+ h, l, s = self._abc
+ r, g, b = colorsys.hls_to_rgb(h / 255, l / 255, s / 255)
+ r, g, b = round(r * 255), round(g * 255), round(b * 255)
+ return r, g, b
+
+ @cached_property
+ def hsv(self) -> tuple[int, int, int]:
+ if self._color_model == "hsv":
+ h, s, v = self._abc
+ else:
+ r, g, b = self.rgb
+ h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)
+ h, s, v = round(h * 255), round(s * 255), round(v * 255)
+ return h, s, v
+
+ @cached_property
+ def hls(self) -> tuple[int, int, int]:
+ if self._color_model == "hls":
+ h, l, s = self._abc
+ else:
+ r, g, b = self.rgb
+ h, l, s = colorsys.rgb_to_hls(r / 255, g / 255, b / 255)
+ h, l, s = round(h * 255), round(l * 255), round(s * 255)
+ return h, l, s
+
+ @cached_property
+ def rgb_hex_str(self) -> str:
+ r, g, b = self.rgb
+ return f"#{r:02X}{g:02X}{b:02X}"
+
+ @cached_property
+ def brightness_factor(self) -> float:
+ # A number between 0 and 1, where 0 is black and 1 is white.
+ return self.hls[1] / Color("white").hls[1]
+
+
+def get_contrasting_color(color: Color) -> Color:
+ if color.brightness_factor < 0.5:
+ return Color("#FFFFFF")
+ else:
+ return Color("#000000")
diff --git a/src/models/player.py b/src/models/player.py
index adeb98d..250f056 100644
--- a/src/models/player.py
+++ b/src/models/player.py
@@ -1,6 +1,8 @@
from dataclasses import dataclass
+from functools import cached_property
from typing import Self, Callable
+from src.models.colors import Color
from src.models.data.ldc_repo import LdcRepo
from src.models.data.light_dc import LightDc
from src.models.ids import PlayerId
@@ -19,6 +21,11 @@ def __post_init__(self) -> None:
self.color.startswith("#") and len(self.color) == 7 and int(self.color[1:], 16) < 0xFFFFFF
), "Invalid color format"
+ @cached_property
+ def color_obj(self) -> Color:
+ # TODO change the main color property to a color object
+ return Color(x=self.color)
+
class PlayerRepo(LdcRepo[Player]):
@classmethod
diff --git a/src/models/transmission.py b/src/models/transmission.py
index 6b0b8d2..7ae6bc7 100644
--- a/src/models/transmission.py
+++ b/src/models/transmission.py
@@ -17,6 +17,7 @@ class TransmissionInfo(LightDc):
operating_cost: float = 0.0
is_for_sale: bool = False
purchase_cost: float = 0.0 # 0 = Not for sale
+ is_active: bool = True
def __post_init__(self) -> None:
assert self.bus2 > self.bus1, f"bus2 must be greater than bus1. Got {self.bus2} and {self.bus1}"
diff --git a/tests/test_models/test_colors.py b/tests/test_models/test_colors.py
new file mode 100644
index 0000000..044084d
--- /dev/null
+++ b/tests/test_models/test_colors.py
@@ -0,0 +1,68 @@
+from unittest import TestCase
+
+from src.models.colors import Color
+
+
+class TestColors(TestCase):
+ def test_making_colours(self) -> None:
+ red = Color("red")
+ self.assertEqual(red.rgb_hex_str, "#FF0000")
+ self.assertEqual(red.rgb, (255, 0, 0))
+ self.assertEqual(red.hsv, (0, 255, 255))
+
+ other_reds = [
+ Color((255, 0, 0)),
+ Color(x=(0, 255, 255), color_model="hsv"),
+ Color(x="#00FFFF", color_model="hsv"),
+ ]
+ for r in other_reds:
+ self.assertEqual(red, r)
+
+ black = Color("black")
+ self.assertEqual(black.rgb_hex_str, "#000000")
+ self.assertEqual(black.rgb, (0, 0, 0))
+ self.assertEqual(black.hsv, (0, 0, 0))
+
+ white = Color("white")
+ self.assertEqual(white.rgb_hex_str, "#FFFFFF")
+ self.assertEqual(white.rgb, (255, 255, 255))
+ self.assertEqual(white.hsv, (0, 0, 255))
+
+ def test_brightness_factor(self) -> None:
+ black = Color("black")
+ gray = Color("gray")
+ white = Color("white")
+ red = Color("red")
+ green = Color("green")
+ blue = Color("blue")
+
+ self.assertEqual(black.brightness_factor, 0.0)
+ self.assertAlmostEqual(gray.brightness_factor, 0.5, places=2)
+ self.assertEqual(white.brightness_factor, 1.0)
+
+ for color in [red, green, blue]:
+ self.assertAlmostEqual(color.brightness_factor, 0.5, places=2)
+
+ def test_color_distance(self) -> None:
+ red = Color("red")
+ blue = Color("blue")
+ green = Color("green")
+ black = Color("black")
+ white = Color("white")
+
+ self.assertEqual(black.calculate_distance_factor(other=black), 0.0)
+ self.assertEqual(black.calculate_distance_factor(other=white), 1.0)
+
+ mid_distance = black.calculate_distance_factor(other=red)
+ self.assertTrue(0.4 < mid_distance < 0.6, f"Expected mid distance to be around 0.5, got {mid_distance}")
+ self.assertEqual(black.calculate_distance_factor(other=green), mid_distance)
+ self.assertEqual(black.calculate_distance_factor(other=blue), mid_distance)
+
+ other_distance = white.calculate_distance_factor(other=red)
+ self.assertTrue(0.7 < other_distance < 0.9, f"Expected other distance to be around 0.8, got {other_distance}")
+ self.assertEqual(white.calculate_distance_factor(other=green), other_distance)
+ self.assertEqual(white.calculate_distance_factor(other=blue), other_distance)
+
+ self.assertEqual(red.calculate_distance_factor(other=green), other_distance)
+ self.assertEqual(red.calculate_distance_factor(other=blue), other_distance)
+ self.assertEqual(green.calculate_distance_factor(other=blue), other_distance)
diff --git a/tests/utils/repo_maker.py b/tests/utils/repo_maker.py
index 9a9aaae..2437ecf 100644
--- a/tests/utils/repo_maker.py
+++ b/tests/utils/repo_maker.py
@@ -6,6 +6,7 @@
from src.models.assets import AssetRepo, AssetInfo, AssetType
from src.models.buses import Bus, BusRepo
+from src.models.colors import Color
from src.models.data.ldc_repo import T_LdcRepo
from src.models.data.light_dc import T_LightDc
from src.models.ids import PlayerId, AssetId, BusId, TransmissionId
@@ -117,10 +118,15 @@ def make_quick(cls, n: int = 3) -> PlayerRepo:
def _make_dc(self) -> Player:
player_id = next(self.id_counter)
+ hue = np.random.randint(0, 255)
+ saturation = np.random.randint(200, 255)
+ value = 200
+
+ color = Color(x=(hue, saturation, value), color_model="hsv").rgb_hex_str
return Player(
id=PlayerId(player_id),
name=f"Player {player_id}",
- color=f"#{np.random.randint(0, 0xFFFFFF):06X}",
+ color=color,
money=float(np.random.rand() * 100), # Just an example of money
is_having_turn=False,
)
@@ -235,6 +241,7 @@ def _make_dc(
marginal_price=marginal_price,
bid_price=bid_price,
is_ice_cream=is_icecream,
+ is_active=np.random.rand() > 0.2,
)
def _get_repo_type(self) -> type[AssetRepo]:
@@ -288,6 +295,7 @@ def _make_dc(
operating_cost=float(np.random.rand() * 100),
is_for_sale=random_choice([True, False]),
purchase_cost=float(np.random.rand() * 1000) if random_choice([True, False]) else 0.0,
+ is_active=np.random.rand() > 0.2,
)
def _get_repo_type(self) -> type[TransmissionRepo]: