From 7656d4c301ebed5db90dc23af9d2a361e68dad1e Mon Sep 17 00:00:00 2001 From: jeremiah Date: Thu, 30 Apr 2026 17:45:22 -0700 Subject: [PATCH] rework FittingData --- casm/project/fit/_FittingData.py | 166 +++++++++++++++++++------------ 1 file changed, 101 insertions(+), 65 deletions(-) diff --git a/casm/project/fit/_FittingData.py b/casm/project/fit/_FittingData.py index 81d2602..284539b 100644 --- a/casm/project/fit/_FittingData.py +++ b/casm/project/fit/_FittingData.py @@ -55,8 +55,8 @@ def __init__(self, proj: "Project", id: str): """dict: A description of the fit, read from `meta.json`.""" self.names = None - """Optional[list[str]]: Names of the configurations, a length `n_configs` list, - if given.""" + """Optional[np.ndarray]: Names of the configurations, a shape=(n_configs,) + array of strings, if given.""" self.parametric_compositions = None """Optional[np.ndarray]: Parametric compositions of all the configurations, a @@ -77,8 +77,27 @@ def __init__(self, proj: "Project", id: str): # load data self.load() + def from_dict(self, data): + """Set fitting data attributes from a dictionary. + + Parameters + ---------- + data : dict + A dictionary containing `names`, `parametric_compositions`, + `mol_compositions`, `correlations_per_unitcell`, and optionally + `formation_energies`. + """ + self.names = np.array(data["names"]) + self.parametric_compositions = np.array(data["parametric_compositions"]) + self.mol_compositions = np.array(data["mol_compositions"]) + self.correlations_per_unitcell = np.array(data["correlations_per_unitcell"]) + if data.get("formation_energies") is not None: + self.formation_energies = np.array(data["formation_energies"]) + else: + self.formation_energies = None + def load(self): - """Read meta.json + """Read meta.json and fitting_data.json This will replace the current contents of this FittingData object with the contents of the associated files, or set the current contents to None if the @@ -89,8 +108,14 @@ def load(self): path = self.fit_dir / "meta.json" self.meta = read_optional(path, default=dict()) + # read fitting_data.json if it exists + path = self.fit_dir / "fitting_data.json" + data = read_optional(path, default=None) + if data is not None: + self.from_dict(data) + def commit(self, verbose: bool = True): - """Write meta.json + """Write meta.json and fitting_data.json If the data does not exist in this object, this will erase the associated files if they do exist. @@ -114,6 +139,14 @@ def commit(self, verbose: bool = True): elif path.exists(): path.unlink() + # write fitting_data.json + path = self.fit_dir / "fitting_data.json" + if self.names is not None: + data = self.to_dict() + safe_dump(data=data, path=path, quiet=quiet, force=True) + elif path.exists(): + path.unlink() + def clear(self): """Clear fitting data""" # TODO @@ -132,39 +165,6 @@ def __repr__(self): return s.strip() - @staticmethod - def from_dict(data): - """Construct FittingData from a dictionary - - Parameters - ---------- - data : dict - A dictionary containing `names`, `parametric_compositions` - `mol_compositions`, `correlations_per_unitcell` and `formation_energies` - of the configurations - Note that `formation_energies` can be None - - Returns - ------- - fitting_data : FittingData - :class:`FittingData` with `names`, `parametric_compositions`, - `mol_compositions`, `correlations_per_unitcell` and `formation_energies` - filled in for all the configurations - - - """ - fitting_data = FittingData() - - fitting_data.names = data["names"] - fitting_data.parametric_compositions = np.array(data["parametric_compositions"]) - fitting_data.mol_compositions = np.array(data["mol_compositions"]) - fitting_data.correlations_per_unitcell = np.array( - data["correlations_per_unitcell"] - ) - fitting_data.formation_energies = np.array(data["formation_energies"]) - - return fitting_data - def to_dict(self): """Turn `FittingData` into a dictionary with `names`, `parametric_compositions`, `mol_compositions`, `correlations_per_unitcell` @@ -175,13 +175,19 @@ def to_dict(self): data : dict """ - return dict( - names=self.names, + names=( + self.names if isinstance(self.names, list) + else self.names.tolist() + ), parametric_compositions=self.parametric_compositions.tolist(), mol_compositions=self.mol_compositions.tolist(), correlations_per_unitcell=self.correlations_per_unitcell.tolist(), - formation_energies=self.formation_energies.tolist(), + formation_energies=( + self.formation_energies.tolist() + if self.formation_energies is not None + else None + ), ) @@ -271,6 +277,9 @@ def make_calculated_fitting_data( composition_converter: comp.CompositionConverter, clexulator: clex.Clexulator, prim_neighbor_list: clex.PrimNeighborList, + proj: "Project", + id: str, + names: list[str] = None, ) -> FittingData: """For a given `config_props` list, constructs FittingData which which holds compositions, correlations per unitcell, formation energies @@ -295,6 +304,14 @@ def make_calculated_fitting_data( A :class:`~libcasm.clexulator.PrimNeighborList` which will be used to construct the :class:`~libcasm.clexulator.SuperNeighborList` for every configuration and will be used while obtaining correlations + proj: casm.project.Project + The CASM project + id: str + The fit identifier. Fitting data is stored in the + fits directory at `/fits/fit./`. + names: Optional[list[str]] + Names of the configurations. If None (default), names are + auto-generated as ``"config.0"``, ``"config.1"``, etc. Returns ------- @@ -302,7 +319,7 @@ def make_calculated_fitting_data( """ - names = [] + _names = [] parametric_compositions = [] mol_compositions = [] correlations_per_unitcell = [] @@ -327,7 +344,7 @@ def make_calculated_fitting_data( composition_converter=composition_converter, ) - names.append("config." + str(config_id)) + _names.append("config." + str(config_id)) correlations_per_unitcell.append(corr_per_unitcell.tolist()) mol_compositions.append(mol_comp.tolist()) parametric_compositions.append(param_comp.tolist()) @@ -336,12 +353,16 @@ def make_calculated_fitting_data( # in config props. Should it be like this?? formation_energies.append(config_prop["formation_energy"]) - fitting_data = FittingData() - fitting_data.names = names - fitting_data.correlations_per_unitcell = np.array(correlations_per_unitcell) - fitting_data.mol_compositions = np.array(mol_compositions) - fitting_data.parametric_compositions = np.array(parametric_compositions) - fitting_data.formation_energies = np.array(formation_energies) + fitting_data = FittingData(proj, id) + fitting_data.from_dict( + dict( + names=names if names is not None else _names, + parametric_compositions=parametric_compositions, + mol_compositions=mol_compositions, + correlations_per_unitcell=correlations_per_unitcell, + formation_energies=formation_energies, + ) + ) return fitting_data @@ -352,6 +373,9 @@ def make_uncalculated_fitting_data( composition_converter: comp.CompositionConverter, clexulator: clex.Clexulator, prim_neighbor_list: clex.PrimNeighborList, + proj: "Project", + id: str, + names: list[str] = None, ) -> FittingData: """For a given `config_list` list, constructs FittingData which which holds compositions, correlations per unitcell of all the configurations @@ -363,10 +387,10 @@ def make_uncalculated_fitting_data( ---------- xtal_prim : xtal.Prim Prim of the project - config_props : list[dict] - A list containing results of mapping/import - composition_converter : libcasm.composition.CompositionCalculator - A :class:`~libcasm.composition.CompositionCalculator` object with + config_list : list[dict] + A list containing results of enumeration + composition_converter : libcasm.composition.CompositionConverter + A :class:`~libcasm.composition.CompositionConverter` object with the warranted composition axes set, which will be used to obtain mol and parametric compostions clexulator : libcasm.clexulator.Clexulator @@ -376,46 +400,58 @@ def make_uncalculated_fitting_data( A :class:`~libcasm.clexulator.PrimNeighborList` which will be used to construct the :class:`~libcasm.clexulator.SuperNeighborList` for every configuration and will be used while obtaining correlations + proj: casm.project.Project + The CASM project + id: str + The fit identifier. Fitting data is stored in the + fits directory at `/fits/fit./`. + names: Optional[list[str]] + Names of the configurations. If None (default), names are + auto-generated as ``"config.0"``, ``"config.1"``, etc. Returns ------- FittingData """ - names = [] + _names = [] parametric_compositions = [] mol_compositions = [] correlations_per_unitcell = [] supercell_set = casmconfig.SupercellSet(casmconfig.Prim(xtal_prim)) for config_id, config in enumerate(config_list): - config_with_properties = casmconfig.Configuration.from_dict( + configuration = casmconfig.Configuration.from_dict( config["configuration_with_properties"], supercell_set ) # Extract correlations corr_per_unitcell = _extract_correlations_for_configuration( - configuration=config_with_properties.configuration, + configuration=configuration, clexulator=clexulator, prim_neighbor_list=prim_neighbor_list, ) # Extract mol and param compositions mol_comp, param_comp = _extract_mol_and_param_comp_for_configuration( - configuration=config_with_properties.configuration, + configuration=configuration, xtal_prim=xtal_prim, composition_converter=composition_converter, ) - names.append("config." + str(config_id)) - correlations_per_unitcell.append(corr_per_unitcell) - mol_compositions.append(mol_comp) - parametric_compositions.append(param_comp) + _names.append("config." + str(config_id)) + correlations_per_unitcell.append(corr_per_unitcell.tolist()) + mol_compositions.append(mol_comp.tolist()) + parametric_compositions.append(param_comp.tolist()) - fitting_data = FittingData() - fitting_data.names = names - fitting_data.correlations_per_unitcell = correlations_per_unitcell - fitting_data.mol_compositions = mol_compositions - fitting_data.parametric_compositions = parametric_compositions + fitting_data = FittingData(proj, id) + fitting_data.from_dict( + dict( + names=names if names is not None else _names, + parametric_compositions=parametric_compositions, + mol_compositions=mol_compositions, + correlations_per_unitcell=correlations_per_unitcell, + ) + ) return fitting_data