From 03bddda93a487ea0d19eab6a5aa86b7633bc3b2c Mon Sep 17 00:00:00 2001 From: jucordero Date: Wed, 11 Mar 2026 18:31:36 +0000 Subject: [PATCH 1/7] configuration parser and cli script --- agrifoodpy/pipeline/cli.py | 58 +++++++++++++++++++++++++++++++++ agrifoodpy/pipeline/pipeline.py | 38 +++++++++++++++++---- setup.py | 15 ++++++--- 3 files changed, 99 insertions(+), 12 deletions(-) create mode 100644 agrifoodpy/pipeline/cli.py diff --git a/agrifoodpy/pipeline/cli.py b/agrifoodpy/pipeline/cli.py new file mode 100644 index 0000000..427c025 --- /dev/null +++ b/agrifoodpy/pipeline/cli.py @@ -0,0 +1,58 @@ +import argparse +import json +from .pipeline import Pipeline + + +def main(): + parser = argparse.ArgumentParser( + description="Run an AgriFoodPy pipeline from a configuration file." + ) + + parser.add_argument( + "config", + help="Pipeline configuration YAML file" + ) + + parser.add_argument( + "--output", + help="Optional output file to store the datablock as JSON", + default=None + ) + + parser.add_argument( + "--show-datablock", + action="store_true", + help="Print the final datablock to stdout" + ) + + + parser.add_argument( + "--show-nodes", + action="store_true", + help="Print the final datablock to stdout" + ) + + parser.add_argument( + "--norun", + action="store_true", + help="Do not run the pipeline" + ) + + args = parser.parse_args() + + pipeline = Pipeline.read(args.config) + + if args.show_nodes: + pipeline.print_nodes() + + if not args.norun: + pipeline.run() + + datablock = pipeline.datablock + + if args.show_datablock: + print(json.dumps(datablock, indent=2, default=str)) + + if args.output: + with open(args.output, "w") as f: + json.dump(datablock, f, indent=2, default=str) \ No newline at end of file diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index 914d71d..908690b 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -8,11 +8,14 @@ from functools import wraps from inspect import signature import time +import yaml +import importlib class Pipeline(): '''Class for constructing and running pipelines of functions with individual sets of parameters.''' + def __init__(self, datablock=None): self.nodes = [] self.params = [] @@ -22,9 +25,16 @@ def __init__(self, datablock=None): else: self.datablock = {} + @staticmethod + def _load_function(path): + """Load a function from a dotted path.""" + module_path, func_name = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, func_name) + @classmethod def read(cls, filename): - """Read a pipeline from a configuration file + """Read a pipeline configuration from a YAML file Parameters ---------- @@ -36,7 +46,22 @@ def read(cls, filename): pipeline : Pipeline The pipeline object. """ - raise NotImplementedError("This method is not yet implemented.") + + with open(filename, "r") as f: + config = yaml.safe_load(f) + + pipeline = cls() + + for step in config["nodes"]: + func = cls._load_function(step["function"]) + params = step.get("params", {}) + name = step.get("name", func.__name__) + + pipeline.nodes.append(func) + pipeline.params.append(params) + pipeline.names.append(name) + + return pipeline def datablock_write(self, path, value): """Writes a single value to the datablock at the specified path. @@ -118,7 +143,6 @@ def remove_node(self, node): del self.params[index] del self.names[index] - def run(self, from_node=0, to_node=None, skip=None, timing=False): """Runs the pipeline @@ -130,7 +154,7 @@ def run(self, from_node=0, to_node=None, skip=None, timing=False): to_node : int, optional The index of the last node to be executed. If not provided, all nodes will be executed - + skip : list of int, optional List of node indices to skip during execution. Defaults to None. @@ -175,18 +199,17 @@ def run(self, from_node=0, to_node=None, skip=None, timing=False): def print_nodes(self, show_params=True): """Prints the list of nodes associated with a Pipeline instance. - + Parameters ---------- show_params : bool, optional If True, displays the parameters associated with each node. """ - if not self.nodes: print("Pipeline is empty.") return - + print("Pipeline nodes:") for i, (name, node, params) in enumerate(zip(self.names, self.nodes, self.params)): node_name = getattr(node, "__name__", repr(node)) @@ -195,6 +218,7 @@ def print_nodes(self, show_params=True): for k, v in params.items(): print(f" {k} = {v}") + def standalone(input_keys, return_keys): """ Decorator to make a pipeline node available as a standalone function diff --git a/setup.py b/setup.py index 046c7f3..7217e01 100644 --- a/setup.py +++ b/setup.py @@ -15,10 +15,10 @@ LONG_DESC_TYPE = "text/markdown" INSTALL_REQUIRES = [ - 'numpy', - 'pandas', - 'xarray', - 'matplotlib', + 'numpy', + 'pandas', + 'xarray', + 'matplotlib', ] setup(name=PACKAGE_NAME, @@ -31,5 +31,10 @@ author_email=AUTHOR_EMAIL, url=URL, install_requires=INSTALL_REQUIRES, - packages=find_packages() + packages=find_packages(), + entry_points={ + 'console_scripts': [ + 'agrifoodpy = agrifoodpy.pipeline.cli:main', + ] + } ) From 8434d548230f7e34b63a50c8e194f6e679ac43d0 Mon Sep 17 00:00:00 2001 From: jucordero Date: Thu, 12 Mar 2026 15:27:22 +0000 Subject: [PATCH 2/7] command line script and tests --- agrifoodpy/pipeline/cli.py | 57 +++++++++++++------ agrifoodpy/pipeline/pipeline.py | 17 +++--- .../pipeline/tests/data/empty_config.yml | 0 .../pipeline/tests/data/test_config.yml | 6 ++ agrifoodpy/pipeline/tests/test_cli.py | 27 +++++++++ 5 files changed, 81 insertions(+), 26 deletions(-) create mode 100644 agrifoodpy/pipeline/tests/data/empty_config.yml create mode 100644 agrifoodpy/pipeline/tests/data/test_config.yml create mode 100644 agrifoodpy/pipeline/tests/test_cli.py diff --git a/agrifoodpy/pipeline/cli.py b/agrifoodpy/pipeline/cli.py index 427c025..eb31670 100644 --- a/agrifoodpy/pipeline/cli.py +++ b/agrifoodpy/pipeline/cli.py @@ -1,9 +1,10 @@ import argparse import json +import sys from .pipeline import Pipeline -def main(): +def main(args=None): parser = argparse.ArgumentParser( description="Run an AgriFoodPy pipeline from a configuration file." ) @@ -14,45 +15,65 @@ def main(): ) parser.add_argument( + "-o", "--output", help="Optional output file to store the datablock as JSON", default=None ) parser.add_argument( - "--show-datablock", + "--nodes", action="store_true", - help="Print the final datablock to stdout" + help="Print the nodes and parameters to stdout" ) + parser.add_argument( + "--no-run", + action="store_false", + help="Do not run the pipeline" + ) parser.add_argument( - "--show-nodes", - action="store_true", - help="Print the final datablock to stdout" + "--from-node", + type=int, + help="Index of the first node to be executed" ) parser.add_argument( - "--norun", - action="store_true", - help="Do not run the pipeline" + "--to-node", + type=int, + help="Index of the last node to be executed" ) - args = parser.parse_args() + parser.add_argument( + "--skip-nodes", + nargs="+", + type=int, + help="List of nodes to be skipped in the pipeline execution" + ) + + # get system args if none passed + if args is None: + args = sys.argv[1:] + + args = parser.parse_args(args or ['--help']) pipeline = Pipeline.read(args.config) - if args.show_nodes: + if args.nodes: pipeline.print_nodes() - if not args.norun: - pipeline.run() + from_node = args.from_node if args.from_node is not None else 0 + to_node = args.to_node if args.to_node is not None else len(pipeline.nodes) + skip_nodes = args.skip_nodes if args.skip_nodes is not None else None - datablock = pipeline.datablock + if args.no_run: + pipeline.run(from_node=from_node, to_node=to_node, skip=skip_nodes) - if args.show_datablock: - print(json.dumps(datablock, indent=2, default=str)) + datablock = pipeline.datablock - if args.output: + if args.output is not None: with open(args.output, "w") as f: - json.dump(datablock, f, indent=2, default=str) \ No newline at end of file + json.dump(datablock, f, indent=2, default=str) + + return 0 \ No newline at end of file diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index 908690b..1dc13c2 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -52,14 +52,15 @@ def read(cls, filename): pipeline = cls() - for step in config["nodes"]: - func = cls._load_function(step["function"]) - params = step.get("params", {}) - name = step.get("name", func.__name__) - - pipeline.nodes.append(func) - pipeline.params.append(params) - pipeline.names.append(name) + if config is not None: + for step in config["nodes"]: + func = cls._load_function(step["function"]) + params = step.get("params", {}) + name = step.get("name", func.__name__) + + pipeline.nodes.append(func) + pipeline.params.append(params) + pipeline.names.append(name) return pipeline diff --git a/agrifoodpy/pipeline/tests/data/empty_config.yml b/agrifoodpy/pipeline/tests/data/empty_config.yml new file mode 100644 index 0000000..e69de29 diff --git a/agrifoodpy/pipeline/tests/data/test_config.yml b/agrifoodpy/pipeline/tests/data/test_config.yml new file mode 100644 index 0000000..6f8c8a1 --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config.yml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + params: + key: test_write + value: 200 + name: writing to datablock \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/test_cli.py b/agrifoodpy/pipeline/tests/test_cli.py new file mode 100644 index 0000000..fe3628a --- /dev/null +++ b/agrifoodpy/pipeline/tests/test_cli.py @@ -0,0 +1,27 @@ +from agrifoodpy.pipeline.cli import main +import pytest +import os + +def test_cli(tmp_path): + + # Test with no arguments + with pytest.raises(SystemExit) as e: + main() + assert e.value.code == 0 + + # Argparse help + with pytest.raises(SystemExit) as e: + main(['--help']) + assert e.value.code == 0 + + # Process test config file + script_dir = os.path.dirname(__file__) + config_filename = os.path.join(script_dir, "data/empty_config.yml") + output_filename = str(tmp_path / 'empty.json') + assert main([config_filename, '-o', output_filename]) == 0 + + # Process test config file + script_dir = os.path.dirname(__file__) + config_filename = os.path.join(script_dir, "data/test_config.yml") + output_filename = str(tmp_path / 'test.json') + assert main([config_filename, '-o', output_filename]) == 0 From 7c69d5d578f759aed139c2fa1ad0bc04bbf9d046 Mon Sep 17 00:00:00 2001 From: jucordero Date: Thu, 12 Mar 2026 19:46:09 +0000 Subject: [PATCH 3/7] CLI documentation --- agrifoodpy/pipeline/pipeline.py | 2 +- docs/conf.py | 2 +- docs/index.rst | 2 +- examples/README.rst | 2 +- examples/cli/README.rst | 13 ++++++ examples/cli/scaling_food_supply.yml | 66 ++++++++++++++++++++++++++++ 6 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 examples/cli/README.rst create mode 100644 examples/cli/scaling_food_supply.yml diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index 1dc13c2..9104a61 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -48,7 +48,7 @@ def read(cls, filename): """ with open(filename, "r") as f: - config = yaml.safe_load(f) + config = yaml.load(f, Loader=yaml.FullLoader) pipeline = cls() diff --git a/docs/conf.py b/docs/conf.py index 9b60950..c4ad751 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,6 +36,6 @@ sphinx_gallery_conf = { 'examples_dirs': '../examples', # path to examples scripts 'gallery_dirs': 'examples', # path to gallery generated examples - 'run_stale_examples': True, + 'run_stale_examples': False, 'download_all_examples': False, } \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 2728d15..bc904ca 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,7 +12,7 @@ Welcome to agrifoodpy's documentation! install examples/index - + config_file contributing .. include:: readme.md diff --git a/examples/README.rst b/examples/README.rst index d8c6771..762e5d4 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -3,4 +3,4 @@ Examples AgriFoodPy implements models and table manipulation methods to analyse and process agrifood data. In these examples we demonstrate the use of some of these -methods and models using python scripts. \ No newline at end of file +methods and models using python scripts, or a console script. \ No newline at end of file diff --git a/examples/cli/README.rst b/examples/cli/README.rst new file mode 100644 index 0000000..14b0bd1 --- /dev/null +++ b/examples/cli/README.rst @@ -0,0 +1,13 @@ +Command line tool +----------------- + +``agrifoodpy`` is a command line script that runs a pipeline of functions +defined in a config file. + +.. code-block:: console + + $ agrifoodpy examples/cli/scaling_food_supply.yml -o results.json + +For more information on how to configure a pipeline using configuration files +and execute them using the ``agrifoodpy`` command line tool, see +:ref:`config_file`. \ No newline at end of file diff --git a/examples/cli/scaling_food_supply.yml b/examples/cli/scaling_food_supply.yml new file mode 100644 index 0000000..d042a0a --- /dev/null +++ b/examples/cli/scaling_food_supply.yml @@ -0,0 +1,66 @@ +nodes: + - function: agrifoodpy.utils.nodes.load_dataset + name: "Load dataset" + params: + datablock_path: "food" + module: "agrifoodpy_data.food" + data_attr: "FAOSTAT" + coords: { + Item: [2731, 2511], + Year: [2019, 2020], + Region: 229} + + - function: agrifoodpy.utils.nodes.write_to_datablock + name: "Add convertion factor from 1000 tonnes to kg" + params: + key: "tonnes_to_kg" + value: 1000000 + + - function: agrifoodpy.food.model.fbs_convert + name: "Convert fbs from 1000 tonnes to kg" + params: + fbs: "food" + convertion_arr: "tonnes_to_kg" + + - function: agrifoodpy.food.model.SSR + name: "Calculate SSR" + params: + fbs: "food" + out_key: "SSR" + + - function: agrifoodpy.food.model.IDR + name: "Calculate IDR" + params: + fbs: "food" + out_key: "IDR" + + - function: agrifoodpy.utils.nodes.print_datablock + name: "Print SSR" + params: + key: "SSR" + method: "to_numpy" + preffix: "SSR: values " + + - function: agrifoodpy.utils.nodes.add_items + name: "Add item names to the datablock" + params: + dataset: "food" + copy_from: 2731 + items: { + "Item": 5000, + "Item_name": "Cultured meat", + "Item_group": "Cultured products", + "Item_origin": "Synthetic origin", + } + + - function: agrifoodpy.utils.nodes.add_years + name: "Add years to the fbs" + params: + dataset: "food" + years: [2021, 2025, 2030] + projection: [1.1, 1.5, 10] + + - function: agrifoodpy.utils.nodes.print_datablock + name: "Print FBS" + params: + key: "food" \ No newline at end of file From 6c001689bda58dbe1967d838d47d780a163f0f90 Mon Sep 17 00:00:00 2001 From: jucordero Date: Thu, 12 Mar 2026 19:52:39 +0000 Subject: [PATCH 4/7] missing config_file --- docs/config_file.rst | 53 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 docs/config_file.rst diff --git a/docs/config_file.rst b/docs/config_file.rst new file mode 100644 index 0000000..63eac7b --- /dev/null +++ b/docs/config_file.rst @@ -0,0 +1,53 @@ +.. _config_file: + +Command line tool +================= + +The ``agrifoodpy`` command line tool allows you to run a pipeline of functions +defined in a configuration file. This is useful for automating workflows and +reproducibility. You can specify the configuration file and an output file for +the results. + +Executing the command line tool +------------------------------- + +To execute the command line tool, use the following syntax: + +.. code-block:: console + + $ agrifoodpy -o + +The following options are available for the command line tool: + +:\-o \-\-output: Specify the output file for the results. The output will be saved in JSON format. + +:\-\-nodes: Print the nodes and parameters to stdout + +:\-\-no-run: Do not run the pipeline + +:\-\-from-node: Index of the first node to be executed + +:\-\-to-node: Index of the last node to be executed + +:\-\-skip-nodes: List of nodes to be skipped in the pipeline execution + + +Configuration files +------------------- + +Configuration files are YAML files that define a pipeline of functions to be +executed by the ``agrifoodpy`` command line tool. Each function is specified +with its name and parameters, and the pipeline is executed in the order they +are defined. + +.. literalinclude:: ../examples/cli/scaling_food_supply.yml + :language: YAML + :caption: Example of a configuration file for scaling a food balance sheet. + +Each node is defined with a function to execute, and its parameters and, +optionally, a name. The function is specified in the format ``module.function``, +where ``module`` is the name of the module containing the function, +and ``function`` is the name of the function to be executed. +The parameters are specified as a dictionary of key-value pairs, +where the keys are the parameter names. + From 07a0d69eeb3ad0fea69550fa8bc195e89db693fc Mon Sep 17 00:00:00 2001 From: jucordero Date: Mon, 30 Mar 2026 12:49:32 +0100 Subject: [PATCH 5/7] pyyaml dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 7217e01..6ab8122 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ 'pandas', 'xarray', 'matplotlib', + 'pyyaml' ] setup(name=PACKAGE_NAME, From 9734d3188ecd855d0e9f1e98136bf6a4162a403c Mon Sep 17 00:00:00 2001 From: jucordero Date: Wed, 1 Apr 2026 13:24:48 +0100 Subject: [PATCH 6/7] nodes to write outputs --- agrifoodpy/pipeline/cli.py | 10 ++- agrifoodpy/utils/nodes.py | 171 +++++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 3 deletions(-) diff --git a/agrifoodpy/pipeline/cli.py b/agrifoodpy/pipeline/cli.py index eb31670..1f9cdf4 100644 --- a/agrifoodpy/pipeline/cli.py +++ b/agrifoodpy/pipeline/cli.py @@ -58,20 +58,24 @@ def main(args=None): args = parser.parse_args(args or ['--help']) + # read pipeline configuration and set pipeline object pipeline = Pipeline.read(args.config) - if args.nodes: - pipeline.print_nodes() - from_node = args.from_node if args.from_node is not None else 0 to_node = args.to_node if args.to_node is not None else len(pipeline.nodes) skip_nodes = args.skip_nodes if args.skip_nodes is not None else None + # print the nodes and parameters if requested + if args.nodes: + pipeline.print_nodes() + + # run the pipeline if not skipped if args.no_run: pipeline.run(from_node=from_node, to_node=to_node, skip=skip_nodes) datablock = pipeline.datablock + # write outputs if args.output is not None: with open(args.output, "w") as f: json.dump(datablock, f, indent=2, default=str) diff --git a/agrifoodpy/utils/nodes.py b/agrifoodpy/utils/nodes.py index 77a7dfe..d8982ad 100644 --- a/agrifoodpy/utils/nodes.py +++ b/agrifoodpy/utils/nodes.py @@ -1,5 +1,8 @@ import copy +import json +import os import xarray as xr +import numpy as np import importlib from ..pipeline import standalone @@ -287,3 +290,171 @@ def load_dataset( datablock[datablock_path] = dataset * scale return datablock + +def _tuple_to_str(tup): + return ".".join(str(x) for x in tup) + +def write_json(datablock, key, path, indent=2): + """Writes a datablock value to a JSON file. + + Parameters + ---------- + datablock : dict + The datablock to read from. + key : str, tuple or list + List of datablock keys (or tuple of keys for nested access) of the + value to write. + path : str + Output file path. + indent : int, optional + JSON indentation level. Defaults to 2. + + Returns + ------- + datablock : dict + Unmodified datablock. + """ + + def _default(o): + if isinstance(o, (xr.Dataset, xr.DataArray)): + return o.to_dict() + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + return str(o) + + if isinstance(key, str): + key = [key] + elif isinstance(key, tuple): + key = [key] + + obj_dict = {} + + for obj_key in key: + if isinstance(obj_key, tuple): + obj_dict[_tuple_to_str(obj_key)] = get_dict(datablock, obj_key) + else: + obj_dict[obj_key] = get_dict(datablock, obj_key) + + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + with open(path, "w") as f: + json.dump(obj_dict, f, indent=indent, default=_default) + + return datablock + + +def write_netcdf(datablock, key, path): + """Writes a datablock xarray Dataset or DataArray to a NetCDF file. + + Parameters + ---------- + datablock : dict + The datablock to read from. + key : str or tuple + Datablock key (or tuple of keys for nested access) of the dataset to + write. + path : str + Output file path. + + Returns + ------- + datablock : dict + Unmodified datablock. + """ + + obj = get_dict(datablock, key) + + if not isinstance(obj, (xr.Dataset, xr.DataArray)): + raise TypeError( + f"write_netcdf only supports xarray Dataset or DataArray objects. " + f"Got {type(obj).__name__} instead." + ) + + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + obj.to_netcdf(path) + + return datablock + + +def write_csv(datablock, key, path, index=True): + """Writes a datablock value to a CSV file. + + Supports xarray Dataset, xarray DataArray, pandas DataFrame and Series. + xarray objects are converted to a DataFrame before writing; the + multi-index produced by that conversion (which encodes coordinates) is + written when index=True. + + Parameters + ---------- + datablock : dict + The datablock to read from. + key : str, list or tuple + Datablock key (or tuple of keys for nested access) of the value to write. + path : str + Output file path. + index : bool, optional + Whether to write the row index. Defaults to True. + + Returns + ------- + datablock : dict + Unmodified datablock. + """ + import pandas as pd + + obj = get_dict(datablock, key) + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + if isinstance(obj, (xr.Dataset, xr.DataArray)): + try: + obj.to_dataframe().to_csv(path, index=index) + except ValueError: + # Unnamed DataArrays raise a ValueError. + + if isinstance(key, tuple): + obj.name = key[-1] + else: + obj.name = str(key) + + obj.to_dataframe().to_csv(path, index=index) + + elif isinstance(obj, (pd.DataFrame, pd.Series)): + obj.to_csv(path, index=index) + else: + raise TypeError( + f"write_csv does not support objects of type {type(obj).__name__}. " + "Expected xr.Dataset, xr.DataArray, pd.DataFrame, or pd.Series." + ) + + return datablock + + +def write_text(datablock, key, path): + """Writes the string representation of a datablock value to a text file. + + Parameters + ---------- + datablock : dict + The datablock to read from. + key : str or tuple + Datablock key (or tuple of keys for nested access) of the value to write. + path : str + Output file path. + + Returns + ------- + datablock : dict + Unmodified datablock. + """ + + obj = get_dict(datablock, key) + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + with open(path, "w") as f: + f.write(str(obj)) + + return datablock From 13e9096f27b8236f504102c6d27c06c97439c537 Mon Sep 17 00:00:00 2001 From: jucordero Date: Wed, 1 Apr 2026 15:28:19 +0100 Subject: [PATCH 7/7] units tests for write to disk node functions --- agrifoodpy/utils/nodes.py | 3 +- agrifoodpy/utils/tests/test_write_to_disk.py | 127 +++++++++++++++++++ 2 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 agrifoodpy/utils/tests/test_write_to_disk.py diff --git a/agrifoodpy/utils/nodes.py b/agrifoodpy/utils/nodes.py index d8982ad..5e3cefb 100644 --- a/agrifoodpy/utils/nodes.py +++ b/agrifoodpy/utils/nodes.py @@ -414,12 +414,11 @@ def write_csv(datablock, key, path, index=True): obj.to_dataframe().to_csv(path, index=index) except ValueError: # Unnamed DataArrays raise a ValueError. - if isinstance(key, tuple): obj.name = key[-1] else: obj.name = str(key) - + obj.to_dataframe().to_csv(path, index=index) elif isinstance(obj, (pd.DataFrame, pd.Series)): diff --git a/agrifoodpy/utils/tests/test_write_to_disk.py b/agrifoodpy/utils/tests/test_write_to_disk.py new file mode 100644 index 0000000..7afd509 --- /dev/null +++ b/agrifoodpy/utils/tests/test_write_to_disk.py @@ -0,0 +1,127 @@ +def test_write_csv(): + from agrifoodpy.utils.nodes import write_csv + import pandas as pd + import xarray as xr + + # Test unsupported type + datablock_unsupported = {"data": set([1, 2, 3])} + try: + write_csv(datablock_unsupported, key="data", path="test.csv") + except TypeError as e: + assert str(e) == "write_csv does not support objects of type set. Expected xr.Dataset, xr.DataArray, pd.DataFrame, or pd.Series." + + # Test Pandas DataFrame + df = pd.DataFrame( + { + "Item": ["Beef", "Beef", "Apples", "Apples"], + "Year": [2020, 2021, 2020, 2021], + "production": [15, 20, 6, 7] + } + ) + + datablock_df = {"data": df} + path = "test_df_output.csv" + + write_csv(datablock_df, key="data", path=path) + + df = pd.read_csv(path, index_col=0, header=0) + assert df.shape == (4, 3) + assert list(df.columns) == ["Item", "Year", "production"] + assert df.iloc[0]["Item"] == "Beef" + assert df.iloc[0]["Year"] == 2020 + assert df.iloc[0]["production"] == 15 + + assert df.iloc[1]["Item"] == "Beef" + assert df.iloc[1]["Year"] == 2021 + assert df.iloc[1]["production"] == 20 + + # Test Pandas Series + series = pd.Series( + [15, 6, 30], + index=["Beef", "Apples", "Poultry"], + name="production", + ) + + datablock_series = {"data": series} + path = "test_series_output.csv" + + write_csv(datablock_series, key="data", path=path) + + df = pd.read_csv(path, index_col=0, header=0) + assert df.shape == (3, 1) + assert list(df.columns) == ["production"] + assert df.loc["Beef", "production"] == 15 + assert df.loc["Apples", "production"] == 6 + assert df.loc["Poultry", "production"] == 30 + + # Test xarray Dataset + ds = xr.Dataset( + { + "production": (("Item", "Year"), [[15, 20], [6, 7]]), + "imports": (("Item", "Year"), [[4, 5], [1, 2]]), + }, + coords={ + "Item": ["Beef", "Apples"], + "Year": ["2020", "2021"], + }, + ) + + datablock_ds = {"data": ds} + path = "test_ds_output.csv" + + write_csv(datablock_ds, key="data", path=path) + + df = pd.read_csv(path, header=0) + + assert df.shape == (4, 4) + assert list(df.columns) == ["Item", "Year", "production", "imports"] + assert df.iloc[0]["Item"] == "Beef" + assert df.iloc[0]["Year"] == 2020 + assert df.iloc[0]["production"] == 15 + assert df.iloc[0]["imports"] == 4 + + # Test xarray DataArray + da = xr.DataArray( + [[15, 20], [6, 7]], + coords={ + "Item": ["Beef", "Apples"], + "Year": ["2020", "2021"], + }, + dims=["Item", "Year"], + name="production", + ) + + datablock_da = {"data": da} + path = "test_da_output.csv" + + write_csv(datablock_da, key="data", path=path) + + df = pd.read_csv(path, header=0) + + assert df.shape == (4, 3) + assert list(df.columns) == ["Item", "Year", "production"] + assert df.iloc[0]["Item"] == "Beef" + assert df.iloc[0]["Year"] == 2020 + assert df.iloc[0]["production"] == 15 + + # Test unnamed xarray DataArray + datablock_key = "data" + + da_unnamed = xr.DataArray( + [[15, 20], [6, 7]], + coords={ + "Item": ["Beef", "Apples"], + "Year": ["2020", "2021"], + }, + dims=["Item", "Year"], + ) + + datablock_da_unnamed = {datablock_key: da_unnamed} + path = "test_da_unnamed_output.csv" + + write_csv(datablock_da_unnamed, key=datablock_key, path=path) + + df = pd.read_csv(path, header=0) + + assert df.shape == (4, 3) + assert list(df.columns) == ["Item", "Year", datablock_key] \ No newline at end of file