diff --git a/src/app/simple_front_end/plotting/base_plot_object.py b/src/app/simple_front_end/plotting/base_plot_object.py index 5e1aac0..d8af0df 100644 --- a/src/app/simple_front_end/plotting/base_plot_object.py +++ b/src/app/simple_front_end/plotting/base_plot_object.py @@ -5,6 +5,8 @@ import plotly.graph_objs as go from plotly.graph_objs import Scatter +from src.models.colors import Color + @dataclass(frozen=True) class Point: @@ -54,7 +56,7 @@ def centre(self) -> Point: @property @abstractmethod - def color(self) -> str: + def color(self) -> Color: pass @property @@ -71,6 +73,11 @@ def data_dict(self) -> dict[str, str]: def text_locations(self) -> list[Point]: return [self.centre] + @staticmethod + def deactivate_color(c: Color) -> Color: + h, s, v = c.hsv + return Color(x=(h, round(s / 2), round(v / 2)), color_model="hsv") + def render_hover_text(self) -> Scatter: hover_template = ( f"{self.title}

" @@ -82,7 +89,7 @@ def render_hover_text(self) -> Scatter: mode="markers", marker={ "size": 10, - "color": self.color, + "color": self.color.rgb_hex_str, "symbol": "circle", "line": {"width": 0.0}, "opacity": 0.0, # Make the marker invisible diff --git a/src/app/simple_front_end/plotting/colors.py b/src/app/simple_front_end/plotting/colors.py deleted file mode 100644 index 3836d43..0000000 --- a/src/app/simple_front_end/plotting/colors.py +++ /dev/null @@ -1,16 +0,0 @@ -def get_contrasting_color(color: str) -> str: - """ - Given a color in hex format, return a contrasting color (black or white). - """ - if not color.startswith("#") or len(color) != 7: - raise ValueError(f"Invalid color format: {color}. Expected hex format #RRGGBB.") - - r = int(color[1:3], 16) - g = int(color[3:5], 16) - b = int(color[5:7], 16) - - # Calculate the brightness of the color - brightness = (r * 299 + g * 587 + b * 114) / 1000 - - # Return black for light colors and white for dark colors - return "#000000" if brightness > 128 else "#FFFFFF" diff --git a/src/app/simple_front_end/plotting/po_asset.py b/src/app/simple_front_end/plotting/po_asset.py index 1f02221..02c2558 100644 --- a/src/app/simple_front_end/plotting/po_asset.py +++ b/src/app/simple_front_end/plotting/po_asset.py @@ -6,9 +6,9 @@ from plotly.graph_objs import Scatter from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject -from src.app.simple_front_end.plotting.colors import get_contrasting_color from src.app.simple_front_end.plotting.po_bus import PlotBus from src.models.assets import AssetInfo, AssetType +from src.models.colors import get_contrasting_color, Color from src.models.player import Player @@ -33,8 +33,8 @@ def title(self) -> str: return title @property - def color(self) -> str: - return self.owner.color + def color(self) -> Color: + return self.owner.color_obj @property def data_dict(self) -> dict[str, str]: @@ -67,15 +67,21 @@ def render_shape(self) -> Scatter: else: raise ValueError(f"Unknown asset type: {self.asset.asset_type}") + if self.asset.is_active: + color = self.color + else: + color = self.deactivate_color(self.color) + contrast_color = get_contrasting_color(color) + main = go.Scatter( x=x, y=y, mode="lines+text", text=[""] * (len(x) - 1) + [text], fill="toself", - fillcolor=self.color, + fillcolor=color.rgb_hex_str, line={"width": 0.0}, hoverinfo="skip", - textfont={"size": 10, "color": get_contrasting_color(self.color)}, + textfont={"size": 10, "color": contrast_color.rgb_hex_str}, ) return main diff --git a/src/app/simple_front_end/plotting/po_bus.py b/src/app/simple_front_end/plotting/po_bus.py index 9460d9e..edfd221 100644 --- a/src/app/simple_front_end/plotting/po_bus.py +++ b/src/app/simple_front_end/plotting/po_bus.py @@ -8,6 +8,7 @@ from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject from src.models.buses import Bus +from src.models.colors import Color from src.models.player import Player SocketSide = Literal["tr", "bl"] # Top Right or Bottom Left @@ -80,8 +81,8 @@ def title(self) -> str: return f"Bus{self.bus.id}" @property - def color(self) -> str: - return self.owner.color + def color(self) -> Color: + return self.owner.color_obj @property def data_dict(self) -> dict[str, str]: @@ -95,7 +96,7 @@ def render_shape(self) -> Scatter: x=[p.x for p in points], y=[p.y for p in points], fill="toself", - fillcolor=self.owner.color, + fillcolor=self.color.rgb_hex_str, line=dict(color="black", width=1), mode="lines", hoverinfo="skip", diff --git a/src/app/simple_front_end/plotting/po_line.py b/src/app/simple_front_end/plotting/po_line.py index 78b11a7..6d0d891 100644 --- a/src/app/simple_front_end/plotting/po_line.py +++ b/src/app/simple_front_end/plotting/po_line.py @@ -6,6 +6,7 @@ from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject, point_linspace from src.app.simple_front_end.plotting.po_bus import PlotBus +from src.models.colors import Color from src.models.player import Player from src.models.transmission import TransmissionInfo @@ -21,8 +22,8 @@ def title(self) -> str: return f"Line{self.line.id}" @property - def color(self) -> str: - return self.owner.color + def color(self) -> Color: + return self.owner.color_obj @property def data_dict(self) -> dict[str, str]: @@ -69,10 +70,16 @@ def text_locations(self) -> list[Point]: def render_shape(self) -> Scatter: points = self.vertices + + if self.line.is_active: + color = self.color + else: + color = self.deactivate_color(self.color) + scatter = go.Scatter( x=[p.x for p in points], y=[p.y for p in points], - line=dict(color=self.owner.color, width=3), + line=dict(color=color.rgb_hex_str, width=3), opacity=0.8, mode="lines", hoverinfo="skip", diff --git a/src/app/simple_front_end/plotting/po_player_legend.py b/src/app/simple_front_end/plotting/po_player_legend.py index a80bd88..da09062 100644 --- a/src/app/simple_front_end/plotting/po_player_legend.py +++ b/src/app/simple_front_end/plotting/po_player_legend.py @@ -1,14 +1,11 @@ from dataclasses import dataclass from functools import cached_property -from typing import Literal, Optional -import numpy as np import plotly.graph_objects as go from plotly.graph_objs import Scatter from src.app.simple_front_end.plotting.base_plot_object import Point, PlotObject -from src.app.simple_front_end.plotting.colors import get_contrasting_color -from src.models.buses import Bus +from src.models.colors import get_contrasting_color, Color from src.models.player import Player @@ -24,8 +21,8 @@ def title(self) -> str: return self.player.name @property - def color(self) -> str: - return self.player.color + def color(self) -> Color: + return self.player.color_obj @property def data_dict(self) -> dict[str, str]: @@ -42,9 +39,9 @@ def render_shape(self) -> Scatter: mode="lines+text", text=[""] * (len(points) - 1) + [self.player.name], fill="toself", - fillcolor=self.color, + fillcolor=self.color.rgb_hex_str, line=dict(color="black", width=0), - textfont={"size": 10, "color": get_contrasting_color(self.color)}, + textfont={"size": 10, "color": get_contrasting_color(self.color).rgb_hex_str}, hoverinfo="skip", ) return scatter diff --git a/src/models/assets.py b/src/models/assets.py index a62881d..dda7be9 100644 --- a/src/models/assets.py +++ b/src/models/assets.py @@ -27,6 +27,7 @@ class AssetInfo(LightDc): marginal_price: float = 0.0 bid_price: float = 0.0 is_ice_cream: bool = False # This is a special type of load + is_active: bool = True def __post_init__(self) -> None: if self.is_ice_cream: diff --git a/src/models/colors.py b/src/models/colors.py new file mode 100644 index 0000000..eb358f2 --- /dev/null +++ b/src/models/colors.py @@ -0,0 +1,127 @@ +import colorsys +from functools import cached_property +from typing import Union, Literal + +import numpy as np + + +class Color: + def __init__( + self, + x: Union[str, tuple[int, int, int], Literal["red", "green", "blue", "black", "gray", "white"]], + color_model: Literal["rgb", "hsv", "hls"] = "rgb", + ): + """ + You can create a color using a pre-defined color name, e.g. "red", "green", "blue", "black", "gray" "white". + >>> Color("red") + + Or pass an RGB hex string, e.g. "#FF5733". + >>> Color("#FF5733") + + You can pass a tuple of unsigned 8bit integers if you prefer + >>> Color((255, 87, 51)) + + The default model is RGB but if you want you can specify the color as hsv or hls + >>> Color((255, 87, 51), color_model="hsv") + """ + assert color_model in ["rgb", "hsv", "hls"], f"Invalid color model: {color_model}." + + expected_format = {"rgb": "#RRGGBB", "hsv": "#HHSSVV", "hls": "#HHLLSS"}[color_model] + + if isinstance(x, str): + if not x.startswith("#"): + x = { + "red": "#FF0000", + "blue": "#0000FF", + "green": "#00FF00", + "black": "#000000", + "gray": "#808080", + "white": "#FFFFFF", + }[x] + if not len(x) == 7: + raise ValueError(f"Invalid hex color format: {x}. Expected format {expected_format}.") + a = int(x[1:3], 16) + b = int(x[3:5], 16) + c = int(x[5:7], 16) + else: + assert len(x) == 3, f"{color_model.upper()} tuple must have exactly three elements. Received {x}." + assert all( + isinstance(value, int) for value in x + ), f"{color_model.upper()} values must be integers. Received {x}." + a, b, c = x + + abc = (a, b, c) + assert all(0 <= value <= 255 for value in abc), f"Values must be between 0 and 255. Received {abc}." + + self._color_model = color_model + self._abc = abc + + def __str__(self) -> str: + al, bl, cl = self._color_model + a, b, c = self._abc + return f"" + + def __repr__(self) -> str: + return str(self) + + def __eq__(self, other: "Color") -> bool: + if not isinstance(other, Color): + return False + return self.rgb_hex_str == other.rgb_hex_str + + def calculate_distance_factor(self, other: "Color") -> float: + # Returns a number between 0 and 1, where 0 means the colors are identical + assert isinstance(other, Color), f"Expected a Color instance, got {type(other)}." + diff = (np.array(self.rgb) - np.array(other.rgb)) / 255 + return float(np.linalg.norm(x=diff, ord=2) / np.sqrt(3)) + + @cached_property + def rgb(self) -> tuple[int, int, int]: + if self._color_model == "rgb": + r, g, b = self._abc + return r, g, b + if self._color_model == "hsv": + h, s, v = self._abc + r, g, b = colorsys.hsv_to_rgb(h / 255, s / 255, v / 255) + else: + h, l, s = self._abc + r, g, b = colorsys.hls_to_rgb(h / 255, l / 255, s / 255) + r, g, b = round(r * 255), round(g * 255), round(b * 255) + return r, g, b + + @cached_property + def hsv(self) -> tuple[int, int, int]: + if self._color_model == "hsv": + h, s, v = self._abc + else: + r, g, b = self.rgb + h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255) + h, s, v = round(h * 255), round(s * 255), round(v * 255) + return h, s, v + + @cached_property + def hls(self) -> tuple[int, int, int]: + if self._color_model == "hls": + h, l, s = self._abc + else: + r, g, b = self.rgb + h, l, s = colorsys.rgb_to_hls(r / 255, g / 255, b / 255) + h, l, s = round(h * 255), round(l * 255), round(s * 255) + return h, l, s + + @cached_property + def rgb_hex_str(self) -> str: + r, g, b = self.rgb + return f"#{r:02X}{g:02X}{b:02X}" + + @cached_property + def brightness_factor(self) -> float: + # A number between 0 and 1, where 0 is black and 1 is white. + return self.hls[1] / Color("white").hls[1] + + +def get_contrasting_color(color: Color) -> Color: + if color.brightness_factor < 0.5: + return Color("#FFFFFF") + else: + return Color("#000000") diff --git a/src/models/player.py b/src/models/player.py index adeb98d..250f056 100644 --- a/src/models/player.py +++ b/src/models/player.py @@ -1,6 +1,8 @@ from dataclasses import dataclass +from functools import cached_property from typing import Self, Callable +from src.models.colors import Color from src.models.data.ldc_repo import LdcRepo from src.models.data.light_dc import LightDc from src.models.ids import PlayerId @@ -19,6 +21,11 @@ def __post_init__(self) -> None: self.color.startswith("#") and len(self.color) == 7 and int(self.color[1:], 16) < 0xFFFFFF ), "Invalid color format" + @cached_property + def color_obj(self) -> Color: + # TODO change the main color property to a color object + return Color(x=self.color) + class PlayerRepo(LdcRepo[Player]): @classmethod diff --git a/src/models/transmission.py b/src/models/transmission.py index 6b0b8d2..7ae6bc7 100644 --- a/src/models/transmission.py +++ b/src/models/transmission.py @@ -17,6 +17,7 @@ class TransmissionInfo(LightDc): operating_cost: float = 0.0 is_for_sale: bool = False purchase_cost: float = 0.0 # 0 = Not for sale + is_active: bool = True def __post_init__(self) -> None: assert self.bus2 > self.bus1, f"bus2 must be greater than bus1. Got {self.bus2} and {self.bus1}" diff --git a/tests/test_models/test_colors.py b/tests/test_models/test_colors.py new file mode 100644 index 0000000..044084d --- /dev/null +++ b/tests/test_models/test_colors.py @@ -0,0 +1,68 @@ +from unittest import TestCase + +from src.models.colors import Color + + +class TestColors(TestCase): + def test_making_colours(self) -> None: + red = Color("red") + self.assertEqual(red.rgb_hex_str, "#FF0000") + self.assertEqual(red.rgb, (255, 0, 0)) + self.assertEqual(red.hsv, (0, 255, 255)) + + other_reds = [ + Color((255, 0, 0)), + Color(x=(0, 255, 255), color_model="hsv"), + Color(x="#00FFFF", color_model="hsv"), + ] + for r in other_reds: + self.assertEqual(red, r) + + black = Color("black") + self.assertEqual(black.rgb_hex_str, "#000000") + self.assertEqual(black.rgb, (0, 0, 0)) + self.assertEqual(black.hsv, (0, 0, 0)) + + white = Color("white") + self.assertEqual(white.rgb_hex_str, "#FFFFFF") + self.assertEqual(white.rgb, (255, 255, 255)) + self.assertEqual(white.hsv, (0, 0, 255)) + + def test_brightness_factor(self) -> None: + black = Color("black") + gray = Color("gray") + white = Color("white") + red = Color("red") + green = Color("green") + blue = Color("blue") + + self.assertEqual(black.brightness_factor, 0.0) + self.assertAlmostEqual(gray.brightness_factor, 0.5, places=2) + self.assertEqual(white.brightness_factor, 1.0) + + for color in [red, green, blue]: + self.assertAlmostEqual(color.brightness_factor, 0.5, places=2) + + def test_color_distance(self) -> None: + red = Color("red") + blue = Color("blue") + green = Color("green") + black = Color("black") + white = Color("white") + + self.assertEqual(black.calculate_distance_factor(other=black), 0.0) + self.assertEqual(black.calculate_distance_factor(other=white), 1.0) + + mid_distance = black.calculate_distance_factor(other=red) + self.assertTrue(0.4 < mid_distance < 0.6, f"Expected mid distance to be around 0.5, got {mid_distance}") + self.assertEqual(black.calculate_distance_factor(other=green), mid_distance) + self.assertEqual(black.calculate_distance_factor(other=blue), mid_distance) + + other_distance = white.calculate_distance_factor(other=red) + self.assertTrue(0.7 < other_distance < 0.9, f"Expected other distance to be around 0.8, got {other_distance}") + self.assertEqual(white.calculate_distance_factor(other=green), other_distance) + self.assertEqual(white.calculate_distance_factor(other=blue), other_distance) + + self.assertEqual(red.calculate_distance_factor(other=green), other_distance) + self.assertEqual(red.calculate_distance_factor(other=blue), other_distance) + self.assertEqual(green.calculate_distance_factor(other=blue), other_distance) diff --git a/tests/utils/repo_maker.py b/tests/utils/repo_maker.py index 9a9aaae..2437ecf 100644 --- a/tests/utils/repo_maker.py +++ b/tests/utils/repo_maker.py @@ -6,6 +6,7 @@ from src.models.assets import AssetRepo, AssetInfo, AssetType from src.models.buses import Bus, BusRepo +from src.models.colors import Color from src.models.data.ldc_repo import T_LdcRepo from src.models.data.light_dc import T_LightDc from src.models.ids import PlayerId, AssetId, BusId, TransmissionId @@ -117,10 +118,15 @@ def make_quick(cls, n: int = 3) -> PlayerRepo: def _make_dc(self) -> Player: player_id = next(self.id_counter) + hue = np.random.randint(0, 255) + saturation = np.random.randint(200, 255) + value = 200 + + color = Color(x=(hue, saturation, value), color_model="hsv").rgb_hex_str return Player( id=PlayerId(player_id), name=f"Player {player_id}", - color=f"#{np.random.randint(0, 0xFFFFFF):06X}", + color=color, money=float(np.random.rand() * 100), # Just an example of money is_having_turn=False, ) @@ -235,6 +241,7 @@ def _make_dc( marginal_price=marginal_price, bid_price=bid_price, is_ice_cream=is_icecream, + is_active=np.random.rand() > 0.2, ) def _get_repo_type(self) -> type[AssetRepo]: @@ -288,6 +295,7 @@ def _make_dc( operating_cost=float(np.random.rand() * 100), is_for_sale=random_choice([True, False]), purchase_cost=float(np.random.rand() * 1000) if random_choice([True, False]) else 0.0, + is_active=np.random.rand() > 0.2, ) def _get_repo_type(self) -> type[TransmissionRepo]: