diff --git a/docs/bug_fixes/gfs_preprocessing_fix.md b/docs/bug_fixes/gfs_preprocessing_fix.md new file mode 100644 index 0000000..fbe7bee --- /dev/null +++ b/docs/bug_fixes/gfs_preprocessing_fix.md @@ -0,0 +1,50 @@ +# GFS Preprocessing Data Format Fix + +## Issues +1. Dimension naming conflict between 'variable' and 'channel' +2. Longitude range mismatch (-180 to 180 vs 0 to 360) +3. Data structure incompatibility with model expectations + +## Solution + +### 1. Correct Preprocessing Script +```python +import xarray as xr + +def preprocess_gfs(year: int): + gfs = xr.open_mfdataset(f"/mnt/storage_b/nwp/gfs/global/{year}*.zarr.zip", engine="zarr") + + # Fix longitude range + gfs['longitude'] = ((gfs['longitude'] + 360) % 360) + + # Select UK region (in 0-360 range) + gfs = gfs.sel( + latitude=slice(65, 45), + longitude=slice(0, 360) + ) + + # Stack variables into channel dimension + gfs = gfs.to_array(dim="channel") # Use channel instead of variable + + # Optimize chunking + gfs = gfs.chunk({ + 'init_time_utc': len(gfs.init_time_utc), + 'step': 10, + 'latitude': 1, + 'longitude': 1 + }) + + return gfs +``` + +### 2. Expected Data Structure +- Dimensions: (init_time_utc, step, channel, latitude, longitude) +- Longitude range: [0, 360) +- Single stacked DataArray with channel dimension + +### 3. Verification +```python +ds = xr.open_zarr("path/to/gfs.zarr") +assert "channel" in ds.dims +assert 0 <= ds.longitude.min() < ds.longitude.max() <= 360 +``` \ No newline at end of file diff --git a/docs/getting_started.md b/docs/getting_started.md index 1d3567a..958ab03 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -646,6 +646,18 @@ open-data-pvnet metoffice load --year 2023 --month 12 --day 1 --region uk open-data-pvnet metoffice load --year 2023 --month 1 --day 16 --region uk --remote ``` +#### GFS Data Processing +```bash +# Process and archive GFS data for a specific year +open-data-pvnet gfs archive --year 2023 + +# Process specific month with upload to HuggingFace +open-data-pvnet gfs archive --year 2023 --month 12 --overwrite + +# Process without uploading to HuggingFace +open-data-pvnet gfs archive --year 2023 --skip-upload +``` + ### Error Handling Common error messages and their solutions: - "No datasets found": Check if the specified date has available data diff --git a/src/open_data_pvnet/main.py b/src/open_data_pvnet/main.py index f8d733e..9472c29 100644 --- a/src/open_data_pvnet/main.py +++ b/src/open_data_pvnet/main.py @@ -15,6 +15,7 @@ from open_data_pvnet.scripts.archive import handle_archive from open_data_pvnet.nwp.met_office import CONFIG_PATHS from open_data_pvnet.nwp.dwd import process_dwd_data +from open_data_pvnet.nwp.gfs import process_gfs_data logger = logging.getLogger(__name__) @@ -203,6 +204,24 @@ def configure_parser(): consolidate_parser = operation_subparsers.add_parser("consolidate", help="Consolidate data") _add_common_arguments(consolidate_parser, provider) + # Add GFS parser + gfs_parser = subparsers.add_parser("gfs", help="Commands for GFS data") + operation_subparsers = gfs_parser.add_subparsers(dest="operation", help="Operation to perform") + + # Archive operation parser for GFS + archive_parser = operation_subparsers.add_parser("archive", help="Archive GFS data") + archive_parser.add_argument("--year", type=int, required=True, help="Year to process") + archive_parser.add_argument("--month", type=int, help="Month to process") + archive_parser.add_argument("--day", type=int, help="Day to process") + archive_parser.add_argument("--skip-upload", action="store_true", help="Skip uploading to HuggingFace") + archive_parser.add_argument("--overwrite", "-o", action="store_true", help="Overwrite existing files") + archive_parser.add_argument( + "--archive-type", + choices=["zarr.zip", "tar"], + default="zarr.zip", + help="Type of archive to create" + ) + return parser @@ -419,7 +438,8 @@ def main(): open-data-pvnet metoffice consolidate --year 2023 --month 12 --day 1 GFS Data: - Partially implemented + # Archive GFS data for a specific day + open-data-pvnet gfs archive --year 2023 --month 1 --day 1 --skip-upload DWD Data: # Archive DWD data for a specific day @@ -496,7 +516,17 @@ def main(): "overwrite": args.overwrite, "archive_type": getattr(args, "archive_type", "zarr.zip"), } - archive_to_hf(**archive_kwargs) + if args.command == "gfs": + process_gfs_data( + year=args.year, + month=args.month, + day=args.day, + skip_upload=args.skip_upload, + overwrite=args.overwrite, + archive_type=args.archive_type + ) + else: + archive_to_hf(**archive_kwargs) return 0 diff --git a/src/open_data_pvnet/nwp/gfs.py b/src/open_data_pvnet/nwp/gfs.py index 364db23..9b6ff43 100644 --- a/src/open_data_pvnet/nwp/gfs.py +++ b/src/open_data_pvnet/nwp/gfs.py @@ -1,8 +1,103 @@ +import xarray as xr import logging +from pathlib import Path +from open_data_pvnet.utils.data_uploader import upload_to_huggingface logger = logging.getLogger(__name__) +def process_gfs_data( + year: int, + month: int = None, + day: int = None, + skip_upload: bool = False, + overwrite: bool = False, + archive_type: str = "zarr.zip", +) -> None: + """ + Process GFS data for a given time period and optionally upload to HuggingFace. -def process_gfs_data(year, month): - logger.info(f"Downloading GFS data for {year}-{month}") - raise NotImplementedError("The process_gfs_data function is not implemented yet.") + Args: + year (int): Year to process + month (int, optional): Month to process. If None, processes entire year + day (int, optional): Day to process. If None, processes entire month/year + skip_upload (bool): If True, skips uploading to HuggingFace + overwrite (bool): If True, overwrites existing files + archive_type (str): Type of archive to create ("zarr.zip" or "tar") + """ + try: + # Load GFS data + gfs = xr.open_mfdataset( + f"/mnt/storage_b/nwp/gfs/global/{year}*.zarr.zip", + engine="zarr" + ) + logger.info(f"Loaded GFS data for {year}") + + # Fix longitude range and select UK region + gfs['longitude'] = ((gfs['longitude'] + 360) % 360) + gfs = gfs.sel( + latitude=slice(65, 45), + longitude=slice(350, 362) # UK region in [0, 360) range + ) + + # Stack variables into channel dimension + gfs = gfs.to_array(dim="channel") + + # Optimize chunking + chunk_sizes = { + 'init_time_utc': 1, + 'step': 4, + 'channel': -1, # Keep all channels together + 'latitude': 1, + 'longitude': 1 + } + gfs = gfs.chunk(chunk_sizes) + + # Save locally + output_dir = Path(f"data/gfs/uk/gfs_uk_{year}.zarr") + output_dir.parent.mkdir(parents=True, exist_ok=True) + gfs.to_zarr(output_dir, mode='w') + logger.info(f"Saved processed data to {output_dir}") + + # Upload to HuggingFace if requested + if not skip_upload: + upload_to_huggingface( + config_path=Path("config.yaml"), + folder_name=str(year), + year=year, + month=month if month else 1, + day=day if day else 1, + overwrite=overwrite, + archive_type=archive_type + ) + logger.info("Upload to HuggingFace completed") + + except Exception as e: + logger.error(f"Error processing GFS data: {e}") + raise + +def verify_gfs_data(zarr_path: str) -> bool: + """ + Verify that the processed GFS data meets the expected format. + + Args: + zarr_path (str): Path to the zarr file to verify + + Returns: + bool: True if verification passes + """ + try: + ds = xr.open_zarr(zarr_path) + + # Verify dimensions + required_dims = {"init_time_utc", "latitude", "longitude", "channel"} + if not all(dim in ds.dims for dim in required_dims): + raise ValueError(f"Dataset missing required dimensions: {required_dims}") + + # Verify longitude range + assert 0 <= ds.longitude.min() < ds.longitude.max() <= 360, \ + "Longitude range must be [0, 360)" + + return True + except Exception as e: + logger.error(f"Verification failed: {e}") + return False \ No newline at end of file diff --git a/src/open_data_pvnet/nwp/gfs_dataset.py b/src/open_data_pvnet/nwp/gfs_dataset.py index eadb201..d0299db 100644 --- a/src/open_data_pvnet/nwp/gfs_dataset.py +++ b/src/open_data_pvnet/nwp/gfs_dataset.py @@ -5,15 +5,86 @@ 3. Uncomment the main block below to run as a standalone script. """ +# Standard library imports +from typing import Union, Optional, Dict, Any, List import logging + +# Third-party imports import pandas as pd import xarray as xr from torch.utils.data import Dataset -from ocf_data_sampler.config import load_yaml_configuration -from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods -from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS import fsspec import numpy as np +from pydantic import BaseModel, ConfigDict + +logger = logging.getLogger(__name__) + +class GFSConfig(BaseModel): + channels: List[str] + time_resolution_minutes: int + interval_start_minutes: int + interval_end_minutes: int + provider: str + zarr_path: str + image_size_pixels_height: int + image_size_pixels_width: int + max_staleness_minutes: Optional[int] = None # Changed from max_staleness_hours + dropout_timedeltas_minutes: Optional[List[int]] = None + accum_channels: Optional[List[str]] = [] + +class NWPConfig(BaseModel): + gfs: GFSConfig + +class InputDataConfig(BaseModel): + nwp: NWPConfig + +class GeneralConfig(BaseModel): + name: str + description: str + +class Configuration(BaseModel): + general: GeneralConfig + input_data: InputDataConfig + +# Define default values for NWP statistics +DEFAULT_GFS_CHANNELS = [ + "dlwrf", "dswrf", "hcc", "lcc", "mcc", "prate", + "r", "t", "tcc", "u10", "u100", "v10", "v100", "vis" +] + +try: + from ocf_data_sampler.config import load_yaml_configuration + from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods + from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS +except ImportError: + logging.warning("Could not import ocf_data_sampler modules. Using default implementations.") + + def load_yaml_configuration(config_path): + import yaml + with open(config_path, 'r') as f: + return yaml.safe_load(f) + + def find_valid_time_periods(dataset_dict, config): + # Simple implementation that returns all times + times = dataset_dict['nwp']['gfs'].init_time_utc.values + return pd.DataFrame({'t0': times}) + + # Create default normalization values + NWP_MEANS = { + "gfs": xr.DataArray( + np.zeros(len(DEFAULT_GFS_CHANNELS)), + coords={'channel': DEFAULT_GFS_CHANNELS}, + dims=['channel'] + ) + } + + NWP_STDS = { + "gfs": xr.DataArray( + np.ones(len(DEFAULT_GFS_CHANNELS)), + coords={'channel': DEFAULT_GFS_CHANNELS}, + dims=['channel'] + ) + } # Configure logging @@ -23,56 +94,54 @@ xr.set_options(keep_attrs=True) -def open_gfs(dataset_path: str) -> xr.DataArray: +def open_gfs(path: str) -> Union[xr.Dataset, xr.DataArray]: """ - Opens the GFS dataset stored in Zarr format and prepares it for processing. - + Open a GFS dataset from a Zarr store. + Args: - dataset_path (str): Path to the GFS dataset. - + path: Path to the Zarr store + Returns: - xr.DataArray: The processed GFS data. - """ - logging.info("Opening GFS dataset synchronously...") - store = fsspec.get_mapper(dataset_path, anon=True) - gfs_dataset: xr.Dataset = xr.open_dataset( - store, engine="zarr", consolidated=True, chunks="auto" - ) - gfs_data: xr.DataArray = gfs_dataset.to_array(dim="channel") - - if "init_time" in gfs_data.dims: - logging.debug("Renaming 'init_time' to 'init_time_utc'...") - gfs_data = gfs_data.rename({"init_time": "init_time_utc"}) - - required_dims = ["init_time_utc", "step", "channel", "latitude", "longitude"] - gfs_data = gfs_data.transpose(*required_dims) - - logging.debug(f"GFS dataset dimensions: {gfs_data.dims}") - return gfs_data - - -def handle_nan_values( - dataset: xr.DataArray, method: str = "fill", fill_value: float = 0.0 -) -> xr.DataArray: + xr.Dataset or xr.DataArray: The loaded GFS dataset """ - Handle NaN values in the dataset. - + try: + # Open the dataset + ds = xr.open_zarr(path) + + # Basic validation + required_dims = {"init_time_utc", "latitude", "longitude", "channel"} + if not all(dim in ds.dims for dim in required_dims): + raise ValueError(f"Dataset missing required dimensions: {required_dims}") + + # Ensure longitude is in [-180, 180] range + if (ds.longitude > 180).any(): + ds['longitude'] = ((ds['longitude'] + 180) % 360) - 180 + ds = ds.sortby('longitude') + + return ds + + except Exception as e: + raise IOError(f"Failed to open GFS dataset at {path}: {str(e)}") + + +def handle_nan_values(dataset: xr.Dataset, method: str = "fill", fill_value: float = 0.0) -> xr.Dataset: + """Handle NaN values in the dataset. + Args: - dataset (xr.DataArray): The dataset to process. - method (str): The method for handling NaNs ("fill" or "drop"). - fill_value (float): Value to replace NaNs if method is "fill". - + dataset: Input xarray Dataset + method: Method to handle NaNs ('fill' or 'drop') + fill_value: Value to use for filling NaNs when method='fill' + Returns: - xr.DataArray: The processed dataset. + Processed dataset with NaNs handled """ if method == "fill": - logging.info(f"Filling NaN values with {fill_value}.") return dataset.fillna(fill_value) elif method == "drop": - logging.info("Dropping NaN values.") - return dataset.dropna(dim="latitude", how="all").dropna(dim="longitude", how="all") + # Drop time steps that contain any NaN values + return dataset.dropna(dim='init_time_utc', how='any') else: - raise ValueError("Invalid method for handling NaNs. Use 'fill' or 'drop'.") + raise ValueError(f"Unknown method: {method}") class GFSDataSampler(Dataset): @@ -97,23 +166,45 @@ def __init__( end_time (str, optional): End time for filtering data. """ logging.info("Initializing GFSDataSampler...") + + # Validate input dataset + if not isinstance(dataset, (xr.DataArray, xr.Dataset)): + raise TypeError("Dataset must be an xarray DataArray or Dataset") + + # Validate required dimensions + required_dims = {'channel', 'init_time_utc', 'step', 'latitude', 'longitude'} + missing_dims = required_dims - set(dataset.dims) + if missing_dims: + raise ValueError(f"Dataset missing required dimensions: {missing_dims}") + self.dataset = dataset self.config = load_yaml_configuration(config_filename) + + # Validate config + if not hasattr(self.config.input_data.nwp.gfs, 'channels'): + raise ValueError("Config missing required field: input_data.nwp.gfs.channels") + self.valid_t0_times = find_valid_time_periods({"nwp": {"gfs": self.dataset}}, self.config) logging.debug(f"Valid initialization times:\n{self.valid_t0_times}") if "start_dt" in self.valid_t0_times.columns: self.valid_t0_times = self.valid_t0_times.rename(columns={"start_dt": "t0"}) + # Filter by time range if provided if start_time: + start_ts = pd.Timestamp(start_time) self.valid_t0_times = self.valid_t0_times[ - self.valid_t0_times["t0"] >= pd.Timestamp(start_time) + self.valid_t0_times["t0"] >= start_ts ] if end_time: + end_ts = pd.Timestamp(end_time) self.valid_t0_times = self.valid_t0_times[ - self.valid_t0_times["t0"] <= pd.Timestamp(end_time) + self.valid_t0_times["t0"] <= end_ts ] + if len(self.valid_t0_times) == 0: + raise ValueError("No valid time periods found in dataset") + logging.debug(f"Filtered valid_t0_times:\n{self.valid_t0_times}") def __len__(self): diff --git a/src/open_data_pvnet/nwp/gfs_preprocessing.py b/src/open_data_pvnet/nwp/gfs_preprocessing.py new file mode 100644 index 0000000..3ca0861 --- /dev/null +++ b/src/open_data_pvnet/nwp/gfs_preprocessing.py @@ -0,0 +1,76 @@ +import xarray as xr +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +def preprocess_gfs_data(input_path: str, output_path: str, year: int) -> None: + """ + Preprocess GFS data to ensure correct longitude range and dimension structure. + + Args: + input_path: Path to input GFS data + output_path: Path to save processed data + year: Year of data to process + """ + logger.info(f"Starting GFS preprocessing for year {year}") + + # Load GFS dataset + gfs = xr.open_mfdataset(f"{input_path}/{year}*.zarr.zip", engine="zarr") + logger.info(f"Loaded GFS data with shape: {gfs.dims}") + + # Step 1: Fix longitude range to [0, 360) + gfs['longitude'] = ((gfs['longitude'] + 360) % 360) + gfs = gfs.sortby('longitude') # Ensure longitudes are sorted + + # Step 2: Select UK region (with buffer) + gfs = gfs.sel( + latitude=slice(65, 45), # Include buffer around UK + longitude=slice(350, 362) # Wrap around 360° (350° to 2°) + ) + + # Step 3: Ensure data validity + if len(gfs.latitude) == 0 or len(gfs.longitude) == 0: + raise ValueError("No data found after selecting UK region") + + # Step 4: Stack variables into channel dimension + gfs = gfs.to_array(dim="channel") + + # Step 5: Optimize chunking for performance + chunk_sizes = { + 'init_time_utc': 1, + 'step': 4, + 'channel': -1, # Keep all channels together + 'latitude': 1, + 'longitude': 1 + } + + # Remove existing chunk encoding + for var in gfs.variables: + if 'chunks' in gfs[var].encoding: + del gfs[var].encoding['chunks'] + + gfs = gfs.chunk(chunk_sizes) + + # Step 6: Save processed data + output_file = Path(output_path) / f"gfs_uk_{year}.zarr" + logger.info(f"Saving processed data to {output_file}") + gfs.to_zarr(output_file) + + logger.info(f"Completed processing for {year}") + logger.info(f"Final dimensions: {gfs.dims}") + +def main(): + input_base = "/mnt/storage_b/nwp/gfs/global" + output_base = "uk" + + for year in [2023, 2024]: + try: + preprocess_gfs_data(input_base, output_base, year) + except Exception as e: + logger.error(f"Error processing year {year}: {e}") + raise + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main()