From bf79e850039f72dbf393194409681ce3ff4d0802 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Wed, 15 Jun 2022 17:36:59 +0800 Subject: [PATCH 01/10] Add reset test in sc scene --- .../supply_chain/case_04/test_case_04.csv | 12 + ...consumer_unit.py => test_consumer_unit.py} | 0 ...tion_unit.py => test_distribution_unit.py} | 0 tests/supply_chain/test_env_reset.py | 448 ++++++++++++++++++ ...cture_unit.py => test_manufacture_unit.py} | 0 ...ly_chain_readdata.py => test_read_data.py} | 0 ...ain_seller_unit.py => test_seller_unit.py} | 2 +- ...chain_state_only.py => test_state_only.py} | 0 ...n_storage_unit.py => test_storage_unit.py} | 0 ...teraction.py => test_units_interaction.py} | 0 10 files changed, 461 insertions(+), 1 deletion(-) rename tests/supply_chain/{test_supply_chain_consumer_unit.py => test_consumer_unit.py} (100%) rename tests/supply_chain/{test_supply_chain_distribution_unit.py => test_distribution_unit.py} (100%) create mode 100644 tests/supply_chain/test_env_reset.py rename tests/supply_chain/{test_supply_chain_manufacture_unit.py => test_manufacture_unit.py} (100%) rename tests/supply_chain/{test_supply_chain_readdata.py => test_read_data.py} (100%) rename tests/supply_chain/{test_supply_chain_seller_unit.py => test_seller_unit.py} (99%) rename tests/supply_chain/{test_supply_chain_state_only.py => test_state_only.py} (100%) rename tests/supply_chain/{test_supply_chain_storage_unit.py => test_storage_unit.py} (100%) rename tests/supply_chain/{test_supply_chain_units_interaction.py => test_units_interaction.py} (100%) diff --git a/tests/data/supply_chain/case_04/test_case_04.csv b/tests/data/supply_chain/case_04/test_case_04.csv index 8f77f53c7..acbc2b922 100644 --- a/tests/data/supply_chain/case_04/test_case_04.csv +++ b/tests/data/supply_chain/case_04/test_case_04.csv @@ -4,11 +4,23 @@ food_1,2021/1/2,43.1,33.39,20 food_1,2021/1/3,43.2,33.39,30 food_1,2021/1/4,43.3,33.39,40 food_1,2021/1/5,43.4,33.39,50 +food_1,2021/1/6,43.4,33.39,60 +food_1,2021/1/7,43.4,33.39,70 +food_1,2021/1/8,43.4,33.39,80 +food_1,2021/1/9,43.4,33.39,90 +food_1,2021/1/10,43.4,33.39,100 +food_1,2021/1/11,43.4,33.39,110 hobby_1,2021/1/1,28.32,21.79,100 hobby_1,2021/1/2,28.32,21.79,200 hobby_1,2021/1/3,28.32,21.79,300 hobby_1,2021/1/4,28.32,21.79,400 hobby_1,2021/1/5,28.32,21.79,500 +hobby_1,2021/1/6,28.32,21.79,600 +hobby_1,2021/1/7,28.32,21.79,700 +hobby_1,2021/1/8,28.32,21.79,800 +hobby_1,2021/1/9,28.32,21.79,900 +hobby_1,2021/1/10,28.32,21.79,1000 +hobby_1,2021/1/11,28.32,21.79,1100 household_1,2022/5/14,17.35974446,14.4,146 household_1,2022/5/15,17.29592816,14.4,164 household_1,2022/5/16,17.43129521,14.4,127 diff --git a/tests/supply_chain/test_supply_chain_consumer_unit.py b/tests/supply_chain/test_consumer_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_consumer_unit.py rename to tests/supply_chain/test_consumer_unit.py diff --git a/tests/supply_chain/test_supply_chain_distribution_unit.py b/tests/supply_chain/test_distribution_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_distribution_unit.py rename to tests/supply_chain/test_distribution_unit.py diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py new file mode 100644 index 000000000..0b9959323 --- /dev/null +++ b/tests/supply_chain/test_env_reset.py @@ -0,0 +1,448 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license +import random +import unittest +from collections import defaultdict +from typing import Dict, List + +import numpy as np + +from maro.simulator.scenarios.supply_chain import FacilityBase, ConsumerAction, ManufactureAction, StorageUnit +from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine +from maro.simulator.scenarios.supply_chain.order import Order + +from tests.supply_chain.common import build_env, SKU3_ID, FOOD_1_ID, get_product_dict_from_storage + + +class MyTestCase(unittest.TestCase): + """ + . consumer unit test + . distribution unit test + . manufacture unit test + . seller unit test + . storage unit test + """ + + def test_consumer_unit_reset(self) -> None: + """Test whether reset updates the consumer unit completely""" + env = build_env("case_01", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_consumer_unit = supplier_1.products[SKU3_ID].consumer + + consumer_node_index = sku3_consumer_unit.data_model_index + + features = ("id", "facility_id", "sku_id", "order_base_cost", "purchased", "received", "order_product_cost", + "latest_consumptions", "in_transit_quantity") + + # ##################################### Before reset ##################################### + consumer_nodes = env.snapshot_list["consumer"] + action = ConsumerAction(sku3_consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step([action]) + env_metric_1[i] = env.metrics + states_1[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset() + states = consumer_nodes[1:consumer_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_2 + states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step([action]) + env_metric_2[i] = env.metrics + states_2[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_distribution_unit_reset(self) -> None: + """Test initial state of the DistributionUnit of Supplier_SKU3.Test distribution unit reset""" + env = build_env("case_02", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_3 = be.world._get_facility_by_name("Supplier_SKU3") + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + + distribution_unit = supplier_3.distribution + distribution_node_index = distribution_unit.data_model_index + distribution_nodes = env.snapshot_list["distribution"] + + features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + + # ##################################### Before reset ##################################### + + order_1 = Order(src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, ) + + # There are 2 "train" in total, and 1 left after scheduling this order. + distribution_unit.place_order(order_1) + distribution_unit.try_schedule_orders(env.tick) + self.assertEqual(0, len(distribution_unit._order_queues["train"])) + self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) + + order_2 = Order( + src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + + distribution_unit.place_order(order_2) + distribution_unit.try_schedule_orders(env.tick) + self.assertEqual(0, len(distribution_unit._order_queues["train"])) + self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) + + # 3rd order, will cause the pending order increase + order_3 = Order( + src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order_3) + distribution_unit.try_schedule_orders(env.tick) + self.assertEqual(1, len(distribution_unit._order_queues["train"])) + self.assertEqual(10, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) + + env.step(None) + + # The purpose is to randomly perform the order operation + random_tick: List[int] = [] + for j in range(10): + random_tick.append(random.randint(5, 100)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot distribution unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + if i in random_tick: + order = Order( + src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env.step(None) + env_metric_1[i] = env.metrics + states_1[i] = distribution_nodes[i:distribution_node_index:features].flatten().astype(np.int) + + # ####################### Test whether reset updates the distribution unit completely ################ + env.reset() + env.step(None) + + distribution_nodes = env.snapshot_list["distribution"] + + # snapshot should reset after env.reset(). + states = distribution_nodes[1:distribution_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0], list(states)) + + # Do the same as before env.reset(). + distribution_unit.place_order(order_1) + distribution_unit.try_schedule_orders(env.tick) + + distribution_unit.place_order(order_2) + distribution_unit.try_schedule_orders(env.tick) + self.assertEqual(0, len(distribution_unit._order_queues["train"])) + self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) + + distribution_unit.place_order(order_3) + distribution_unit.try_schedule_orders(env.tick) + self.assertEqual(1, len(distribution_unit._order_queues["train"])) + self.assertEqual(10, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) + + env.step(None) + + expect_tick = 100 + # Save the env.metric of each tick into env_metric_2. + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot distribution unit of each tick in states_2. + states_2: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + if i in random_tick: + order = Order( + src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env.step(None) + env_metric_2[i] = env.metrics + states_2[i] = distribution_nodes[i:distribution_node_index:features].flatten().astype(np.int) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_manufacture_unit_reset(self) -> None: + """Test sku3 manufacturing. -- Supplier_SKU3.Test manufacture unit reset""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_storage_index = supplier_3.storage.data_model_index + manufacture_sku3_unit = supplier_3.products[SKU3_ID].manufacture + sku3_manufacture_index = manufacture_sku3_unit.data_model_index + + storage_nodes = env.snapshot_list["storage"] + + manufacture_features = ( + "id", "facility_id", "start_manufacture_quantity", "sku_id", "in_pipeline_quantity", "finished_quantity", + "product_unit_id", + ) + + # ############################### TICK: 0 ###################################### + + # tick 0 passed, no product manufacturing. + env.step(None) + + capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) + remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + # ###################################################################### + + # pass an action to start manufacturing for this tick. + action = ManufactureAction(manufacture_sku3_unit.id, 1) + + expect_tick = 30 + env_metric_1: Dict[int, dict] = defaultdict(dict) + states_1: Dict[int, list] = defaultdict(list) + random_tick: List[int] = [] + manufacture_nodes = env.snapshot_list["manufacture"] + + for i in range(10): + random_tick.append(random.randint(1, 30)) + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 0)]) + env_metric_1[i] = env.metrics + states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + # ############################### Test whether reset updates the distribution unit completely ################ + env.reset() + env.step(None) + + states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + + capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) + remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + expect_tick = 30 + env_metric_2: Dict[int, dict] = defaultdict(dict) + states_2: Dict[int, list] = defaultdict(list) + manufacture_nodes = env.snapshot_list["manufacture"] + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 0)]) + env_metric_2[i] = env.metrics + states_2[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + expect_tick = 30 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_seller_unit_dynamics_sampler(self): + """Tested the store_001 Interaction between seller unit and dynamics csv data. + The data file of this test is test_case_ 04.csv""" + env = build_env("case_04", 600) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") + seller_unit = Store_001.products[FOOD_1_ID].seller + + seller_node_index = seller_unit.data_model_index + + seller_nodes = env.snapshot_list["seller"] + + features = ("sold", "demand", "total_sold", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",) + + self.assertEqual(20, seller_unit.sku_id) + + # NOTE: this simple seller unit return demands that same as current tick + + # Tick 0 will have demand == 10.first row of data after preprocessing data. + # from sample_preprocessed.csv + self.assertEqual(10, seller_unit._sold) + self.assertEqual(10, seller_unit._demand) + self.assertEqual(10, seller_unit._total_sold) + + expect_tick = 12 + env_metric_1: Dict[int, dict] = defaultdict(dict) + states_1: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step(None) + env_metric_1[i] = env.metrics + states_1[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int) + + # ############################### Test whether reset updates the distribution unit completely ################ + env.reset() + env.step(None) + states = seller_nodes[1:seller_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + + expect_tick = 12 + + env_metric_2: Dict[int, dict] = defaultdict(dict) + states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step(None) + env_metric_2[i] = env.metrics + states_2[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int) + + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_storage_unit_reset(self) -> None: + """Facility with single SKU. -- Supplier_SKU3""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + storage_unit: StorageUnit = supplier_3.storage + storage_node_index = storage_unit.data_model_index + + storage_nodes = env.snapshot_list["storage"] + + features = ("id", "facility_id",) + + # ############################### Take more than existing ###################################### + + # which this setting, it will return false, as no enough product for ous + + expect_tick = 10 + env_metric_1: Dict[int, dict] = defaultdict(dict) + states_1: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step(None) + env_metric_1[i] = env.metrics + states_1[i] = list(storage_nodes[i:storage_node_index:features].flatten().astype(np.int)) + states_1[i].append(storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum()) + states_1[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum()) + states_1[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum()) + + # ############################### Test whether reset updates the distribution unit completely ################ + env.reset() + env.step(None) + + states = storage_nodes[1:storage_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0], list(states)) + + expect_tick = 10 + + env_metric_2: Dict[int, dict] = defaultdict(dict) + states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step(None) + env_metric_2[i] = env.metrics + states_2[i] = list(storage_nodes[i:storage_node_index:features].flatten().astype(np.int)) + states_2[i].append(storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum()) + states_2[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum()) + states_2[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum()) + + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/supply_chain/test_supply_chain_manufacture_unit.py b/tests/supply_chain/test_manufacture_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_manufacture_unit.py rename to tests/supply_chain/test_manufacture_unit.py diff --git a/tests/supply_chain/test_supply_chain_readdata.py b/tests/supply_chain/test_read_data.py similarity index 100% rename from tests/supply_chain/test_supply_chain_readdata.py rename to tests/supply_chain/test_read_data.py diff --git a/tests/supply_chain/test_supply_chain_seller_unit.py b/tests/supply_chain/test_seller_unit.py similarity index 99% rename from tests/supply_chain/test_supply_chain_seller_unit.py rename to tests/supply_chain/test_seller_unit.py index f74e4a320..3a78e1a52 100644 --- a/tests/supply_chain/test_supply_chain_seller_unit.py +++ b/tests/supply_chain/test_seller_unit.py @@ -192,7 +192,7 @@ def test_seller_unit_dynamics_sampler(self): # NOTE: this simple seller unit return demands that same as current tick - # Tick 0 will have demand == 25.first row of data after preprocessing data. + # Tick 0 will have demand == 10.first row of data after preprocessing data. # from sample_preprocessed.csv self.assertEqual(10, seller_unit._sold) self.assertEqual(10, seller_unit._demand) diff --git a/tests/supply_chain/test_supply_chain_state_only.py b/tests/supply_chain/test_state_only.py similarity index 100% rename from tests/supply_chain/test_supply_chain_state_only.py rename to tests/supply_chain/test_state_only.py diff --git a/tests/supply_chain/test_supply_chain_storage_unit.py b/tests/supply_chain/test_storage_unit.py similarity index 100% rename from tests/supply_chain/test_supply_chain_storage_unit.py rename to tests/supply_chain/test_storage_unit.py diff --git a/tests/supply_chain/test_supply_chain_units_interaction.py b/tests/supply_chain/test_units_interaction.py similarity index 100% rename from tests/supply_chain/test_supply_chain_units_interaction.py rename to tests/supply_chain/test_units_interaction.py From c4068e375ece54b1f51adba122d06320b5649781 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Thu, 16 Jun 2022 11:55:43 +0800 Subject: [PATCH 02/10] Add reset test in sc scene and fix the problem found by reset test Add reset test in sc scene and fix the problem found by reset test --- .../supply_chain/datamodels/distribution.py | 3 + .../supply_chain/units/distribution.py | 1 + tests/supply_chain/test_env_reset.py | 78 ++++++++++++------- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/maro/simulator/scenarios/supply_chain/datamodels/distribution.py b/maro/simulator/scenarios/supply_chain/datamodels/distribution.py index 051c458d3..c2f3d45c2 100644 --- a/maro/simulator/scenarios/supply_chain/datamodels/distribution.py +++ b/maro/simulator/scenarios/supply_chain/datamodels/distribution.py @@ -19,5 +19,8 @@ class DistributionDataModel(DataModelBase): def __init__(self) -> None: super(DistributionDataModel, self).__init__() + def initialize(self) -> None: + self.reset() + def reset(self) -> None: super(DistributionDataModel, self).reset() diff --git a/maro/simulator/scenarios/supply_chain/units/distribution.py b/maro/simulator/scenarios/supply_chain/units/distribution.py index 053a54ecf..2405d127a 100644 --- a/maro/simulator/scenarios/supply_chain/units/distribution.py +++ b/maro/simulator/scenarios/supply_chain/units/distribution.py @@ -100,6 +100,7 @@ def initialize(self) -> None: self._busy_vehicle_num[vehicle_type] = 0 # TODO: add vehicle patient setting if needed + self.data_model.initialize() for sku_id in self.facility.products.keys(): self._unit_delay_order_penalty[sku_id] = self.facility.skus[sku_id].unit_delay_order_penalty diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 0b9959323..1cbffa1f5 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -100,13 +100,14 @@ def test_distribution_unit_reset(self) -> None: # ##################################### Before reset ##################################### - order_1 = Order(src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, ) + order_1 = Order( + src_facility=supplier_3, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, ) # There are 2 "train" in total, and 1 left after scheduling this order. distribution_unit.place_order(order_1) @@ -243,15 +244,15 @@ def test_manufacture_unit_reset(self) -> None: sku3_manufacture_index = manufacture_sku3_unit.data_model_index storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] manufacture_features = ( "id", "facility_id", "start_manufacture_quantity", "sku_id", "in_pipeline_quantity", "finished_quantity", "product_unit_id", ) - # ############################### TICK: 0 ###################################### + # ##################################### Before reset ##################################### - # tick 0 passed, no product manufacturing. env.step(None) capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) @@ -272,17 +273,19 @@ def test_manufacture_unit_reset(self) -> None: # all the id is greater than 0 self.assertGreater(manufacture_sku3_unit.id, 0) - # ###################################################################### - - # pass an action to start manufacturing for this tick. action = ManufactureAction(manufacture_sku3_unit.id, 1) expect_tick = 30 + + # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_1 states_1: Dict[int, list] = defaultdict(list) + random_tick: List[int] = [] - manufacture_nodes = env.snapshot_list["manufacture"] + # The purpose is to randomly perform the order operation for i in range(10): random_tick.append(random.randint(1, 30)) @@ -293,13 +296,17 @@ def test_manufacture_unit_reset(self) -> None: env_metric_1[i] = env.metrics states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) - # ############################### Test whether reset updates the distribution unit completely ################ + # ############################### Test whether reset updates the manufacture unit completely ################ env.reset() env.step(None) + # snapshot should reset after env.reset(). states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] + capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) @@ -319,9 +326,12 @@ def test_manufacture_unit_reset(self) -> None: self.assertGreater(manufacture_sku3_unit.id, 0) expect_tick = 30 + + # Save the env.metric of each tick into env_metric_2 env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_2 states_2: Dict[int, list] = defaultdict(list) - manufacture_nodes = env.snapshot_list["manufacture"] for i in range(expect_tick): env.step([action]) @@ -344,13 +354,13 @@ def test_seller_unit_dynamics_sampler(self): env.step(None) Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - seller_unit = Store_001.products[FOOD_1_ID].seller + seller_unit = Store_001.products[FOOD_1_ID].seller seller_node_index = seller_unit.data_model_index - seller_nodes = env.snapshot_list["seller"] - features = ("sold", "demand", "total_sold", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",) + features = ("sold", "demand", "total_sold", "id", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",) + # ##################################### Before reset ##################################### self.assertEqual(20, seller_unit.sku_id) @@ -363,23 +373,33 @@ def test_seller_unit_dynamics_sampler(self): self.assertEqual(10, seller_unit._total_sold) expect_tick = 12 + + # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot seller unit of each tick in states_1 states_1: Dict[int, list] = defaultdict(list) for i in range(expect_tick): env.step(None) env_metric_1[i] = env.metrics states_1[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int) - # ############################### Test whether reset updates the distribution unit completely ################ + # ################# Test whether reset updates the seller unit completely ################ env.reset() env.step(None) + + # snapshot should reset after env.reset(). states = seller_nodes[1:seller_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) expect_tick = 12 + # Save the env.metric of each tick into env_metric_2 env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot seller unit of each tick in states_2 states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): env.step(None) env_metric_2[i] = env.metrics @@ -398,19 +418,20 @@ def test_storage_unit_reset(self) -> None: env.step(None) supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + storage_unit: StorageUnit = supplier_3.storage storage_node_index = storage_unit.data_model_index - storage_nodes = env.snapshot_list["storage"] - features = ("id", "facility_id",) - # ############################### Take more than existing ###################################### - - # which this setting, it will return false, as no enough product for ous + # ##################################### Before reset ##################################### expect_tick = 10 + + # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot storage unit of each tick in states_1 states_1: Dict[int, list] = defaultdict(list) for i in range(expect_tick): env.step(None) @@ -420,17 +441,22 @@ def test_storage_unit_reset(self) -> None: states_1[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum()) states_1[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum()) - # ############################### Test whether reset updates the distribution unit completely ################ + # ############################### Test whether reset updates the storage unit completely ################ env.reset() env.step(None) + # snapshot should reset after env.reset(). states = storage_nodes[1:storage_node_index:features].flatten().astype(np.int) self.assertEqual([0, 0], list(states)) expect_tick = 10 + # Save the env.metric of each tick into env_metric_2 env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot storage unit of each tick in states_2 states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): env.step(None) env_metric_2[i] = env.metrics From 08744df7b2b351f8ffa35d48b81a2a6c19bb0be0 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Tue, 21 Jun 2022 11:25:58 +0800 Subject: [PATCH 03/10] Modify test according to comments Modify test according to comments --- .../supply_chain/units/distribution.py | 1 + tests/data/supply_chain/case_01/config.yml | 2 + tests/data/supply_chain/case_05/config.yml | 25 +- tests/supply_chain/test_action_reset.py | 593 ++++++++++++++++++ tests/supply_chain/test_env_reset.py | 71 ++- tests/supply_chain/test_state_only.py | 4 +- 6 files changed, 672 insertions(+), 24 deletions(-) create mode 100644 tests/supply_chain/test_action_reset.py diff --git a/maro/simulator/scenarios/supply_chain/units/distribution.py b/maro/simulator/scenarios/supply_chain/units/distribution.py index 2405d127a..29ea8a263 100644 --- a/maro/simulator/scenarios/supply_chain/units/distribution.py +++ b/maro/simulator/scenarios/supply_chain/units/distribution.py @@ -100,6 +100,7 @@ def initialize(self) -> None: self._busy_vehicle_num[vehicle_type] = 0 # TODO: add vehicle patient setting if needed + self.data_model.initialize() for sku_id in self.facility.products.keys(): diff --git a/tests/data/supply_chain/case_01/config.yml b/tests/data/supply_chain/case_01/config.yml index 7730960e4..0d21ac8c5 100644 --- a/tests/data/supply_chain/case_01/config.yml +++ b/tests/data/supply_chain/case_01/config.yml @@ -88,6 +88,7 @@ world: sku3: init_stock: 80 has_manufacture: True + has_consumer: True max_manufacture_rate: 50 manufacture_leading_time: 1 unit_product_cost: 1 @@ -105,6 +106,7 @@ world: init_stock: 96 has_manufacture: True max_manufacture_rate: 50 + has_consumer: True manufacture_leading_time: 1 unit_product_cost: 1 price: 100 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 93f2fdf51..93c663234 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -146,6 +146,25 @@ world: unit_delay_order_penalty: 20 unit_order_cost: 0 + - name: "Supplier_SKU1" + definition_ref: "SupplierFacility" + skus: + sku3: + init_stock: 80 + has_manufacture: True + max_manufacture_rate: 50 + manufacture_leading_time: 1 + unit_product_cost: 1 + price: 10 + unit_delay_order_penalty: 10 + has_consumer: True + children: + storage: *small_storage + distribution: *normal_distribution + config: + unit_delay_order_penalty: 20 + unit_order_cost: 0 + - name: "Warehouse_001" definition_ref: "WarehouseFacility" skus: @@ -232,7 +251,7 @@ world: storage: *single_storage config: unit_order_cost: 200 - file_path: "tests/data/supply_chain/case_05/test_case_05.csv" + file_path: "tests/data/supply_chain/case_04/test_case_04.csv" topology: @@ -258,6 +277,10 @@ world: Warehouse_001: sku3: + "Supplier_SKU1": + "train": + vlt: 3 + cost: 1 "Supplier_SKU3": "train": vlt: 3 diff --git a/tests/supply_chain/test_action_reset.py b/tests/supply_chain/test_action_reset.py new file mode 100644 index 000000000..d87c6b914 --- /dev/null +++ b/tests/supply_chain/test_action_reset.py @@ -0,0 +1,593 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +import unittest +from collections import defaultdict +from typing import Dict, List + +import numpy as np + +from maro.simulator.scenarios.supply_chain import ( + ConsumerAction, + ConsumerUnit, + FacilityBase, + ManufactureAction, + ManufactureUnit, + StorageUnit, +) +from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine +from maro.simulator.scenarios.supply_chain.order import Order + +from tests.supply_chain.common import SKU1_ID, SKU3_ID, build_env, get_product_dict_from_storage + + +class MyTestCase(unittest.TestCase): + """ + . consumer unit test + . distribution unit test + . manufacture unit test + . seller unit test + . storage unit test + """ + + def test_env_reset_with_none_action(self) -> None: + """test_env_reset_with_none_action""" + env = build_env("case_05", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") + consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + storage_unit: StorageUnit = supplier_1.storage + seller_unit = Store_001.products[SKU3_ID].seller + manufacture_unit = supplier_1.products[SKU3_ID].manufacture + distribution_unit = supplier_1.distribution + + consumer_nodes = env.snapshot_list["consumer"] + storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] + + consumer_node_index = consumer_unit.data_model_index + storage_node_index = storage_unit.data_model_index + seller_node_index = seller_unit.data_model_index + manufacture_node_index = manufacture_unit.data_model_index + distribution_node_index = distribution_unit.data_model_index + + consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + storage_features = ("id", "facility_id") + + seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) + + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + + distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + + # ##################################### Before reset ##################################### + + expect_tick = 10 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(10): + random_tick.append(random.randint(1, 30)) + + # Store the information about the snapshot of each tick in states_1_x + states_1_consumer: Dict[int, list] = defaultdict(list) + states_1_storage: Dict[int, list] = defaultdict(list) + states_1_seller: Dict[int, list] = defaultdict(list) + states_1_manufacture: Dict[int, list] = defaultdict(list) + states_1_distribution: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step(None) + if i in random_tick: + order = Order( + src_facility=supplier_1, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env_metric_1[i] = env.metrics + states_1_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) + states_1_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + env_metric_1[i] = env.metrics + states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), + ) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), + ) + states_1_storage[i].append( + storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), + ) + states_1_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) + states_1_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + states_1_distribution[i] = ( + distribution_nodes[i:distribution_node_index:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + # ############################### Test whether reset updates the storage unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset(). + consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) + storage_states = storage_nodes[1:storage_node_index:storage_features].flatten().astype(np.int) + seller_states = seller_nodes[1:seller_node_index:seller_features].flatten().astype(np.int) + manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) + distribution_states = ( + distribution_nodes[1:distribution_node_index:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) + self.assertEqual([0, 0], list(storage_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(seller_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) + self.assertEqual([0, 0, 0, 0], list(distribution_states)) + + expect_tick = 10 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot storage unit of each tick in states_2 + + states_2_consumer: Dict[int, list] = defaultdict(list) + states_2_storage: Dict[int, list] = defaultdict(list) + states_2_seller: Dict[int, list] = defaultdict(list) + states_2_manufacture: Dict[int, list] = defaultdict(list) + states_2_distribution: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step(None) + if i in random_tick: + order = Order( + src_facility=supplier_1, + dest_facility=warehouse_1, + sku_id=SKU3_ID, + quantity=10, + vehicle_type="train", + creation_tick=env.tick, + expected_finish_tick=env.tick + 7, + ) + distribution_unit.place_order(order) + distribution_unit.try_schedule_orders(env.tick) + env_metric_2[i] = env.metrics + states_2_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) + states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), + ) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), + ) + states_2_storage[i].append( + storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), + ) + states_2_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) + states_2_manufacture[i] = ( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + states_2_distribution[i] = ( + distribution_nodes[i:distribution_node_index:distribution_features].flatten().astype(np.int) + ) + + for i in range(expect_tick): + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ManufactureAction_only(self) -> None: + """test env reset with ManufactureAction only""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_storage_index = supplier_3.storage.data_model_index + manufacture_sku3_unit = supplier_3.products[SKU3_ID].manufacture + sku3_manufacture_index = manufacture_sku3_unit.data_model_index + + storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + # ##################################### Before reset ##################################### + + env.step(None) + + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + action = ManufactureAction(manufacture_sku3_unit.id, 1) + + expect_tick = 30 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(10): + random_tick.append(random.randint(1, 30)) + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env_metric_1[i] = env.metrics + states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + # ############################### Test whether reset updates the manufacture unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset(). + states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + + storage_nodes = env.snapshot_list["storage"] + manufacture_nodes = env.snapshot_list["manufacture"] + + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) + + # there should be 80 units been taken at the beginning according to the config file. + # so remaining space should be 20 + self.assertEqual(20, remaining_spaces.sum()) + # capacity is 100 by config + self.assertEqual(100, capacities.sum()) + + product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + + # The product quantity should be same as configuration at beginning. + # 80 sku3 + self.assertEqual(80, product_dict[SKU3_ID]) + + # all the id is greater than 0 + self.assertGreater(manufacture_sku3_unit.id, 0) + + expect_tick = 30 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot manufacture unit of each tick in states_2 + states_2: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step([action]) + if i in random_tick: + env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env_metric_2[i] = env.metrics + states_2[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + expect_tick = 30 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_ConsumerAction_only(self) -> None: + """ "test env reset with ConsumerAction only""" + env = build_env("case_01", 500) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + sku3_consumer_unit = supplier_1.products[SKU3_ID].consumer + + consumer_node_index = sku3_consumer_unit.data_model_index + + features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + # ##################################### Before reset ##################################### + consumer_nodes = env.snapshot_list["consumer"] + action = ConsumerAction(sku3_consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_1 + states_1: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + env.step([action]) + env_metric_1[i] = env.metrics + states_1[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset() + states = consumer_nodes[1:consumer_node_index:features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_2 + states_2: Dict[int, list] = defaultdict(list) + for i in range(expect_tick): + env.step([action]) + env_metric_2[i] = env.metrics + states_2[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: + """test env reset with both ManufactureAction and ConsumerAction""" + env = build_env("case_01", 100) + be = env.business_engine + assert isinstance(be, SupplyChainBusinessEngine) + + env.step(None) + + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture + storage_unit: StorageUnit = supplier_1.storage + + consumer_node_index = consumer_unit.data_model_index + manufacture_node_index = manufacture_unit.data_model_index + storage_node_index = storage_unit.data_model_index + + consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + storage_features = ("id", "facility_id") + + consumer_nodes = env.snapshot_list["consumer"] + manufacture_nodes = env.snapshot_list["manufacture"] + storage_nodes = env.snapshot_list["storage"] + + # ##################################### Before reset ##################################### + action_consumer = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 20, "train") + action_manufacture = ManufactureAction(manufacture_unit.id, 5) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_1 + env_metric_1: Dict[int, dict] = defaultdict(dict) + + random_tick: List[int] = [] + + # The purpose is to randomly perform the order operation + for i in range(30): + random_tick.append(random.randint(0, 90)) + + # Store the information about the snapshot unit of each tick in states_1 + states_1_consumer: Dict[int, list] = defaultdict(list) + states_1_manufacture: Dict[int, list] = defaultdict(list) + states_1_storage: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + + if i in random_tick: + env.step([action_manufacture]) + i += 1 + states_1_manufacture[i] = list( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ), + ) + env_metric_1[i] = env.metrics + continue + + env.step([action_consumer]) + env_metric_1[i] = env.metrics + states_1_consumer[i] = list( + consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), + ) + + states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_1_storage[i].append( + list(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int)), + ) + states_1_storage[i].append( + list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), + ) + + # ############### Test whether reset updates the consumer unit completely ################ + env.reset() + env.step(None) + + # snapshot should reset after env.reset() + consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) + manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) + + expect_tick = 100 + + # Save the env.metric of each tick into env_metric_2 + env_metric_2: Dict[int, dict] = defaultdict(dict) + + # Store the information about the snapshot consumer unit of each tick in states_2 + states_2_consumer: Dict[int, list] = defaultdict(list) + states_2_manufacture: Dict[int, list] = defaultdict(list) + states_2_storage: Dict[int, list] = defaultdict(list) + + for i in range(expect_tick): + + if i in random_tick: + env.step([action_manufacture]) + i += 1 + states_2_manufacture[i] = list( + manufacture_nodes[i:manufacture_node_index:manufacture_features] + .flatten() + .astype( + np.int, + ), + ) + env_metric_2[i] = env.metrics + continue + + env.step([action_consumer]) + env_metric_2[i] = env.metrics + states_2_consumer[i] = list( + consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), + ) + + states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) + states_2_storage[i].append( + list( + storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int), + ), + ) + states_2_storage[i].append( + list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), + ) + + expect_tick = 100 + for i in range(expect_tick): + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 1cbffa1f5..2a5ab982d 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. -# Licensed under the MIT license +# Licensed under the MIT license. + import random import unittest from collections import defaultdict @@ -7,21 +8,21 @@ import numpy as np -from maro.simulator.scenarios.supply_chain import FacilityBase, ConsumerAction, ManufactureAction, StorageUnit +from maro.simulator.scenarios.supply_chain import ConsumerAction, FacilityBase, ManufactureAction, StorageUnit from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine from maro.simulator.scenarios.supply_chain.order import Order -from tests.supply_chain.common import build_env, SKU3_ID, FOOD_1_ID, get_product_dict_from_storage +from tests.supply_chain.common import FOOD_1_ID, SKU3_ID, build_env, get_product_dict_from_storage class MyTestCase(unittest.TestCase): """ - . consumer unit test - . distribution unit test - . manufacture unit test - . seller unit test - . storage unit test - """ + . consumer unit test + . distribution unit test + . manufacture unit test + . seller unit test + . storage unit test + """ def test_consumer_unit_reset(self) -> None: """Test whether reset updates the consumer unit completely""" @@ -37,8 +38,17 @@ def test_consumer_unit_reset(self) -> None: consumer_node_index = sku3_consumer_unit.data_model_index - features = ("id", "facility_id", "sku_id", "order_base_cost", "purchased", "received", "order_product_cost", - "latest_consumptions", "in_transit_quantity") + features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) # ##################################### Before reset ##################################### consumer_nodes = env.snapshot_list["consumer"] @@ -107,7 +117,8 @@ def test_distribution_unit_reset(self) -> None: quantity=10, vehicle_type="train", creation_tick=env.tick, - expected_finish_tick=env.tick + 7, ) + expected_finish_tick=env.tick + 7, + ) # There are 2 "train" in total, and 1 left after scheduling this order. distribution_unit.place_order(order_1) @@ -247,7 +258,12 @@ def test_manufacture_unit_reset(self) -> None: manufacture_nodes = env.snapshot_list["manufacture"] manufacture_features = ( - "id", "facility_id", "start_manufacture_quantity", "sku_id", "in_pipeline_quantity", "finished_quantity", + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", "product_unit_id", ) @@ -255,8 +271,10 @@ def test_manufacture_unit_reset(self) -> None: env.step(None) - capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) - remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) # there should be 80 units been taken at the beginning according to the config file. # so remaining space should be 20 @@ -307,8 +325,10 @@ def test_manufacture_unit_reset(self) -> None: storage_nodes = env.snapshot_list["storage"] manufacture_nodes = env.snapshot_list["manufacture"] - capacities = storage_nodes[env.frame_index:sku3_storage_index:"capacity"].flatten().astype(np.int) - remaining_spaces = storage_nodes[env.frame_index:sku3_storage_index:"remaining_space"].flatten().astype(np.int) + capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + remaining_spaces = ( + storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + ) # there should be 80 units been taken at the beginning according to the config file. # so remaining space should be 20 @@ -347,7 +367,7 @@ def test_manufacture_unit_reset(self) -> None: def test_seller_unit_dynamics_sampler(self): """Tested the store_001 Interaction between seller unit and dynamics csv data. - The data file of this test is test_case_ 04.csv""" + The data file of this test is test_case_ 04.csv""" env = build_env("case_04", 600) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) @@ -359,7 +379,16 @@ def test_seller_unit_dynamics_sampler(self): seller_node_index = seller_unit.data_model_index seller_nodes = env.snapshot_list["seller"] - features = ("sold", "demand", "total_sold", "id", "total_demand", "backlog_ratio", "facility_id", "product_unit_id",) + features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) # ##################################### Before reset ##################################### self.assertEqual(20, seller_unit.sku_id) @@ -422,7 +451,7 @@ def test_storage_unit_reset(self) -> None: storage_unit: StorageUnit = supplier_3.storage storage_node_index = storage_unit.data_model_index storage_nodes = env.snapshot_list["storage"] - features = ("id", "facility_id",) + features = ("id", "facility_id") # ##################################### Before reset ##################################### @@ -470,5 +499,5 @@ def test_storage_unit_reset(self) -> None: self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/supply_chain/test_state_only.py b/tests/supply_chain/test_state_only.py index 0dfa90d50..93aa95231 100644 --- a/tests/supply_chain/test_state_only.py +++ b/tests/supply_chain/test_state_only.py @@ -185,7 +185,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") retailer_1: FacilityBase = be.world._get_facility_by_name("Retailer_001") - warehouse_1_id, retailer_1_id = 6, 13 + warehouse_1_id, retailer_1_id = 12, 19 warehouse_1_distribution_unit = warehouse_1.distribution self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) @@ -375,7 +375,7 @@ def test_distribution_state_only(self) -> None: distribution_unit.place_order(order) self.assertEqual(1, len(distribution_unit._order_queues["train"])) self.assertEqual(20, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - supplier_3_id, warehouse_1_id, retailer_1_id = 1, 6, 13 + supplier_3_id, warehouse_1_id, retailer_1_id = 1, 12, 19 env.step(None) From 0d59470348d8392a9b87288b586363a523a84e43 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Tue, 21 Jun 2022 14:48:26 +0800 Subject: [PATCH 04/10] Modify test according to comments --- tests/supply_chain/test_state_only.py | 35 ++++++++++++--------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/supply_chain/test_state_only.py b/tests/supply_chain/test_state_only.py index 93aa95231..d6584e4a9 100644 --- a/tests/supply_chain/test_state_only.py +++ b/tests/supply_chain/test_state_only.py @@ -54,7 +54,6 @@ def test_distribution_state_only_small_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") distribution_unit = supplier_3.distribution - warehouse_1.products[SKU3_ID].consumer env.step(None) # vlt is greater than len(pending_order_len), which will cause the pending order to increase @@ -185,15 +184,12 @@ def test_distribution_state_only_bigger_vlt(self) -> None: warehouse_1 = be.world._get_facility_by_name("Warehouse_001") retailer_1: FacilityBase = be.world._get_facility_by_name("Retailer_001") - warehouse_1_id, retailer_1_id = 12, 19 warehouse_1_distribution_unit = warehouse_1.distribution self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) env.step(None) - retailer_1.products[SKU2_ID].consumer - order_1 = Order(warehouse_1, retailer_1, SKU2_ID, 1, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_1) # The vlt configuration of this topology is 5. @@ -230,7 +226,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: list(env.metrics["products"][retailer_1.products[SKU2_ID].id]["pending_order_daily"]), ) - self.assertEqual(3, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(3, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) # There are a total of two trains in the configuration, and they have all been dispatched now. self.assertEqual(0, len(warehouse_1_distribution_unit._order_queues["train"])) @@ -245,9 +241,9 @@ def test_distribution_state_only_bigger_vlt(self) -> None: list(env.metrics["products"][retailer_1.products[SKU2_ID].id]["pending_order_daily"]), ) - self.assertEqual(6, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(6, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) - self.assertEqual(0, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(0, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU3_ID]) # After env.step runs, where tick is 5. order_1 arrives after env.step. # order_2 will arrive at tick=6.order_3 is expected to arrive at tick=8 under normal circumstances. @@ -267,7 +263,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: # When order_1 arrives at the next step, the in_transit_orders of retailer_1 should be the negative number # 1+2+3-1 of the arrival order of retailer_1. - self.assertEqual(5, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(5, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) # After env.step runs, where tick is 7. order_2 arrives after env.step. # There are empty cars at this time, order_3 will arrive at tick = 11. @@ -275,7 +271,7 @@ def test_distribution_state_only_bigger_vlt(self) -> None: # When order_2 arrives at the next step, the in_transit_orders of retailer_1 should be the negative number # 1+2+3-1-2 of the arrival order of retailer_1. - self.assertEqual(3, env.metrics["facilities"][retailer_1_id]["in_transit_orders"][SKU2_ID]) + self.assertEqual(3, env.metrics["facilities"][retailer_1.id]["in_transit_orders"][SKU2_ID]) order_4 = Order(warehouse_1, retailer_1, SKU2_ID, 4, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_4) @@ -375,7 +371,6 @@ def test_distribution_state_only(self) -> None: distribution_unit.place_order(order) self.assertEqual(1, len(distribution_unit._order_queues["train"])) self.assertEqual(20, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - supplier_3_id, warehouse_1_id, retailer_1_id = 1, 12, 19 env.step(None) @@ -384,7 +379,7 @@ def test_distribution_state_only(self) -> None: [0, 0, 20, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) # add another order, it would be successfully scheduled. order = Order(supplier_3, warehouse_1, SKU3_ID, 25, "train", env.tick, None) @@ -398,7 +393,7 @@ def test_distribution_state_only(self) -> None: self.assertEqual(2, len(distribution_unit._order_queues["train"])) self.assertEqual(25 + 30, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3_id]["pending_order"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3.id]["pending_order"][SKU3_ID]) self.assertEqual(25 + 30, distribution_unit.pending_product_quantity[SKU3_ID]) warehouse_1_distribution_unit = warehouse_1.distribution @@ -409,16 +404,16 @@ def test_distribution_state_only(self) -> None: order_3 = Order(warehouse_1, retailer_1, SKU3_ID, 5, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_3) - self.assertEqual(5 + 5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5 + 5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5 + 5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) order_4 = Order(warehouse_1, retailer_1, SKU3_ID, 5, "train", env.tick, None) warehouse_1_distribution_unit.place_order(order_4) - self.assertEqual(5 + 5 + 5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5 + 5 + 5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5 + 5 + 5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) # There is no place_order for the distribution of supplier_3, there should be no change - self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3_id]["pending_order"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][supplier_3.id]["pending_order"][SKU3_ID]) self.assertEqual(25 + 30, distribution_unit.pending_product_quantity[SKU3_ID]) start_tick = env.tick @@ -429,7 +424,7 @@ def test_distribution_state_only(self) -> None: [0, 20, 25, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) while env.tick < expected_supplier_tick - 1: env.step(None) @@ -438,20 +433,20 @@ def test_distribution_state_only(self) -> None: [20, 25, 0, 0], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(20 + 25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) env.step(None) self.assertEqual( [25, 0, 0, 30], list(env.metrics["products"][warehouse_1.products[SKU3_ID].id]["pending_order_daily"]), ) - self.assertEqual(25 + 30, env.metrics["facilities"][warehouse_1_id]["in_transit_orders"][SKU3_ID]) + self.assertEqual(25 + 30, env.metrics["facilities"][warehouse_1.id]["in_transit_orders"][SKU3_ID]) # will arrive at the end of this tick, still on the way. assert env.tick == expected_supplier_tick self.assertEqual(0, len(distribution_unit._order_queues["train"])) self.assertEqual(0, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) env.step(None) @@ -459,7 +454,7 @@ def test_distribution_state_only(self) -> None: self.assertEqual(0, len(distribution_unit._order_queues["train"])) self.assertEqual(0, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - self.assertEqual(5, env.metrics["facilities"][warehouse_1_id]["pending_order"][SKU3_ID]) + self.assertEqual(5, env.metrics["facilities"][warehouse_1.id]["pending_order"][SKU3_ID]) self.assertEqual(5, warehouse_1_distribution_unit.pending_product_quantity[SKU3_ID]) From d064c96c4d28990d42e86e5b5b5615235e1c740f Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Wed, 22 Jun 2022 13:56:13 +0800 Subject: [PATCH 05/10] Modify test according to comments ++ Modify test according to comments ++ --- tests/data/supply_chain/case_01/config.yml | 2 - tests/data/supply_chain/case_05/config.yml | 15 +- tests/supply_chain/test_action_reset.py | 657 +++++++++++++-------- 3 files changed, 438 insertions(+), 236 deletions(-) diff --git a/tests/data/supply_chain/case_01/config.yml b/tests/data/supply_chain/case_01/config.yml index 0d21ac8c5..7730960e4 100644 --- a/tests/data/supply_chain/case_01/config.yml +++ b/tests/data/supply_chain/case_01/config.yml @@ -88,7 +88,6 @@ world: sku3: init_stock: 80 has_manufacture: True - has_consumer: True max_manufacture_rate: 50 manufacture_leading_time: 1 unit_product_cost: 1 @@ -106,7 +105,6 @@ world: init_stock: 96 has_manufacture: True max_manufacture_rate: 50 - has_consumer: True manufacture_leading_time: 1 unit_product_cost: 1 price: 100 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 93c663234..08089864e 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -149,8 +149,8 @@ world: - name: "Supplier_SKU1" definition_ref: "SupplierFacility" skus: - sku3: - init_stock: 80 + sku1: + init_stock: 20 has_manufacture: True max_manufacture_rate: 50 manufacture_leading_time: 1 @@ -158,6 +158,11 @@ world: price: 10 unit_delay_order_penalty: 10 has_consumer: True + sku3: + init_stock: 80 + unit_product_cost: 1 + price: 10 + unit_delay_order_penalty: 10 children: storage: *small_storage distribution: *normal_distribution @@ -173,6 +178,7 @@ world: sub_storage_id: 1 storage_upper_bound: 40 price: 100 + has_consumer: True sku2: init_stock: 12 sub_storage_id: 1 @@ -251,7 +257,7 @@ world: storage: *single_storage config: unit_order_cost: 200 - file_path: "tests/data/supply_chain/case_04/test_case_04.csv" + file_path: "tests/data/supply_chain/case_05/test_case_05.csv" topology: @@ -276,11 +282,12 @@ world: cost: 0.6 Warehouse_001: - sku3: + sku1: "Supplier_SKU1": "train": vlt: 3 cost: 1 + sku3: "Supplier_SKU3": "train": vlt: 3 diff --git a/tests/supply_chain/test_action_reset.py b/tests/supply_chain/test_action_reset.py index d87c6b914..cd84ffe2d 100644 --- a/tests/supply_chain/test_action_reset.py +++ b/tests/supply_chain/test_action_reset.py @@ -14,10 +14,11 @@ FacilityBase, ManufactureAction, ManufactureUnit, + RetailerFacility, StorageUnit, + WarehouseFacility, ) from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine -from maro.simulator.scenarios.supply_chain.order import Order from tests.supply_chain.common import SKU1_ID, SKU3_ID, build_env, get_product_dict_from_storage @@ -42,11 +43,12 @@ def test_env_reset_with_none_action(self) -> None: supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") warehouse_1 = be.world._get_facility_by_name("Warehouse_001") Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + + consumer_unit: ConsumerUnit = warehouse_1.products[SKU3_ID].consumer storage_unit: StorageUnit = supplier_1.storage - seller_unit = Store_001.products[SKU3_ID].seller - manufacture_unit = supplier_1.products[SKU3_ID].manufacture - distribution_unit = supplier_1.distribution + Store_001.products[SKU3_ID].seller + supplier_1.products[SKU1_ID].manufacture + supplier_1.distribution consumer_nodes = env.snapshot_list["consumer"] storage_nodes = env.snapshot_list["storage"] @@ -54,12 +56,6 @@ def test_env_reset_with_none_action(self) -> None: manufacture_nodes = env.snapshot_list["manufacture"] distribution_nodes = env.snapshot_list["distribution"] - consumer_node_index = consumer_unit.data_model_index - storage_node_index = storage_unit.data_model_index - seller_node_index = seller_unit.data_model_index - manufacture_node_index = manufacture_unit.data_model_index - distribution_node_index = distribution_unit.data_model_index - consumer_features = ( "id", "facility_id", @@ -104,91 +100,59 @@ def test_env_reset_with_none_action(self) -> None: # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) - random_tick: List[int] = [] - - # The purpose is to randomly perform the order operation - for i in range(10): - random_tick.append(random.randint(1, 30)) - # Store the information about the snapshot of each tick in states_1_x - states_1_consumer: Dict[int, list] = defaultdict(list) - states_1_storage: Dict[int, list] = defaultdict(list) - states_1_seller: Dict[int, list] = defaultdict(list) - states_1_manufacture: Dict[int, list] = defaultdict(list) - states_1_distribution: Dict[int, list] = defaultdict(list) + states_1_consumer: Dict[int, dict] = defaultdict(dict) + states_1_storage: Dict[int, dict] = defaultdict(dict) + states_1_seller: Dict[int, dict] = defaultdict(dict) + states_1_manufacture: Dict[int, dict] = defaultdict(dict) + states_1_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): env.step(None) - if i in random_tick: - order = Order( - src_facility=supplier_1, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - distribution_unit.place_order(order) - distribution_unit.try_schedule_orders(env.tick) - env_metric_1[i] = env.metrics - states_1_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) - states_1_manufacture[i] = ( - manufacture_nodes[i:manufacture_node_index:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) env_metric_1[i] = env.metrics - states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) - states_1_storage[i].append( - storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), - ) - states_1_storage[i].append( - storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), - ) - states_1_storage[i].append( - storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), - ) - states_1_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) - states_1_manufacture[i] = ( - manufacture_nodes[i:manufacture_node_index:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) - states_1_distribution[i] = ( - distribution_nodes[i:distribution_node_index:distribution_features] - .flatten() - .astype( - np.int, - ) - ) + + for idx in range(len(consumer_nodes)): + states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_1_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) + + for idx in range(len(distribution_nodes)): + states_1_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) + + for idx in range(len(seller_nodes)): + states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) # ############################### Test whether reset updates the storage unit completely ################ env.reset() env.step(None) # snapshot should reset after env.reset(). - consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) - storage_states = storage_nodes[1:storage_node_index:storage_features].flatten().astype(np.int) - seller_states = seller_nodes[1:seller_node_index:seller_features].flatten().astype(np.int) - manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) - distribution_states = ( - distribution_nodes[1:distribution_node_index:distribution_features] - .flatten() - .astype( - np.int, - ) - ) + for idx in range(len(manufacture_nodes)): + states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) - self.assertEqual([0, 0], list(storage_states)) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(seller_states)) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) - self.assertEqual([0, 0, 0, 0], list(distribution_states)) + for idx in range(len(storage_nodes)): + states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) + self.assertEqual([0, 0], list(states)) + + for idx in range(len(distribution_nodes)): + states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0], list(states)) + + for idx in range(len(consumer_nodes)): + states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + for idx in range(len(seller_nodes)): + states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) expect_tick = 10 @@ -197,49 +161,33 @@ def test_env_reset_with_none_action(self) -> None: # Store the information about the snapshot storage unit of each tick in states_2 - states_2_consumer: Dict[int, list] = defaultdict(list) - states_2_storage: Dict[int, list] = defaultdict(list) - states_2_seller: Dict[int, list] = defaultdict(list) - states_2_manufacture: Dict[int, list] = defaultdict(list) - states_2_distribution: Dict[int, list] = defaultdict(list) + states_2_consumer: Dict[int, dict] = defaultdict(dict) + states_2_storage: Dict[int, dict] = defaultdict(dict) + states_2_seller: Dict[int, dict] = defaultdict(dict) + states_2_manufacture: Dict[int, dict] = defaultdict(dict) + states_2_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): env.step(None) - if i in random_tick: - order = Order( - src_facility=supplier_1, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - distribution_unit.place_order(order) - distribution_unit.try_schedule_orders(env.tick) env_metric_2[i] = env.metrics - states_2_consumer[i] = consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int) - states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) - states_2_storage[i].append( - storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum(), - ) - states_2_storage[i].append( - storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum(), - ) - states_2_storage[i].append( - storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum(), - ) - states_2_seller[i] = seller_nodes[i:seller_node_index:seller_features].flatten().astype(np.int) - states_2_manufacture[i] = ( - manufacture_nodes[i:manufacture_node_index:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) - states_2_distribution[i] = ( - distribution_nodes[i:distribution_node_index:distribution_features].flatten().astype(np.int) - ) + + for idx in range(len(consumer_nodes)): + states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) + + for idx in range(len(distribution_nodes)): + states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) + + for idx in range(len(seller_nodes)): + states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) for i in range(expect_tick): self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) @@ -251,17 +199,51 @@ def test_env_reset_with_none_action(self) -> None: def test_env_reset_with_ManufactureAction_only(self) -> None: """test env reset with ManufactureAction only""" - env = build_env("case_01", 100) + env = build_env("case_02", 100) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - sku3_storage_index = supplier_3.storage.data_model_index - manufacture_sku3_unit = supplier_3.products[SKU3_ID].manufacture - sku3_manufacture_index = manufacture_sku3_unit.data_model_index + warehouse_1: WarehouseFacility = be.world._get_facility_by_name("Warehouse_001") + retailer_1: RetailerFacility = be.world._get_facility_by_name("Retailer_001") + + storage_unit = supplier_3.storage + warehouse_1.products[SKU3_ID].consumer + manufacture_unit = supplier_3.products[SKU3_ID].manufacture + supplier_3.distribution + retailer_1.products[SKU3_ID].seller + consumer_nodes = env.snapshot_list["consumer"] storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] + + consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", + ) + + storage_features = ("id", "facility_id") + + seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) + manufacture_features = ( "id", "facility_id", @@ -271,13 +253,17 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: "finished_quantity", "product_unit_id", ) + + distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + # ##################################### Before reset ##################################### env.step(None) - capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + storage_node_index = storage_unit.data_model_index + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) remaining_spaces = ( - storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) ) # there should be 80 units been taken at the beginning according to the config file. @@ -286,24 +272,25 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: # capacity is 100 by config self.assertEqual(100, capacities.sum()) - product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) # The product quantity should be same as configuration at beginning. # 80 sku3 self.assertEqual(80, product_dict[SKU3_ID]) - # all the id is greater than 0 - self.assertGreater(manufacture_sku3_unit.id, 0) - - action = ManufactureAction(manufacture_sku3_unit.id, 1) + ManufactureAction(manufacture_unit.id, 1) expect_tick = 30 # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) - # Store the information about the snapshot manufacture unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) + # Store the information about the snapshot unit of each tick in states_1 + states_1_consumer: Dict[int, dict] = defaultdict(dict) + states_1_storage: Dict[int, dict] = defaultdict(dict) + states_1_seller: Dict[int, dict] = defaultdict(dict) + states_1_manufacture: Dict[int, dict] = defaultdict(dict) + states_1_distribution: Dict[int, dict] = defaultdict(dict) random_tick: List[int] = [] @@ -312,26 +299,58 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: random_tick.append(random.randint(1, 30)) for i in range(expect_tick): - env.step([action]) - if i in random_tick: - env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env.step([ManufactureAction(manufacture_unit.id, 1)]) env_metric_1[i] = env.metrics - states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + for idx in range(len(consumer_nodes)): + states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_1_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) + + for idx in range(len(distribution_nodes)): + states_1_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) + + for idx in range(len(seller_nodes)): + states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) + + if i in random_tick: + env.step([ManufactureAction(manufacture_unit.id, 0)]) # ############################### Test whether reset updates the manufacture unit completely ################ env.reset() env.step(None) # snapshot should reset after env.reset(). - states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + for idx in range(len(manufacture_nodes)): + states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - storage_nodes = env.snapshot_list["storage"] - manufacture_nodes = env.snapshot_list["manufacture"] + for idx in range(len(storage_nodes)): + states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) + self.assertEqual([0, 0], list(states)) + + for idx in range(len(distribution_nodes)): + states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0], list(states)) - capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + for idx in range(len(consumer_nodes)): + states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + for idx in range(len(seller_nodes)): + states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) remaining_spaces = ( - storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) ) # there should be 80 units been taken at the beginning according to the config file. @@ -340,50 +359,81 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: # capacity is 100 by config self.assertEqual(100, capacities.sum()) - product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) # The product quantity should be same as configuration at beginning. # 80 sku3 self.assertEqual(80, product_dict[SKU3_ID]) # all the id is greater than 0 - self.assertGreater(manufacture_sku3_unit.id, 0) + self.assertGreater(manufacture_unit.id, 0) expect_tick = 30 # Save the env.metric of each tick into env_metric_2 env_metric_2: Dict[int, dict] = defaultdict(dict) - # Store the information about the snapshot manufacture unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) + # Store the information about the snapshot unit of each tick in states_2 + + states_2_consumer: Dict[int, dict] = defaultdict(dict) + states_2_storage: Dict[int, dict] = defaultdict(dict) + states_2_seller: Dict[int, dict] = defaultdict(dict) + states_2_manufacture: Dict[int, dict] = defaultdict(dict) + states_2_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): - env.step([action]) - if i in random_tick: - env.step([ManufactureAction(manufacture_sku3_unit.id, 1)]) + env.step(None) env_metric_2[i] = env.metrics - states_2[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + + for idx in range(len(consumer_nodes)): + states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) + + for idx in range(len(distribution_nodes)): + states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) + + for idx in range(len(seller_nodes)): + states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) + + if i in random_tick: + env.step([ManufactureAction(manufacture_unit.id, 0)]) expect_tick = 30 for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) def test_env_reset_with_ConsumerAction_only(self) -> None: """ "test env reset with ConsumerAction only""" - env = build_env("case_01", 500) + env = build_env("case_05", 500) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) env.step(None) - supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - sku3_consumer_unit = supplier_1.products[SKU3_ID].consumer + consumer_unit = warehouse_1.products[SKU3_ID].consumer - consumer_node_index = sku3_consumer_unit.data_model_index + consumer_nodes = env.snapshot_list["consumer"] + storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] - features = ( + consumer_features = ( "id", "facility_id", "sku_id", @@ -395,50 +445,151 @@ def test_env_reset_with_ConsumerAction_only(self) -> None: "in_transit_quantity", ) + storage_features = ("id", "facility_id") + + seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) + + manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", + ) + + distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + # ##################################### Before reset ##################################### - consumer_nodes = env.snapshot_list["consumer"] - action = ConsumerAction(sku3_consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + action = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") expect_tick = 100 # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) - # Store the information about the snapshot consumer unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) + # Store the information about the snapshot unit of each tick in states_1 + states_1_consumer: Dict[int, dict] = defaultdict(dict) + states_1_storage: Dict[int, dict] = defaultdict(dict) + states_1_seller: Dict[int, dict] = defaultdict(dict) + states_1_manufacture: Dict[int, dict] = defaultdict(dict) + states_1_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): env.step([action]) env_metric_1[i] = env.metrics - states_1[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + for idx in range(len(consumer_nodes)): + states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_1_manufacture[i][idx] = ( + manufacture_nodes[i:idx:manufacture_features] + .flatten() + .astype( + np.int, + ) + ) + + for idx in range(len(distribution_nodes)): + states_1_distribution[i][idx] = ( + distribution_nodes[i:idx:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + for idx in range(len(seller_nodes)): + states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) # ############### Test whether reset updates the consumer unit completely ################ env.reset() env.step(None) # snapshot should reset after env.reset() - states = consumer_nodes[1:consumer_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + for idx in range(len(manufacture_nodes)): + states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) + + for idx in range(len(storage_nodes)): + states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) + self.assertEqual([0, 0], list(states)) + + for idx in range(len(distribution_nodes)): + states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0], list(states)) + + for idx in range(len(consumer_nodes)): + states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) + + for idx in range(len(seller_nodes)): + states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) expect_tick = 100 # Save the env.metric of each tick into env_metric_2 env_metric_2: Dict[int, dict] = defaultdict(dict) - # Store the information about the snapshot consumer unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) + # Store the information about the snapshot unit of each tick in states_2 + states_2_consumer: Dict[int, dict] = defaultdict(dict) + states_2_storage: Dict[int, dict] = defaultdict(dict) + states_2_seller: Dict[int, dict] = defaultdict(dict) + states_2_manufacture: Dict[int, dict] = defaultdict(dict) + states_2_distribution: Dict[int, dict] = defaultdict(dict) + for i in range(expect_tick): env.step([action]) env_metric_2[i] = env.metrics - states_2[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) + + for idx in range(len(consumer_nodes)): + states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) + + for idx in range(len(distribution_nodes)): + states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) + + for idx in range(len(seller_nodes)): + states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) expect_tick = 100 + for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: """test env reset with both ManufactureAction and ConsumerAction""" - env = build_env("case_01", 100) + env = build_env("case_05", 100) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) @@ -446,13 +597,14 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - consumer_unit: ConsumerUnit = supplier_1.products[SKU3_ID].consumer + warehouse_1: RetailerFacility = be.world._get_facility_by_name("Warehouse_001") + consumer_unit: ConsumerUnit = warehouse_1.products[SKU1_ID].consumer manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture storage_unit: StorageUnit = supplier_1.storage consumer_node_index = consumer_unit.data_model_index manufacture_node_index = manufacture_unit.data_model_index - storage_node_index = storage_unit.data_model_index + storage_unit.data_model_index consumer_features = ( "id", @@ -466,6 +618,19 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: "in_transit_quantity", ) + storage_features = ("id", "facility_id") + + seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", + ) + manufacture_features = ( "id", "facility_id", @@ -475,21 +640,31 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: "finished_quantity", "product_unit_id", ) - storage_features = ("id", "facility_id") + + distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") consumer_nodes = env.snapshot_list["consumer"] - manufacture_nodes = env.snapshot_list["manufacture"] storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] # ##################################### Before reset ##################################### - action_consumer = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 20, "train") - action_manufacture = ManufactureAction(manufacture_unit.id, 5) + action_consumer = ConsumerAction(consumer_unit.id, SKU1_ID, supplier_1.id, 5, "train") + action_manufacture = ManufactureAction(manufacture_unit.id, 1) expect_tick = 100 # Save the env.metric of each tick into env_metric_1 env_metric_1: Dict[int, dict] = defaultdict(dict) + # Store the information about the snapshot unit of each tick in states_1 + states_1_consumer: Dict[int, dict] = defaultdict(dict) + states_1_storage: Dict[int, dict] = defaultdict(dict) + states_1_seller: Dict[int, dict] = defaultdict(dict) + states_1_manufacture: Dict[int, dict] = defaultdict(dict) + states_1_distribution: Dict[int, dict] = defaultdict(dict) + random_tick: List[int] = [] # The purpose is to randomly perform the order operation @@ -497,39 +672,49 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: random_tick.append(random.randint(0, 90)) # Store the information about the snapshot unit of each tick in states_1 - states_1_consumer: Dict[int, list] = defaultdict(list) - states_1_manufacture: Dict[int, list] = defaultdict(list) - states_1_storage: Dict[int, list] = defaultdict(list) + states_1_consumer: Dict[int, dict] = defaultdict(dict) + states_1_storage: Dict[int, dict] = defaultdict(dict) + states_1_seller: Dict[int, dict] = defaultdict(dict) + states_1_manufacture: Dict[int, dict] = defaultdict(dict) + states_1_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): - if i in random_tick: env.step([action_manufacture]) - i += 1 - states_1_manufacture[i] = list( - manufacture_nodes[i:manufacture_node_index:manufacture_features] + continue + + env.step([action_consumer]) + env_metric_1[i] = env.metrics + + for idx in range(len(consumer_nodes)): + states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_1_manufacture[i][idx] = ( + manufacture_nodes[i:idx:manufacture_features] .flatten() .astype( np.int, - ), + ) ) - env_metric_1[i] = env.metrics - continue - env.step([action_consumer]) - env_metric_1[i] = env.metrics - states_1_consumer[i] = list( - consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), - ) - - states_1_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) - states_1_storage[i].append( - list(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int)), - ) - states_1_storage[i].append( - list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), - ) + for idx in range(len(distribution_nodes)): + states_1_distribution[i][idx] = ( + distribution_nodes[i:idx:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + for idx in range(len(seller_nodes)): + states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) # ############### Test whether reset updates the consumer unit completely ################ env.reset() env.step(None) @@ -546,47 +731,59 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: env_metric_2: Dict[int, dict] = defaultdict(dict) # Store the information about the snapshot consumer unit of each tick in states_2 - states_2_consumer: Dict[int, list] = defaultdict(list) - states_2_manufacture: Dict[int, list] = defaultdict(list) - states_2_storage: Dict[int, list] = defaultdict(list) + states_2_consumer: Dict[int, dict] = defaultdict(dict) + states_2_storage: Dict[int, dict] = defaultdict(dict) + states_2_seller: Dict[int, dict] = defaultdict(dict) + states_2_manufacture: Dict[int, dict] = defaultdict(dict) + states_2_distribution: Dict[int, dict] = defaultdict(dict) for i in range(expect_tick): - if i in random_tick: env.step([action_manufacture]) - i += 1 - states_2_manufacture[i] = list( - manufacture_nodes[i:manufacture_node_index:manufacture_features] + continue + + env.step([action_consumer]) + env_metric_2[i] = env.metrics + + for idx in range(len(consumer_nodes)): + states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) + + for idx in range(len(storage_nodes)): + states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) + states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) + + for idx in range(len(manufacture_nodes)): + states_2_manufacture[i][idx] = ( + manufacture_nodes[i:idx:manufacture_features] .flatten() .astype( np.int, - ), + ) ) - env_metric_2[i] = env.metrics - continue - env.step([action_consumer]) - env_metric_2[i] = env.metrics - states_2_consumer[i] = list( - consumer_nodes[i:consumer_node_index:consumer_features].flatten().astype(np.int), - ) - - states_2_storage[i] = list(storage_nodes[i:storage_node_index:storage_features].flatten().astype(np.int)) - states_2_storage[i].append( - list( - storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int), - ), - ) - states_2_storage[i].append( - list(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int)), - ) + for idx in range(len(distribution_nodes)): + states_2_distribution[i][idx] = ( + distribution_nodes[i:idx:distribution_features] + .flatten() + .astype( + np.int, + ) + ) + + for idx in range(len(seller_nodes)): + states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) expect_tick = 100 for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + for unit_id, unit in be.world.units.items(): + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) if __name__ == "__main__": From 645d5e2478ef241a20c538f275e59ed9b40f93b8 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Wed, 22 Jun 2022 13:56:13 +0800 Subject: [PATCH 06/10] Modify test according to comments ++ --- tests/data/supply_chain/case_04/config.yml | 2 +- tests/data/supply_chain/case_05/config.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/supply_chain/case_04/config.yml b/tests/data/supply_chain/case_04/config.yml index af3085d6e..3e4465b53 100644 --- a/tests/data/supply_chain/case_04/config.yml +++ b/tests/data/supply_chain/case_04/config.yml @@ -244,7 +244,7 @@ world: "food_1": "Warehouse_001": "air": - vlt: 2 + vlt: 7 cost: 3.0 # Unit transportation cost per product per day "train": vlt: 7 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 08089864e..9c000cb3c 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -257,7 +257,7 @@ world: storage: *single_storage config: unit_order_cost: 200 - file_path: "tests/data/supply_chain/case_05/test_case_05.csv" + file_path: "tests/data/supply_chain/case_04/test_case_05.csv" topology: From 863d4ab9b77c22343842b113a859f89122faf2b2 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Fri, 24 Jun 2022 15:38:16 +0800 Subject: [PATCH 07/10] Modify test according to comments +++ Modify test according to comments +++ --- tests/data/supply_chain/case_04/config.yml | 2 +- tests/data/supply_chain/case_05/config.yml | 2 +- tests/supply_chain/common.py | 118 +++ tests/supply_chain/test_action_reset.py | 790 --------------------- tests/supply_chain/test_env_reset.py | 675 +++++++++--------- 5 files changed, 437 insertions(+), 1150 deletions(-) delete mode 100644 tests/supply_chain/test_action_reset.py diff --git a/tests/data/supply_chain/case_04/config.yml b/tests/data/supply_chain/case_04/config.yml index 3e4465b53..af3085d6e 100644 --- a/tests/data/supply_chain/case_04/config.yml +++ b/tests/data/supply_chain/case_04/config.yml @@ -244,7 +244,7 @@ world: "food_1": "Warehouse_001": "air": - vlt: 7 + vlt: 2 cost: 3.0 # Unit transportation cost per product per day "train": vlt: 7 diff --git a/tests/data/supply_chain/case_05/config.yml b/tests/data/supply_chain/case_05/config.yml index 9c000cb3c..08089864e 100644 --- a/tests/data/supply_chain/case_05/config.yml +++ b/tests/data/supply_chain/case_05/config.yml @@ -257,7 +257,7 @@ world: storage: *single_storage config: unit_order_cost: 200 - file_path: "tests/data/supply_chain/case_04/test_case_05.csv" + file_path: "tests/data/supply_chain/case_05/test_case_05.csv" topology: diff --git a/tests/supply_chain/common.py b/tests/supply_chain/common.py index a3ec84689..3d8b8be5d 100644 --- a/tests/supply_chain/common.py +++ b/tests/supply_chain/common.py @@ -2,6 +2,8 @@ # Licensed under the MIT license. import os +from collections import defaultdict +from typing import Dict import numpy as np @@ -23,9 +25,125 @@ def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int): return {sku_id: quantity for sku_id, quantity in zip(sku_id_list, product_quantity)} +def snapshot_query(env: Env, i: int): + consumer_nodes = env.snapshot_list["consumer"] + storage_nodes = env.snapshot_list["storage"] + seller_nodes = env.snapshot_list["seller"] + manufacture_nodes = env.snapshot_list["manufacture"] + distribution_nodes = env.snapshot_list["distribution"] + + states_consumer: Dict[int, list] = defaultdict(list) + states_storage: Dict[int, list] = defaultdict(list) + states_seller: Dict[int, list] = defaultdict(list) + states_manufacture: Dict[int, list] = defaultdict(list) + states_distribution: Dict[int, list] = defaultdict(list) + env_metric: Dict[int, list] = defaultdict(list) + + env_metric = env.metrics + + for idx in range(len(consumer_nodes)): + states_consumer[idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.float) + + for idx in range(len(storage_nodes)): + states_storage[idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.float)) + states_storage[idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) + states_storage[idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.float)) + states_storage[idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.float)) + + for idx in range(len(manufacture_nodes)): + states_manufacture[idx] = ( + manufacture_nodes[i:idx:manufacture_features] + .flatten() + .astype( + np.float, + ) + ) + + for idx in range(len(distribution_nodes)): + states_distribution[idx] = ( + distribution_nodes[i:idx:distribution_features] + .flatten() + .astype( + np.float, + ) + ) + + for idx in range(len(seller_nodes)): + states_seller[idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.float) + + return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution + + +def test_env_reset_snapshot_query(env: Env, action_1, action_2, expect_tick: int, random_tick: list): + + states_consumer: Dict[int, dict] = defaultdict(dict) + states_storage: Dict[int, dict] = defaultdict(dict) + states_seller: Dict[int, dict] = defaultdict(dict) + states_manufacture: Dict[int, dict] = defaultdict(dict) + states_distribution: Dict[int, dict] = defaultdict(dict) + env_metric: Dict[int, dict] = defaultdict(dict) + + for i in range(expect_tick): + ( + env_metric[i], + states_consumer[i], + states_storage[i], + states_seller[i], + states_manufacture[i], + states_distribution[i], + ) = snapshot_query( + env, + i, + ) + env.step(action_1) + + if random_tick is not None: + if i in random_tick: + env.step(action_2) + + return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution + + SKU1_ID = 1 SKU2_ID = 2 SKU3_ID = 3 SKU4_ID = 4 FOOD_1_ID = 20 HOBBY_1_ID = 30 + +consumer_features = ( + "id", + "facility_id", + "sku_id", + "order_base_cost", + "purchased", + "received", + "order_product_cost", + "latest_consumptions", + "in_transit_quantity", +) + +storage_features = ("id", "facility_id") + +seller_features = ( + "sold", + "demand", + "total_sold", + "id", + "total_demand", + "backlog_ratio", + "facility_id", + "product_unit_id", +) + +manufacture_features = ( + "id", + "facility_id", + "start_manufacture_quantity", + "sku_id", + "in_pipeline_quantity", + "finished_quantity", + "product_unit_id", +) + +distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") diff --git a/tests/supply_chain/test_action_reset.py b/tests/supply_chain/test_action_reset.py deleted file mode 100644 index cd84ffe2d..000000000 --- a/tests/supply_chain/test_action_reset.py +++ /dev/null @@ -1,790 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import random -import unittest -from collections import defaultdict -from typing import Dict, List - -import numpy as np - -from maro.simulator.scenarios.supply_chain import ( - ConsumerAction, - ConsumerUnit, - FacilityBase, - ManufactureAction, - ManufactureUnit, - RetailerFacility, - StorageUnit, - WarehouseFacility, -) -from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine - -from tests.supply_chain.common import SKU1_ID, SKU3_ID, build_env, get_product_dict_from_storage - - -class MyTestCase(unittest.TestCase): - """ - . consumer unit test - . distribution unit test - . manufacture unit test - . seller unit test - . storage unit test - """ - - def test_env_reset_with_none_action(self) -> None: - """test_env_reset_with_none_action""" - env = build_env("case_05", 500) - be = env.business_engine - assert isinstance(be, SupplyChainBusinessEngine) - - env.step(None) - - supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") - warehouse_1 = be.world._get_facility_by_name("Warehouse_001") - Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - - consumer_unit: ConsumerUnit = warehouse_1.products[SKU3_ID].consumer - storage_unit: StorageUnit = supplier_1.storage - Store_001.products[SKU3_ID].seller - supplier_1.products[SKU1_ID].manufacture - supplier_1.distribution - - consumer_nodes = env.snapshot_list["consumer"] - storage_nodes = env.snapshot_list["storage"] - seller_nodes = env.snapshot_list["seller"] - manufacture_nodes = env.snapshot_list["manufacture"] - distribution_nodes = env.snapshot_list["distribution"] - - consumer_features = ( - "id", - "facility_id", - "sku_id", - "order_base_cost", - "purchased", - "received", - "order_product_cost", - "latest_consumptions", - "in_transit_quantity", - ) - - storage_features = ("id", "facility_id") - - seller_features = ( - "sold", - "demand", - "total_sold", - "id", - "total_demand", - "backlog_ratio", - "facility_id", - "product_unit_id", - ) - - manufacture_features = ( - "id", - "facility_id", - "start_manufacture_quantity", - "sku_id", - "in_pipeline_quantity", - "finished_quantity", - "product_unit_id", - ) - - distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") - - # ##################################### Before reset ##################################### - - expect_tick = 10 - - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot of each tick in states_1_x - states_1_consumer: Dict[int, dict] = defaultdict(dict) - states_1_storage: Dict[int, dict] = defaultdict(dict) - states_1_seller: Dict[int, dict] = defaultdict(dict) - states_1_manufacture: Dict[int, dict] = defaultdict(dict) - states_1_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - env.step(None) - env_metric_1[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_1_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) - - for idx in range(len(distribution_nodes)): - states_1_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) - - for idx in range(len(seller_nodes)): - states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - # ############################### Test whether reset updates the storage unit completely ################ - env.reset() - env.step(None) - - # snapshot should reset after env.reset(). - for idx in range(len(manufacture_nodes)): - states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(storage_nodes)): - states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) - self.assertEqual([0, 0], list(states)) - - for idx in range(len(distribution_nodes)): - states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0], list(states)) - - for idx in range(len(consumer_nodes)): - states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(seller_nodes)): - states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - expect_tick = 10 - - # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot storage unit of each tick in states_2 - - states_2_consumer: Dict[int, dict] = defaultdict(dict) - states_2_storage: Dict[int, dict] = defaultdict(dict) - states_2_seller: Dict[int, dict] = defaultdict(dict) - states_2_manufacture: Dict[int, dict] = defaultdict(dict) - states_2_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - env.step(None) - env_metric_2[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) - - for idx in range(len(distribution_nodes)): - states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) - - for idx in range(len(seller_nodes)): - states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - - def test_env_reset_with_ManufactureAction_only(self) -> None: - """test env reset with ManufactureAction only""" - env = build_env("case_02", 100) - be = env.business_engine - assert isinstance(be, SupplyChainBusinessEngine) - - supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - warehouse_1: WarehouseFacility = be.world._get_facility_by_name("Warehouse_001") - retailer_1: RetailerFacility = be.world._get_facility_by_name("Retailer_001") - - storage_unit = supplier_3.storage - warehouse_1.products[SKU3_ID].consumer - manufacture_unit = supplier_3.products[SKU3_ID].manufacture - supplier_3.distribution - retailer_1.products[SKU3_ID].seller - - consumer_nodes = env.snapshot_list["consumer"] - storage_nodes = env.snapshot_list["storage"] - seller_nodes = env.snapshot_list["seller"] - manufacture_nodes = env.snapshot_list["manufacture"] - distribution_nodes = env.snapshot_list["distribution"] - - consumer_features = ( - "id", - "facility_id", - "sku_id", - "order_base_cost", - "purchased", - "received", - "order_product_cost", - "latest_consumptions", - "in_transit_quantity", - ) - - storage_features = ("id", "facility_id") - - seller_features = ( - "sold", - "demand", - "total_sold", - "id", - "total_demand", - "backlog_ratio", - "facility_id", - "product_unit_id", - ) - - manufacture_features = ( - "id", - "facility_id", - "start_manufacture_quantity", - "sku_id", - "in_pipeline_quantity", - "finished_quantity", - "product_unit_id", - ) - - distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") - - # ##################################### Before reset ##################################### - - env.step(None) - - storage_node_index = storage_unit.data_model_index - capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) - remaining_spaces = ( - storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) - ) - - # there should be 80 units been taken at the beginning according to the config file. - # so remaining space should be 20 - self.assertEqual(20, remaining_spaces.sum()) - # capacity is 100 by config - self.assertEqual(100, capacities.sum()) - - product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) - - # The product quantity should be same as configuration at beginning. - # 80 sku3 - self.assertEqual(80, product_dict[SKU3_ID]) - - ManufactureAction(manufacture_unit.id, 1) - - expect_tick = 30 - - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot unit of each tick in states_1 - states_1_consumer: Dict[int, dict] = defaultdict(dict) - states_1_storage: Dict[int, dict] = defaultdict(dict) - states_1_seller: Dict[int, dict] = defaultdict(dict) - states_1_manufacture: Dict[int, dict] = defaultdict(dict) - states_1_distribution: Dict[int, dict] = defaultdict(dict) - - random_tick: List[int] = [] - - # The purpose is to randomly perform the order operation - for i in range(10): - random_tick.append(random.randint(1, 30)) - - for i in range(expect_tick): - env.step([ManufactureAction(manufacture_unit.id, 1)]) - env_metric_1[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_1_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) - - for idx in range(len(distribution_nodes)): - states_1_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) - - for idx in range(len(seller_nodes)): - states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - if i in random_tick: - env.step([ManufactureAction(manufacture_unit.id, 0)]) - - # ############################### Test whether reset updates the manufacture unit completely ################ - env.reset() - env.step(None) - - # snapshot should reset after env.reset(). - for idx in range(len(manufacture_nodes)): - states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(storage_nodes)): - states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) - self.assertEqual([0, 0], list(states)) - - for idx in range(len(distribution_nodes)): - states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0], list(states)) - - for idx in range(len(consumer_nodes)): - states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(seller_nodes)): - states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) - remaining_spaces = ( - storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) - ) - - # there should be 80 units been taken at the beginning according to the config file. - # so remaining space should be 20 - self.assertEqual(20, remaining_spaces.sum()) - # capacity is 100 by config - self.assertEqual(100, capacities.sum()) - - product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) - - # The product quantity should be same as configuration at beginning. - # 80 sku3 - self.assertEqual(80, product_dict[SKU3_ID]) - - # all the id is greater than 0 - self.assertGreater(manufacture_unit.id, 0) - - expect_tick = 30 - - # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot unit of each tick in states_2 - - states_2_consumer: Dict[int, dict] = defaultdict(dict) - states_2_storage: Dict[int, dict] = defaultdict(dict) - states_2_seller: Dict[int, dict] = defaultdict(dict) - states_2_manufacture: Dict[int, dict] = defaultdict(dict) - states_2_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - env.step(None) - env_metric_2[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) - - for idx in range(len(distribution_nodes)): - states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) - - for idx in range(len(seller_nodes)): - states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - if i in random_tick: - env.step([ManufactureAction(manufacture_unit.id, 0)]) - - expect_tick = 30 - for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - - def test_env_reset_with_ConsumerAction_only(self) -> None: - """ "test env reset with ConsumerAction only""" - env = build_env("case_05", 500) - be = env.business_engine - assert isinstance(be, SupplyChainBusinessEngine) - - env.step(None) - - warehouse_1 = be.world._get_facility_by_name("Warehouse_001") - supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - consumer_unit = warehouse_1.products[SKU3_ID].consumer - - consumer_nodes = env.snapshot_list["consumer"] - storage_nodes = env.snapshot_list["storage"] - seller_nodes = env.snapshot_list["seller"] - manufacture_nodes = env.snapshot_list["manufacture"] - distribution_nodes = env.snapshot_list["distribution"] - - consumer_features = ( - "id", - "facility_id", - "sku_id", - "order_base_cost", - "purchased", - "received", - "order_product_cost", - "latest_consumptions", - "in_transit_quantity", - ) - - storage_features = ("id", "facility_id") - - seller_features = ( - "sold", - "demand", - "total_sold", - "id", - "total_demand", - "backlog_ratio", - "facility_id", - "product_unit_id", - ) - - manufacture_features = ( - "id", - "facility_id", - "start_manufacture_quantity", - "sku_id", - "in_pipeline_quantity", - "finished_quantity", - "product_unit_id", - ) - - distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") - - # ##################################### Before reset ##################################### - action = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot unit of each tick in states_1 - states_1_consumer: Dict[int, dict] = defaultdict(dict) - states_1_storage: Dict[int, dict] = defaultdict(dict) - states_1_seller: Dict[int, dict] = defaultdict(dict) - states_1_manufacture: Dict[int, dict] = defaultdict(dict) - states_1_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - env.step([action]) - env_metric_1[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_1_manufacture[i][idx] = ( - manufacture_nodes[i:idx:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(distribution_nodes)): - states_1_distribution[i][idx] = ( - distribution_nodes[i:idx:distribution_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(seller_nodes)): - states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - # ############### Test whether reset updates the consumer unit completely ################ - env.reset() - env.step(None) - - # snapshot should reset after env.reset() - for idx in range(len(manufacture_nodes)): - states = manufacture_nodes[1:idx:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(storage_nodes)): - states = storage_nodes[1:idx:storage_features].flatten().astype(np.int) - self.assertEqual([0, 0], list(states)) - - for idx in range(len(distribution_nodes)): - states = distribution_nodes[1:idx:distribution_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0], list(states)) - - for idx in range(len(consumer_nodes)): - states = consumer_nodes[1:idx:consumer_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - for idx in range(len(seller_nodes)): - states = seller_nodes[1:idx:seller_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot unit of each tick in states_2 - states_2_consumer: Dict[int, dict] = defaultdict(dict) - states_2_storage: Dict[int, dict] = defaultdict(dict) - states_2_seller: Dict[int, dict] = defaultdict(dict) - states_2_manufacture: Dict[int, dict] = defaultdict(dict) - states_2_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - env.step([action]) - env_metric_2[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_2_manufacture[i][idx] = manufacture_nodes[i:idx:manufacture_features].flatten().astype(np.int) - - for idx in range(len(distribution_nodes)): - states_2_distribution[i][idx] = distribution_nodes[i:idx:distribution_features].flatten().astype(np.int) - - for idx in range(len(seller_nodes)): - states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - expect_tick = 100 - - for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - - def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: - """test env reset with both ManufactureAction and ConsumerAction""" - env = build_env("case_05", 100) - be = env.business_engine - assert isinstance(be, SupplyChainBusinessEngine) - - env.step(None) - - supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") - supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - warehouse_1: RetailerFacility = be.world._get_facility_by_name("Warehouse_001") - consumer_unit: ConsumerUnit = warehouse_1.products[SKU1_ID].consumer - manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture - storage_unit: StorageUnit = supplier_1.storage - - consumer_node_index = consumer_unit.data_model_index - manufacture_node_index = manufacture_unit.data_model_index - storage_unit.data_model_index - - consumer_features = ( - "id", - "facility_id", - "sku_id", - "order_base_cost", - "purchased", - "received", - "order_product_cost", - "latest_consumptions", - "in_transit_quantity", - ) - - storage_features = ("id", "facility_id") - - seller_features = ( - "sold", - "demand", - "total_sold", - "id", - "total_demand", - "backlog_ratio", - "facility_id", - "product_unit_id", - ) - - manufacture_features = ( - "id", - "facility_id", - "start_manufacture_quantity", - "sku_id", - "in_pipeline_quantity", - "finished_quantity", - "product_unit_id", - ) - - distribution_features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") - - consumer_nodes = env.snapshot_list["consumer"] - storage_nodes = env.snapshot_list["storage"] - seller_nodes = env.snapshot_list["seller"] - manufacture_nodes = env.snapshot_list["manufacture"] - distribution_nodes = env.snapshot_list["distribution"] - - # ##################################### Before reset ##################################### - action_consumer = ConsumerAction(consumer_unit.id, SKU1_ID, supplier_1.id, 5, "train") - action_manufacture = ManufactureAction(manufacture_unit.id, 1) - - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot unit of each tick in states_1 - states_1_consumer: Dict[int, dict] = defaultdict(dict) - states_1_storage: Dict[int, dict] = defaultdict(dict) - states_1_seller: Dict[int, dict] = defaultdict(dict) - states_1_manufacture: Dict[int, dict] = defaultdict(dict) - states_1_distribution: Dict[int, dict] = defaultdict(dict) - - random_tick: List[int] = [] - - # The purpose is to randomly perform the order operation - for i in range(30): - random_tick.append(random.randint(0, 90)) - - # Store the information about the snapshot unit of each tick in states_1 - states_1_consumer: Dict[int, dict] = defaultdict(dict) - states_1_storage: Dict[int, dict] = defaultdict(dict) - states_1_seller: Dict[int, dict] = defaultdict(dict) - states_1_manufacture: Dict[int, dict] = defaultdict(dict) - states_1_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - if i in random_tick: - env.step([action_manufacture]) - continue - - env.step([action_consumer]) - env_metric_1[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_1_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_1_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_1_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_1_manufacture[i][idx] = ( - manufacture_nodes[i:idx:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(distribution_nodes)): - states_1_distribution[i][idx] = ( - distribution_nodes[i:idx:distribution_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(seller_nodes)): - states_1_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - # ############### Test whether reset updates the consumer unit completely ################ - env.reset() - env.step(None) - - # snapshot should reset after env.reset() - consumer_states = consumer_nodes[1:consumer_node_index:consumer_features].flatten().astype(np.int) - manufacture_states = manufacture_nodes[1:manufacture_node_index:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(consumer_states)) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(manufacture_states)) - - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot consumer unit of each tick in states_2 - states_2_consumer: Dict[int, dict] = defaultdict(dict) - states_2_storage: Dict[int, dict] = defaultdict(dict) - states_2_seller: Dict[int, dict] = defaultdict(dict) - states_2_manufacture: Dict[int, dict] = defaultdict(dict) - states_2_distribution: Dict[int, dict] = defaultdict(dict) - - for i in range(expect_tick): - if i in random_tick: - env.step([action_manufacture]) - continue - - env.step([action_consumer]) - env_metric_2[i] = env.metrics - - for idx in range(len(consumer_nodes)): - states_2_consumer[i][idx] = consumer_nodes[i:idx:consumer_features].flatten().astype(np.int) - - for idx in range(len(storage_nodes)): - states_2_storage[i][idx] = list(storage_nodes[i:idx:storage_features].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_id_list"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"product_quantity"].flatten().astype(np.int)) - states_2_storage[i][idx].append(storage_nodes[i:idx:"remaining_space"].flatten().astype(np.int)) - - for idx in range(len(manufacture_nodes)): - states_2_manufacture[i][idx] = ( - manufacture_nodes[i:idx:manufacture_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(distribution_nodes)): - states_2_distribution[i][idx] = ( - distribution_nodes[i:idx:distribution_features] - .flatten() - .astype( - np.int, - ) - ) - - for idx in range(len(seller_nodes)): - states_2_seller[i][idx] = seller_nodes[i:idx:seller_features].flatten().astype(np.int) - - expect_tick = 100 - for i in range(expect_tick): - for unit_id, unit in be.world.units.items(): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 2a5ab982d..64cfa806d 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -3,277 +3,163 @@ import random import unittest -from collections import defaultdict -from typing import Dict, List +from typing import List import numpy as np -from maro.simulator.scenarios.supply_chain import ConsumerAction, FacilityBase, ManufactureAction, StorageUnit +from maro.simulator.scenarios.supply_chain import ( + ConsumerAction, + ConsumerUnit, + FacilityBase, + ManufactureAction, + ManufactureUnit, + RetailerFacility, + StorageUnit, + WarehouseFacility, +) from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine -from maro.simulator.scenarios.supply_chain.order import Order -from tests.supply_chain.common import FOOD_1_ID, SKU3_ID, build_env, get_product_dict_from_storage +from tests.supply_chain.common import ( + SKU1_ID, + SKU3_ID, + build_env, + get_product_dict_from_storage, + snapshot_query, + test_env_reset_snapshot_query, +) class MyTestCase(unittest.TestCase): """ - . consumer unit test - . distribution unit test - . manufacture unit test - . seller unit test - . storage unit test + . test env reset with none action + . with ManufactureAction only + . with ConsumerAction only + . with both ManufactureAction and ConsumerAction """ - def test_consumer_unit_reset(self) -> None: - """Test whether reset updates the consumer unit completely""" - env = build_env("case_01", 500) + def test_env_reset_with_none_action(self) -> None: + """test_env_reset_with_none_action""" + env = build_env("case_05", 500) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) env.step(None) supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") - supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - sku3_consumer_unit = supplier_1.products[SKU3_ID].consumer - - consumer_node_index = sku3_consumer_unit.data_model_index - - features = ( - "id", - "facility_id", - "sku_id", - "order_base_cost", - "purchased", - "received", - "order_product_cost", - "latest_consumptions", - "in_transit_quantity", - ) - - # ##################################### Before reset ##################################### - consumer_nodes = env.snapshot_list["consumer"] - action = ConsumerAction(sku3_consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot consumer unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) - - for i in range(expect_tick): - env.step([action]) - env_metric_1[i] = env.metrics - states_1[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) - - # ############### Test whether reset updates the consumer unit completely ################ - env.reset() - env.step(None) - - # snapshot should reset after env.reset() - states = consumer_nodes[1:consumer_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - expect_tick = 100 - - # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot consumer unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) - for i in range(expect_tick): - env.step([action]) - env_metric_2[i] = env.metrics - states_2[i] = consumer_nodes[i:consumer_node_index:features].flatten().astype(np.int) - - expect_tick = 100 - for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - - def test_distribution_unit_reset(self) -> None: - """Test initial state of the DistributionUnit of Supplier_SKU3.Test distribution unit reset""" - env = build_env("case_02", 100) - be = env.business_engine - assert isinstance(be, SupplyChainBusinessEngine) - - env.step(None) - - supplier_3 = be.world._get_facility_by_name("Supplier_SKU3") warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - distribution_unit = supplier_3.distribution - distribution_node_index = distribution_unit.data_model_index - distribution_nodes = env.snapshot_list["distribution"] + consumer_unit: ConsumerUnit = warehouse_1.products[SKU3_ID].consumer + storage_unit: StorageUnit = supplier_1.storage + Store_001.products[SKU3_ID].seller + supplier_1.products[SKU1_ID].manufacture + supplier_1.distribution - features = ("id", "facility_id", "pending_order_number", "pending_product_quantity") + env.snapshot_list["consumer"] + env.snapshot_list["storage"] + env.snapshot_list["seller"] + env.snapshot_list["manufacture"] + env.snapshot_list["distribution"] # ##################################### Before reset ##################################### - order_1 = Order( - src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - - # There are 2 "train" in total, and 1 left after scheduling this order. - distribution_unit.place_order(order_1) - distribution_unit.try_schedule_orders(env.tick) - self.assertEqual(0, len(distribution_unit._order_queues["train"])) - self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) - - order_2 = Order( - src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - - distribution_unit.place_order(order_2) - distribution_unit.try_schedule_orders(env.tick) - self.assertEqual(0, len(distribution_unit._order_queues["train"])) - self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) - - # 3rd order, will cause the pending order increase - order_3 = Order( - src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - distribution_unit.place_order(order_3) - distribution_unit.try_schedule_orders(env.tick) - self.assertEqual(1, len(distribution_unit._order_queues["train"])) - self.assertEqual(10, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - - env.step(None) - - # The purpose is to randomly perform the order operation - random_tick: List[int] = [] - for j in range(10): - random_tick.append(random.randint(5, 100)) - - expect_tick = 100 + expect_tick = 10 # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot distribution unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env, + None, + None, + expect_tick, + None, + ) - for i in range(expect_tick): - if i in random_tick: - order = Order( - src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - distribution_unit.place_order(order) - distribution_unit.try_schedule_orders(env.tick) - env.step(None) - env_metric_1[i] = env.metrics - states_1[i] = distribution_nodes[i:distribution_node_index:features].flatten().astype(np.int) - - # ####################### Test whether reset updates the distribution unit completely ################ + # ############################### Test whether reset updates the storage unit completely ################ env.reset() env.step(None) - distribution_nodes = env.snapshot_list["distribution"] - - # snapshot should reset after env.reset(). - states = distribution_nodes[1:distribution_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0], list(states)) - - # Do the same as before env.reset(). - distribution_unit.place_order(order_1) - distribution_unit.try_schedule_orders(env.tick) - - distribution_unit.place_order(order_2) - distribution_unit.try_schedule_orders(env.tick) - self.assertEqual(0, len(distribution_unit._order_queues["train"])) - self.assertEqual(0, sum([order.quantity for order in distribution_unit._order_queues["train"]])) - - distribution_unit.place_order(order_3) - distribution_unit.try_schedule_orders(env.tick) - self.assertEqual(1, len(distribution_unit._order_queues["train"])) - self.assertEqual(10, sum([order.required_quantity for order in distribution_unit._order_queues["train"]])) - - env.step(None) - - expect_tick = 100 - # Save the env.metric of each tick into env_metric_2. - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot distribution unit of each tick in states_2. - states_2: Dict[int, list] = defaultdict(list) + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query( + env, + 0, + ) + self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) - for i in range(expect_tick): - if i in random_tick: - order = Order( - src_facility=supplier_3, - dest_facility=warehouse_1, - sku_id=SKU3_ID, - quantity=10, - vehicle_type="train", - creation_tick=env.tick, - expected_finish_tick=env.tick + 7, - ) - distribution_unit.place_order(order) - distribution_unit.try_schedule_orders(env.tick) - env.step(None) - env_metric_2[i] = env.metrics - states_2[i] = distribution_nodes[i:distribution_node_index:features].flatten().astype(np.int) + # Save the env.metric of each tick into env_metric_2 + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env, + None, + None, + expect_tick, + None, + ) - expect_tick = 100 for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - def test_manufacture_unit_reset(self) -> None: - """Test sku3 manufacturing. -- Supplier_SKU3.Test manufacture unit reset""" - env = build_env("case_01", 100) + def test_env_reset_with_ManufactureAction_only(self) -> None: + """test env reset with ManufactureAction only""" + env = build_env("case_02", 100) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - sku3_storage_index = supplier_3.storage.data_model_index - manufacture_sku3_unit = supplier_3.products[SKU3_ID].manufacture - sku3_manufacture_index = manufacture_sku3_unit.data_model_index + warehouse_1: WarehouseFacility = be.world._get_facility_by_name("Warehouse_001") + retailer_1: RetailerFacility = be.world._get_facility_by_name("Retailer_001") + + storage_unit = supplier_3.storage + warehouse_1.products[SKU3_ID].consumer + manufacture_unit = supplier_3.products[SKU3_ID].manufacture + supplier_3.distribution + retailer_1.products[SKU3_ID].seller + env.snapshot_list["consumer"] storage_nodes = env.snapshot_list["storage"] - manufacture_nodes = env.snapshot_list["manufacture"] - - manufacture_features = ( - "id", - "facility_id", - "start_manufacture_quantity", - "sku_id", - "in_pipeline_quantity", - "finished_quantity", - "product_unit_id", - ) + env.snapshot_list["seller"] + env.snapshot_list["manufacture"] + env.snapshot_list["distribution"] # ##################################### Before reset ##################################### env.step(None) - capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + storage_node_index = storage_unit.data_model_index + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) remaining_spaces = ( - storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) ) # there should be 80 units been taken at the beginning according to the config file. @@ -282,24 +168,18 @@ def test_manufacture_unit_reset(self) -> None: # capacity is 100 by config self.assertEqual(100, capacities.sum()) - product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) # The product quantity should be same as configuration at beginning. # 80 sku3 self.assertEqual(80, product_dict[SKU3_ID]) - # all the id is greater than 0 - self.assertGreater(manufacture_sku3_unit.id, 0) - - action = ManufactureAction(manufacture_sku3_unit.id, 1) + ManufactureAction(manufacture_unit.id, 1) expect_tick = 30 - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot manufacture unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) + action_1 = ManufactureAction(manufacture_unit.id, 1) + action_2 = ManufactureAction(manufacture_unit.id, 0) random_tick: List[int] = [] @@ -307,27 +187,49 @@ def test_manufacture_unit_reset(self) -> None: for i in range(10): random_tick.append(random.randint(1, 30)) - for i in range(expect_tick): - env.step([action]) - if i in random_tick: - env.step([ManufactureAction(manufacture_sku3_unit.id, 0)]) - env_metric_1[i] = env.metrics - states_1[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env, + action_1, + action_2, + expect_tick, + random_tick, + ) # ############################### Test whether reset updates the manufacture unit completely ################ env.reset() env.step(None) - # snapshot should reset after env.reset(). - states = manufacture_nodes[1:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0], list(states)) - - storage_nodes = env.snapshot_list["storage"] - manufacture_nodes = env.snapshot_list["manufacture"] - - capacities = storage_nodes[env.frame_index : sku3_storage_index : "capacity"].flatten().astype(np.int) + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query( + env, + 0, + ) + self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + + capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) remaining_spaces = ( - storage_nodes[env.frame_index : sku3_storage_index : "remaining_space"].flatten().astype(np.int) + storage_nodes[env.frame_index : storage_node_index : "remaining_space"].flatten().astype(np.int) ) # there should be 80 units been taken at the beginning according to the config file. @@ -336,166 +238,223 @@ def test_manufacture_unit_reset(self) -> None: # capacity is 100 by config self.assertEqual(100, capacities.sum()) - product_dict = get_product_dict_from_storage(env, env.frame_index, sku3_storage_index) + product_dict = get_product_dict_from_storage(env, env.frame_index, storage_node_index) # The product quantity should be same as configuration at beginning. # 80 sku3 self.assertEqual(80, product_dict[SKU3_ID]) # all the id is greater than 0 - self.assertGreater(manufacture_sku3_unit.id, 0) + self.assertGreater(manufacture_unit.id, 0) expect_tick = 30 # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot manufacture unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) - - for i in range(expect_tick): - env.step([action]) - if i in random_tick: - env.step([ManufactureAction(manufacture_sku3_unit.id, 0)]) - env_metric_2[i] = env.metrics - states_2[i] = manufacture_nodes[i:sku3_manufacture_index:manufacture_features].flatten().astype(np.int) + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env, + action_1, + action_2, + expect_tick, + random_tick, + ) - expect_tick = 30 for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - def test_seller_unit_dynamics_sampler(self): - """Tested the store_001 Interaction between seller unit and dynamics csv data. - The data file of this test is test_case_ 04.csv""" - env = build_env("case_04", 600) + def test_env_reset_with_ConsumerAction_only(self) -> None: + """ "test env reset with ConsumerAction only""" + env = build_env("case_05", 500) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) env.step(None) - Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - - seller_unit = Store_001.products[FOOD_1_ID].seller - seller_node_index = seller_unit.data_model_index - seller_nodes = env.snapshot_list["seller"] - - features = ( - "sold", - "demand", - "total_sold", - "id", - "total_demand", - "backlog_ratio", - "facility_id", - "product_unit_id", - ) - # ##################################### Before reset ##################################### - self.assertEqual(20, seller_unit.sku_id) - - # NOTE: this simple seller unit return demands that same as current tick + warehouse_1 = be.world._get_facility_by_name("Warehouse_001") + supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + consumer_unit = warehouse_1.products[SKU3_ID].consumer - # Tick 0 will have demand == 10.first row of data after preprocessing data. - # from sample_preprocessed.csv - self.assertEqual(10, seller_unit._sold) - self.assertEqual(10, seller_unit._demand) - self.assertEqual(10, seller_unit._total_sold) + env.snapshot_list["consumer"] + env.snapshot_list["storage"] + env.snapshot_list["seller"] + env.snapshot_list["manufacture"] + env.snapshot_list["distribution"] - expect_tick = 12 + # ##################################### Before reset ##################################### + action = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") + expect_tick = 100 # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot seller unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) - for i in range(expect_tick): - env.step(None) - env_metric_1[i] = env.metrics - states_1[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int) + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env, + action, + None, + expect_tick, + None, + ) - # ################# Test whether reset updates the seller unit completely ################ + # ############### Test whether reset updates the consumer unit completely ################ env.reset() env.step(None) - # snapshot should reset after env.reset(). - states = seller_nodes[1:seller_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0], list(states)) - - expect_tick = 12 + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query( + env, + 0, + ) + self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot seller unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) - - for i in range(expect_tick): - env.step(None) - env_metric_2[i] = env.metrics - states_2[i] = seller_nodes[i:seller_node_index:features].flatten().astype(np.int) + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env, + action, + None, + expect_tick, + None, + ) for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) - def test_storage_unit_reset(self) -> None: - """Facility with single SKU. -- Supplier_SKU3""" - env = build_env("case_01", 100) + def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: + """test env reset with both ManufactureAction and ConsumerAction""" + env = build_env("case_05", 100) be = env.business_engine assert isinstance(be, SupplyChainBusinessEngine) env.step(None) - supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") + supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") + warehouse_1: RetailerFacility = be.world._get_facility_by_name("Warehouse_001") + consumer_unit: ConsumerUnit = warehouse_1.products[SKU1_ID].consumer + manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture - storage_unit: StorageUnit = supplier_3.storage - storage_node_index = storage_unit.data_model_index - storage_nodes = env.snapshot_list["storage"] - features = ("id", "facility_id") + consumer_unit.data_model_index + manufacture_unit.data_model_index + + env.snapshot_list["consumer"] + env.snapshot_list["manufacture"] # ##################################### Before reset ##################################### + action_consumer = ConsumerAction(consumer_unit.id, SKU1_ID, supplier_1.id, 5, "train") + action_manufacture = ManufactureAction(manufacture_unit.id, 1) - expect_tick = 10 + expect_tick = 100 - # Save the env.metric of each tick into env_metric_1 - env_metric_1: Dict[int, dict] = defaultdict(dict) + random_tick: List[int] = [] - # Store the information about the snapshot storage unit of each tick in states_1 - states_1: Dict[int, list] = defaultdict(list) - for i in range(expect_tick): - env.step(None) - env_metric_1[i] = env.metrics - states_1[i] = list(storage_nodes[i:storage_node_index:features].flatten().astype(np.int)) - states_1[i].append(storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum()) - states_1[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum()) - states_1[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum()) + # The purpose is to randomly perform the order operation + for i in range(30): + random_tick.append(random.randint(0, 90)) - # ############################### Test whether reset updates the storage unit completely ################ + # Save the env.metric of each tick into env_metric_1 + # Store the information about the snapshot unit of each tick in states_1_unit + ( + env_metric_1, + states_1_consumer, + states_1_storage, + states_1_seller, + states_1_manufacture, + states_1_distribution, + ) = test_env_reset_snapshot_query( + env, + action_consumer, + action_manufacture, + expect_tick, + random_tick, + ) + + # ############### Test whether reset updates the consumer unit completely ################ env.reset() env.step(None) - # snapshot should reset after env.reset(). - states = storage_nodes[1:storage_node_index:features].flatten().astype(np.int) - self.assertEqual([0, 0], list(states)) - - expect_tick = 10 + # Check snapshot initial state after env.reset() + ( + env_metric_initial, + states_consumer_initial, + states_storage_initial, + states_seller_initial, + states_manufacture_initial, + states_distribution_initial, + ) = snapshot_query( + env, + 0, + ) + self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) # Save the env.metric of each tick into env_metric_2 - env_metric_2: Dict[int, dict] = defaultdict(dict) - - # Store the information about the snapshot storage unit of each tick in states_2 - states_2: Dict[int, list] = defaultdict(list) - - for i in range(expect_tick): - env.step(None) - env_metric_2[i] = env.metrics - states_2[i] = list(storage_nodes[i:storage_node_index:features].flatten().astype(np.int)) - states_2[i].append(storage_nodes[i:storage_node_index:"product_id_list"].flatten().astype(np.int).sum()) - states_2[i].append(storage_nodes[i:storage_node_index:"product_quantity"].flatten().astype(np.int).sum()) - states_2[i].append(storage_nodes[i:storage_node_index:"remaining_space"].flatten().astype(np.int).sum()) + # Store the information about the snapshot unit of each tick in states_2_unit + ( + env_metric_2, + states_2_consumer, + states_2_storage, + states_2_seller, + states_2_manufacture, + states_2_distribution, + ) = test_env_reset_snapshot_query( + env, + action_consumer, + action_manufacture, + expect_tick, + random_tick, + ) for i in range(expect_tick): - self.assertEqual(list(states_1[i]), list(states_2[i])) + self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) From 5800acce55d112a71f3603a27d9e09cd75ba2ff2 Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Fri, 24 Jun 2022 16:32:39 +0800 Subject: [PATCH 08/10] Modify test according to comments +++ --- tests/supply_chain/test_env_reset.py | 40 ---------------------------- 1 file changed, 40 deletions(-) diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 64cfa806d..ff511a2c8 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -14,8 +14,6 @@ ManufactureAction, ManufactureUnit, RetailerFacility, - StorageUnit, - WarehouseFacility, ) from maro.simulator.scenarios.supply_chain.business_engine import SupplyChainBusinessEngine @@ -45,22 +43,6 @@ def test_env_reset_with_none_action(self) -> None: env.step(None) - supplier_1: FacilityBase = be.world._get_facility_by_name("Supplier_SKU1") - warehouse_1 = be.world._get_facility_by_name("Warehouse_001") - Store_001: FacilityBase = be.world._get_facility_by_name("Store_001") - - consumer_unit: ConsumerUnit = warehouse_1.products[SKU3_ID].consumer - storage_unit: StorageUnit = supplier_1.storage - Store_001.products[SKU3_ID].seller - supplier_1.products[SKU1_ID].manufacture - supplier_1.distribution - - env.snapshot_list["consumer"] - env.snapshot_list["storage"] - env.snapshot_list["seller"] - env.snapshot_list["manufacture"] - env.snapshot_list["distribution"] - # ##################################### Before reset ##################################### expect_tick = 10 @@ -137,20 +119,10 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: assert isinstance(be, SupplyChainBusinessEngine) supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") - warehouse_1: WarehouseFacility = be.world._get_facility_by_name("Warehouse_001") - retailer_1: RetailerFacility = be.world._get_facility_by_name("Retailer_001") storage_unit = supplier_3.storage - warehouse_1.products[SKU3_ID].consumer manufacture_unit = supplier_3.products[SKU3_ID].manufacture - supplier_3.distribution - retailer_1.products[SKU3_ID].seller - - env.snapshot_list["consumer"] storage_nodes = env.snapshot_list["storage"] - env.snapshot_list["seller"] - env.snapshot_list["manufacture"] - env.snapshot_list["distribution"] # ##################################### Before reset ##################################### @@ -286,12 +258,6 @@ def test_env_reset_with_ConsumerAction_only(self) -> None: supplier_3: FacilityBase = be.world._get_facility_by_name("Supplier_SKU3") consumer_unit = warehouse_1.products[SKU3_ID].consumer - env.snapshot_list["consumer"] - env.snapshot_list["storage"] - env.snapshot_list["seller"] - env.snapshot_list["manufacture"] - env.snapshot_list["distribution"] - # ##################################### Before reset ##################################### action = ConsumerAction(consumer_unit.id, SKU3_ID, supplier_3.id, 1, "train") expect_tick = 100 @@ -374,12 +340,6 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: consumer_unit: ConsumerUnit = warehouse_1.products[SKU1_ID].consumer manufacture_unit: ManufactureUnit = supplier_1.products[SKU1_ID].manufacture - consumer_unit.data_model_index - manufacture_unit.data_model_index - - env.snapshot_list["consumer"] - env.snapshot_list["manufacture"] - # ##################################### Before reset ##################################### action_consumer = ConsumerAction(consumer_unit.id, SKU1_ID, supplier_1.id, 5, "train") action_manufacture = ManufactureAction(manufacture_unit.id, 1) From da28e6a9541194f824379441a4617a16582fde6b Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 31 Aug 2022 13:38:39 +0800 Subject: [PATCH 09/10] Optimize code style --- tests/supply_chain/common.py | 59 ++++---- tests/supply_chain/test_env_reset.py | 198 +++++++++++++-------------- 2 files changed, 115 insertions(+), 142 deletions(-) diff --git a/tests/supply_chain/common.py b/tests/supply_chain/common.py index 3d8b8be5d..3be8507fc 100644 --- a/tests/supply_chain/common.py +++ b/tests/supply_chain/common.py @@ -3,41 +3,39 @@ import os from collections import defaultdict -from typing import Dict +from typing import Dict, List, Tuple import numpy as np from maro.simulator import Env -def build_env(case_name: str, durations: int): +def build_env(case_name: str, durations: int) -> Env: case_folder = os.path.join("tests", "data", "supply_chain", case_name) + return Env(scenario="supply_chain", topology=case_folder, durations=durations) - env = Env(scenario="supply_chain", topology=case_folder, durations=durations) - return env - - -def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int): +def get_product_dict_from_storage(env: Env, frame_index: int, node_index: int) -> Dict[int, int]: sku_id_list = env.snapshot_list["storage"][frame_index:node_index:"sku_id_list"].flatten().astype(np.int) product_quantity = env.snapshot_list["storage"][frame_index:node_index:"product_quantity"].flatten().astype(np.int) - return {sku_id: quantity for sku_id, quantity in zip(sku_id_list, product_quantity)} + return dict(zip(sku_id_list, product_quantity)) -def snapshot_query(env: Env, i: int): +def snapshot_query(env: Env, i: int) -> Tuple[ + Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list], Dict[int, list] +]: consumer_nodes = env.snapshot_list["consumer"] storage_nodes = env.snapshot_list["storage"] seller_nodes = env.snapshot_list["seller"] manufacture_nodes = env.snapshot_list["manufacture"] distribution_nodes = env.snapshot_list["distribution"] - states_consumer: Dict[int, list] = defaultdict(list) - states_storage: Dict[int, list] = defaultdict(list) - states_seller: Dict[int, list] = defaultdict(list) - states_manufacture: Dict[int, list] = defaultdict(list) - states_distribution: Dict[int, list] = defaultdict(list) - env_metric: Dict[int, list] = defaultdict(list) + states_consumer: Dict[int, list] = {} + states_storage: Dict[int, list] = {} + states_seller: Dict[int, list] = {} + states_manufacture: Dict[int, list] = {} + states_distribution: Dict[int, list] = {} env_metric = env.metrics @@ -74,34 +72,23 @@ def snapshot_query(env: Env, i: int): return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution -def test_env_reset_snapshot_query(env: Env, action_1, action_2, expect_tick: int, random_tick: list): - - states_consumer: Dict[int, dict] = defaultdict(dict) - states_storage: Dict[int, dict] = defaultdict(dict) - states_seller: Dict[int, dict] = defaultdict(dict) - states_manufacture: Dict[int, dict] = defaultdict(dict) - states_distribution: Dict[int, dict] = defaultdict(dict) - env_metric: Dict[int, dict] = defaultdict(dict) - +def test_env_reset_snapshot_query( + env: Env, + action_1: object, + action_2: object, + expect_tick: int, + random_tick: list = None, +) -> List[tuple]: + snapshots: List[tuple] = [] # List of (env_metric, states_consumer, ..., states_distribution) for i in range(expect_tick): - ( - env_metric[i], - states_consumer[i], - states_storage[i], - states_seller[i], - states_manufacture[i], - states_distribution[i], - ) = snapshot_query( - env, - i, - ) + snapshots.append(snapshot_query(env, i)) env.step(action_1) if random_tick is not None: if i in random_tick: env.step(action_2) - return env_metric, states_consumer, states_storage, states_seller, states_manufacture, states_distribution + return list(zip(*snapshots)) SKU1_ID = 1 diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index ff511a2c8..0866bd519 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -57,11 +57,11 @@ def test_env_reset_with_none_action(self) -> None: states_1_manufacture, states_1_distribution, ) = test_env_reset_snapshot_query( - env, - None, - None, - expect_tick, - None, + env=env, + action_1=None, + action_2=None, + expect_tick=expect_tick, + random_tick=None, ) # ############################### Test whether reset updates the storage unit completely ################ @@ -76,16 +76,13 @@ def test_env_reset_with_none_action(self) -> None: states_seller_initial, states_manufacture_initial, states_distribution_initial, - ) = snapshot_query( - env, - 0, - ) - self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) - self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) - self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) - self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) - self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) - self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) # Save the env.metric of each tick into env_metric_2 # Store the information about the snapshot unit of each tick in states_2_unit @@ -97,20 +94,20 @@ def test_env_reset_with_none_action(self) -> None: states_2_manufacture, states_2_distribution, ) = test_env_reset_snapshot_query( - env, - None, - None, - expect_tick, - None, + env=env, + action_1=None, + action_2=None, + expect_tick=expect_tick, + random_tick=None, ) for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) def test_env_reset_with_ManufactureAction_only(self) -> None: """test env reset with ManufactureAction only""" @@ -146,8 +143,6 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: # 80 sku3 self.assertEqual(80, product_dict[SKU3_ID]) - ManufactureAction(manufacture_unit.id, 1) - expect_tick = 30 action_1 = ManufactureAction(manufacture_unit.id, 1) @@ -169,11 +164,11 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: states_1_manufacture, states_1_distribution, ) = test_env_reset_snapshot_query( - env, - action_1, - action_2, - expect_tick, - random_tick, + env=env, + action_1=action_1, + action_2=action_2, + expect_tick=expect_tick, + random_tick=random_tick, ) # ############################### Test whether reset updates the manufacture unit completely ################ @@ -188,16 +183,13 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: states_seller_initial, states_manufacture_initial, states_distribution_initial, - ) = snapshot_query( - env, - 0, - ) - self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) - self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) - self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) - self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) - self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) - self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) capacities = storage_nodes[env.frame_index : storage_node_index : "capacity"].flatten().astype(np.int) remaining_spaces = ( @@ -231,20 +223,20 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: states_2_manufacture, states_2_distribution, ) = test_env_reset_snapshot_query( - env, - action_1, - action_2, - expect_tick, - random_tick, + env=env, + action_1=action_1, + action_2=action_2, + expect_tick=expect_tick, + random_tick=random_tick, ) for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) def test_env_reset_with_ConsumerAction_only(self) -> None: """ "test env reset with ConsumerAction only""" @@ -272,11 +264,11 @@ def test_env_reset_with_ConsumerAction_only(self) -> None: states_1_manufacture, states_1_distribution, ) = test_env_reset_snapshot_query( - env, - action, - None, - expect_tick, - None, + env=env, + action_1=action, + action_2=None, + expect_tick=expect_tick, + random_tick=None, ) # ############### Test whether reset updates the consumer unit completely ################ @@ -291,16 +283,13 @@ def test_env_reset_with_ConsumerAction_only(self) -> None: states_seller_initial, states_manufacture_initial, states_distribution_initial, - ) = snapshot_query( - env, - 0, - ) - self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) - self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) - self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) - self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) - self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) - self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) # Save the env.metric of each tick into env_metric_2 # Store the information about the snapshot unit of each tick in states_2_unit @@ -312,20 +301,20 @@ def test_env_reset_with_ConsumerAction_only(self) -> None: states_2_manufacture, states_2_distribution, ) = test_env_reset_snapshot_query( - env, - action, - None, - expect_tick, - None, + env=env, + action_1=action, + action_2=None, + expect_tick=expect_tick, + random_tick=None, ) for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: """test env reset with both ManufactureAction and ConsumerAction""" @@ -362,11 +351,11 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: states_1_manufacture, states_1_distribution, ) = test_env_reset_snapshot_query( - env, - action_consumer, - action_manufacture, - expect_tick, - random_tick, + env=env, + action_1=action_consumer, + action_2=action_manufacture, + expect_tick=expect_tick, + random_tick=random_tick, ) # ############### Test whether reset updates the consumer unit completely ################ @@ -381,16 +370,13 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: states_seller_initial, states_manufacture_initial, states_distribution_initial, - ) = snapshot_query( - env, - 0, - ) - self.assertEqual(list(states_1_consumer[0]), list(states_consumer_initial)) - self.assertEqual(list(states_1_storage[0]), list(states_storage_initial)) - self.assertEqual(list(states_1_seller[0]), list(states_seller_initial)) - self.assertEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) - self.assertEqual(list(states_1_distribution[0]), list(states_distribution_initial)) - self.assertEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) + ) = snapshot_query(env, 0) + self.assertListEqual(list(states_1_consumer[0]), list(states_consumer_initial)) + self.assertListEqual(list(states_1_storage[0]), list(states_storage_initial)) + self.assertListEqual(list(states_1_seller[0]), list(states_seller_initial)) + self.assertListEqual(list(states_1_manufacture[0]), list(states_manufacture_initial)) + self.assertListEqual(list(states_1_distribution[0]), list(states_distribution_initial)) + self.assertListEqual(list(env_metric_1[0].values()), list(env_metric_initial.values())) # Save the env.metric of each tick into env_metric_2 # Store the information about the snapshot unit of each tick in states_2_unit @@ -402,20 +388,20 @@ def test_env_reset_with_both_ManufactureAction_and_ConsumerAction(self) -> None: states_2_manufacture, states_2_distribution, ) = test_env_reset_snapshot_query( - env, - action_consumer, - action_manufacture, - expect_tick, - random_tick, + env=env, + action_1=action_consumer, + action_2=action_manufacture, + expect_tick=expect_tick, + random_tick=random_tick, ) for i in range(expect_tick): - self.assertEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) - self.assertEqual(list(states_1_storage[i]), list(states_2_storage[i])) - self.assertEqual(list(states_1_seller[i]), list(states_2_seller[i])) - self.assertEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) - self.assertEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) - self.assertEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) + self.assertListEqual(list(states_1_consumer[i]), list(states_2_consumer[i])) + self.assertListEqual(list(states_1_storage[i]), list(states_2_storage[i])) + self.assertListEqual(list(states_1_seller[i]), list(states_2_seller[i])) + self.assertListEqual(list(states_1_manufacture[i]), list(states_2_manufacture[i])) + self.assertListEqual(list(states_1_distribution[i]), list(states_2_distribution[i])) + self.assertListEqual(list(env_metric_1[i].values()), list(env_metric_2[i].values())) if __name__ == "__main__": From 988acc4ba237d8212d526e04d7a41772b2cde77f Mon Sep 17 00:00:00 2001 From: v-heli1 Date: Wed, 31 Aug 2022 16:21:02 +0800 Subject: [PATCH 10/10] Optimize code style[added] --- tests/supply_chain/test_env_reset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/supply_chain/test_env_reset.py b/tests/supply_chain/test_env_reset.py index 0866bd519..4e6227d4c 100644 --- a/tests/supply_chain/test_env_reset.py +++ b/tests/supply_chain/test_env_reset.py @@ -211,8 +211,6 @@ def test_env_reset_with_ManufactureAction_only(self) -> None: # all the id is greater than 0 self.assertGreater(manufacture_unit.id, 0) - expect_tick = 30 - # Save the env.metric of each tick into env_metric_2 # Store the information about the snapshot unit of each tick in states_2_unit (