diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d41c3d..0647e23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -545,6 +545,8 @@ set(LIBHMM_SOURCES src/training/viterbi_trainer.cpp src/training/segmental_kmeans_trainer.cpp src/io/file_io_manager.cpp + src/io/json_utils.cpp + src/io/hmm_json.cpp src/io/xml_file_reader.cpp src/io/xml_file_writer.cpp ) diff --git a/examples/basic_hmm_example.cpp b/examples/basic_hmm_example.cpp index e80726b..3a12ac7 100755 --- a/examples/basic_hmm_example.cpp +++ b/examples/basic_hmm_example.cpp @@ -5,10 +5,9 @@ * Demonstrates ForwardBackward probability evaluation. * 2. Distribution showcase — PDF evaluation and MLE fitting. * 3. Viterbi training on a 3-state Gaussian HMM. - * 4. XML file round-trip. + * 4. JSON file round-trip (save_json / load_json). */ #include -#include #include #include #include @@ -181,7 +180,11 @@ int main() { std::cout << hmm << std::endl; } { - std::cout << "Testing File Read/Write" << std::endl; + // ===================================================================== + // 4. JSON file round-trip + // ===================================================================== + std::cout << "Testing JSON File Round-Trip" << std::endl; + std::cout << "---------------------------" << std::endl; Hmm hmm(2); Vector pi(2); Matrix trans(2, 2); @@ -196,17 +199,18 @@ int main() { hmm.setTrans(trans); hmm.setDistribution(0, std::make_unique()); hmm.setDistribution(1, std::make_unique(2.0, 2.0)); - std::ofstream of("testrw", std::ios::out); - std::cout << hmm << std::endl; - of << hmm << std::endl; - of.close(); + std::cout << "Original HMM:\n" << hmm << std::endl; + + // Save to JSON — exact round-trip at max_digits10 precision. + const std::filesystem::path jsonPath = "basic_hmm.json"; + libhmm::save_json(hmm, jsonPath); + std::cout << "Saved to: " << jsonPath << std::endl; - Hmm hmm1(2); - std::ifstream inf("testrw", std::ios::in); - inf >> hmm1; - // inf.close( ); + // Load back. + Hmm loaded = libhmm::load_json(jsonPath); + std::cout << "Loaded HMM:\n" << loaded << std::endl; - std::cout << hmm1 << std::endl; + std::filesystem::remove(jsonPath); // clean up } } diff --git a/include/libhmm/common/common.h b/include/libhmm/common/common.h index db0d1ca..0a76352 100755 --- a/include/libhmm/common/common.h +++ b/include/libhmm/common/common.h @@ -1,11 +1,8 @@ #pragma once // Core C++ standard library headers used throughout the library. -// Lean version: linalg types (Matrix, Vector, ObservationSet, etc.) and XML -// serialization helpers now live in separate headers: -// libhmm/linalg/linalg_types.h — type aliases and clear_* helpers -// libhmm/common/serialization.h — MatrixSerializer / VectorSerializer -// Include those headers from files that need them. +// Linalg types (Matrix, Vector, ObservationSet, etc.) live in +// libhmm/linalg/linalg_types.h — include that from files that need them. #include #include diff --git a/include/libhmm/common/serialization.h b/include/libhmm/common/serialization.h deleted file mode 100644 index 6261fd1..0000000 --- a/include/libhmm/common/serialization.h +++ /dev/null @@ -1,169 +0,0 @@ -#pragma once - -// XML serialization helpers for libhmm linalg types. -// -// Provides MatrixSerializer and VectorSerializer used by the HMM XML -// I/O layer (src/hmm.cpp, src/io/). Extracted from common/common.h to keep -// that header lean and avoid pulling serialization templates into every -// translation unit that only needs basic types. - -#include -#include - -#include "libhmm/linalg/linalg_types.h" - -namespace libhmm { -namespace serialization { - -/** - * Simple XML serialization for BasicMatrix objects. - * Replaces boost::serialization with a lightweight C++17 implementation. - */ -template -class MatrixSerializer { -public: - /// Save matrix to XML format - static void save(std::ostream &os, const BasicMatrix &matrix, - const std::string &name = "matrix") { - os << "<" << name << ">\n"; - os << " " << matrix.size1() << "\n"; - os << " " << matrix.size2() << "\n"; - os << " \n"; - - for (std::size_t i = 0; i < matrix.size1(); ++i) { - os << " "; - for (std::size_t j = 0; j < matrix.size2(); ++j) { - os << matrix(i, j); - if (j < matrix.size2() - 1) - os << " "; - } - os << "\n"; - } - - os << " \n"; - os << "\n"; - } - - /// Load matrix from XML format - static void load(std::istream &is, BasicMatrix &matrix, const std::string &name = "matrix") { - std::string line; - std::size_t rows = 0, cols = 0; - - while (std::getline(is, line)) { - if (line.find("<" + name + ">") != std::string::npos) - break; - } - - if (std::getline(is, line)) { - std::size_t start = line.find("") + 6; - std::size_t end = line.find(""); - if (start != std::string::npos && end != std::string::npos) - rows = std::stoull(line.substr(start, end - start)); - } - - if (std::getline(is, line)) { - std::size_t start = line.find("") + 6; - std::size_t end = line.find(""); - if (start != std::string::npos && end != std::string::npos) - cols = std::stoull(line.substr(start, end - start)); - } - - matrix.resize(rows, cols); - std::getline(is, line); // skip - - for (std::size_t i = 0; i < rows; ++i) { - if (std::getline(is, line)) { - std::size_t start = line.find("") + 5; - std::size_t end = line.find(""); - if (start != std::string::npos && end != std::string::npos) { - std::istringstream row_stream(line.substr(start, end - start)); - for (std::size_t j = 0; j < cols; ++j) { - T value; - row_stream >> value; - matrix(i, j) = value; - } - } - } - } - } -}; - -/** - * Simple XML serialization for BasicVector objects. - * Replaces boost::serialization with a lightweight C++17 implementation. - */ -template -class VectorSerializer { -public: - /// Save vector to XML format - static void save(std::ostream &os, const BasicVector &vector, - const std::string &name = "vector") { - os << "<" << name << ">\n"; - os << " " << vector.size() << "\n"; - os << " "; - for (std::size_t i = 0; i < vector.size(); ++i) { - os << vector[i]; - if (i < vector.size() - 1) - os << " "; - } - os << "\n"; - os << "\n"; - } - - /// Load vector from XML format - static void load(std::istream &is, BasicVector &vector, const std::string &name = "vector") { - std::string line; - std::size_t size = 0; - - while (std::getline(is, line)) { - if (line.find("<" + name + ">") != std::string::npos) - break; - } - - if (std::getline(is, line)) { - std::size_t start = line.find("") + 6; - std::size_t end = line.find(""); - if (start != std::string::npos && end != std::string::npos) - size = std::stoull(line.substr(start, end - start)); - } - - vector.resize(size); - - if (std::getline(is, line)) { - std::size_t start = line.find("") + 6; - std::size_t end = line.find(""); - if (start != std::string::npos && end != std::string::npos) { - std::istringstream data_stream(line.substr(start, end - start)); - for (std::size_t i = 0; i < size; ++i) { - T value; - data_stream >> value; - vector[i] = value; - } - } - } - } -}; - -// Convenience wrappers matching the old boost::serialization style -template -void save(Archive &ar, const BasicMatrix &matrix, const std::string &name = "matrix") { - MatrixSerializer::save(ar, matrix, name); -} - -template -void load(Archive &ar, BasicMatrix &matrix, const std::string &name = "matrix") { - MatrixSerializer::load(ar, matrix, name); -} - -template -void save(Archive &ar, const BasicVector &vector, const std::string &name = "vector") { - VectorSerializer::save(ar, vector, name); -} - -template -void load(Archive &ar, BasicVector &vector, const std::string &name = "vector") { - VectorSerializer::load(ar, vector, name); -} - -} // namespace serialization -} // namespace libhmm diff --git a/include/libhmm/distributions/beta_distribution.h b/include/libhmm/distributions/beta_distribution.h index 218fbea..4551d5f 100644 --- a/include/libhmm/distributions/beta_distribution.h +++ b/include/libhmm/distributions/beta_distribution.h @@ -166,6 +166,9 @@ class BetaDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the alpha (α) shape parameter. diff --git a/include/libhmm/distributions/binomial_distribution.h b/include/libhmm/distributions/binomial_distribution.h index 77122ad..3b59c0a 100644 --- a/include/libhmm/distributions/binomial_distribution.h +++ b/include/libhmm/distributions/binomial_distribution.h @@ -92,6 +92,9 @@ class BinomialDistribution : public DistributionBase { void reset() noexcept override; std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); [[nodiscard]] int getN() const noexcept { return n_; } [[nodiscard]] double getP() const noexcept { return p_; } diff --git a/include/libhmm/distributions/chi_squared_distribution.h b/include/libhmm/distributions/chi_squared_distribution.h index 78bcac8..5f77834 100644 --- a/include/libhmm/distributions/chi_squared_distribution.h +++ b/include/libhmm/distributions/chi_squared_distribution.h @@ -115,6 +115,9 @@ class ChiSquaredDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the degrees of freedom parameter. diff --git a/include/libhmm/distributions/discrete_distribution.h b/include/libhmm/distributions/discrete_distribution.h index 5c2265f..8a23317 100755 --- a/include/libhmm/distributions/discrete_distribution.h +++ b/include/libhmm/distributions/discrete_distribution.h @@ -189,9 +189,12 @@ class DiscreteDistribution : public DistributionBase { * @return String showing all symbol probabilities */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** - * Gets the number of discrete symbols in the distribution. + * Gets the number of discrete symbols * * @return Number of symbols/categories */ diff --git a/include/libhmm/distributions/distribution_base.h b/include/libhmm/distributions/distribution_base.h index d036533..2e0bebf 100644 --- a/include/libhmm/distributions/distribution_base.h +++ b/include/libhmm/distributions/distribution_base.h @@ -7,6 +7,12 @@ namespace libhmm { +// Forward declaration — concrete distributions declare from_json(json::Reader&) +// as a static factory method; implementation is in src/io/hmm_json.cpp. +namespace json { +class Reader; +} // namespace json + /** * @brief Shared implementation base for all emission distributions. * diff --git a/include/libhmm/distributions/distribution_io_utils.h b/include/libhmm/distributions/distribution_io_utils.h deleted file mode 100644 index d0ef347..0000000 --- a/include/libhmm/distributions/distribution_io_utils.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -namespace libhmm::detail { - -/// Splits a "name=value" token at the first '=', strips whitespace from both -/// parts, and returns {name, value}. -/// Throws std::invalid_argument (with context in the message) if no '=' found. -[[nodiscard]] inline std::pair -parse_named_param(const std::string ¶m, const std::string &context) { - const auto eq = param.find('='); - if (eq == std::string::npos) - throw std::invalid_argument("Invalid " + context + " parameter format"); - std::string name = param.substr(0, eq); - std::string value = param.substr(eq + 1); - const auto trim = [](std::string &s) { - s.erase(std::remove_if(s.begin(), s.end(), [](unsigned char c) { return std::isspace(c); }), - s.end()); - }; - trim(name); - trim(value); - return {std::move(name), std::move(value)}; -} - -} // namespace libhmm::detail diff --git a/include/libhmm/distributions/emission_distribution.h b/include/libhmm/distributions/emission_distribution.h index 84c8162..16a86d1 100644 --- a/include/libhmm/distributions/emission_distribution.h +++ b/include/libhmm/distributions/emission_distribution.h @@ -92,7 +92,17 @@ class EmissionDistribution { // Metadata // ========================================================================= - [[nodiscard]] virtual std::string toString() const = 0; + /** @brief Human-readable string representation. Delegates to to_json(). */ + [[nodiscard]] virtual std::string toString() const { return to_json(); } + + /** + * @brief Serialise to a compact JSON object string. + * + * Must produce output that round-trips exactly through the matching + * static from_json() factory registered in src/io/hmm_json.cpp. + * Use json::write_distribution() from libhmm/io/json_utils.h. + */ + [[nodiscard]] virtual std::string to_json() const = 0; /** @brief Returns true for discrete (PMF) distributions, false for continuous (PDF). */ [[nodiscard]] virtual bool isDiscrete() const noexcept = 0; diff --git a/include/libhmm/distributions/exponential_distribution.h b/include/libhmm/distributions/exponential_distribution.h index 5c68010..d02b781 100755 --- a/include/libhmm/distributions/exponential_distribution.h +++ b/include/libhmm/distributions/exponential_distribution.h @@ -134,6 +134,9 @@ class ExponentialDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the rate parameter λ. diff --git a/include/libhmm/distributions/gamma_distribution.h b/include/libhmm/distributions/gamma_distribution.h index 80d8e95..9bbb544 100755 --- a/include/libhmm/distributions/gamma_distribution.h +++ b/include/libhmm/distributions/gamma_distribution.h @@ -168,6 +168,9 @@ class GammaDistribution : public DistributionBase { * @return String describing the distribution parameters */ [[nodiscard]] std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the shape parameter k. diff --git a/include/libhmm/distributions/gaussian_distribution.h b/include/libhmm/distributions/gaussian_distribution.h index a2ecac4..755cf2e 100755 --- a/include/libhmm/distributions/gaussian_distribution.h +++ b/include/libhmm/distributions/gaussian_distribution.h @@ -157,6 +157,9 @@ class GaussianDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the mean parameter μ. diff --git a/include/libhmm/distributions/log_normal_distribution.h b/include/libhmm/distributions/log_normal_distribution.h index 631b459..fcb838e 100755 --- a/include/libhmm/distributions/log_normal_distribution.h +++ b/include/libhmm/distributions/log_normal_distribution.h @@ -131,6 +131,9 @@ class LogNormalDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the mean parameter μ of the underlying normal distribution. diff --git a/include/libhmm/distributions/negative_binomial_distribution.h b/include/libhmm/distributions/negative_binomial_distribution.h index 29147e3..3c1461f 100644 --- a/include/libhmm/distributions/negative_binomial_distribution.h +++ b/include/libhmm/distributions/negative_binomial_distribution.h @@ -147,6 +147,9 @@ class NegativeBinomialDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the number of successes parameter r. diff --git a/include/libhmm/distributions/pareto_distribution.h b/include/libhmm/distributions/pareto_distribution.h index 75c7dfa..f83e5cb 100755 --- a/include/libhmm/distributions/pareto_distribution.h +++ b/include/libhmm/distributions/pareto_distribution.h @@ -158,6 +158,9 @@ class ParetoDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the shape parameter k. diff --git a/include/libhmm/distributions/poisson_distribution.h b/include/libhmm/distributions/poisson_distribution.h index 06dd030..b3f9f35 100644 --- a/include/libhmm/distributions/poisson_distribution.h +++ b/include/libhmm/distributions/poisson_distribution.h @@ -125,6 +125,9 @@ class PoissonDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the rate parameter λ. diff --git a/include/libhmm/distributions/rayleigh_distribution.h b/include/libhmm/distributions/rayleigh_distribution.h index e33bf49..44bc754 100644 --- a/include/libhmm/distributions/rayleigh_distribution.h +++ b/include/libhmm/distributions/rayleigh_distribution.h @@ -150,6 +150,9 @@ class RayleighDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * Gets the scale parameter σ. diff --git a/include/libhmm/distributions/student_t_distribution.h b/include/libhmm/distributions/student_t_distribution.h index 0b0fc9d..5b4fde4 100644 --- a/include/libhmm/distributions/student_t_distribution.h +++ b/include/libhmm/distributions/student_t_distribution.h @@ -216,6 +216,9 @@ class StudentTDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * @brief Equality comparison operator diff --git a/include/libhmm/distributions/uniform_distribution.h b/include/libhmm/distributions/uniform_distribution.h index c7d066d..11489f3 100644 --- a/include/libhmm/distributions/uniform_distribution.h +++ b/include/libhmm/distributions/uniform_distribution.h @@ -95,6 +95,9 @@ class UniformDistribution : public DistributionBase { * @return String description */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** * @brief Get the lower bound parameter diff --git a/include/libhmm/distributions/weibull_distribution.h b/include/libhmm/distributions/weibull_distribution.h index 6c1a78a..45cc114 100644 --- a/include/libhmm/distributions/weibull_distribution.h +++ b/include/libhmm/distributions/weibull_distribution.h @@ -156,9 +156,12 @@ class WeibullDistribution : public DistributionBase { * @return String describing the distribution parameters */ std::string toString() const override; + [[nodiscard]] std::string to_json() const override; + /// @internal JSON factory — called by the distribution registry in src/io/hmm_json.cpp. + static std::unique_ptr from_json(json::Reader &r); /** - * Computes the cumulative distribution function (CDF) for the Weibull distribution. + * Computes the cumulative distribution function * * @param x The value at which to evaluate the CDF (should be ≥ 0) * @return Cumulative probability P(X ≤ x), or 0.0 if x is negative diff --git a/include/libhmm/io/hmm_json.h b/include/libhmm/io/hmm_json.h new file mode 100644 index 0000000..92088a7 --- /dev/null +++ b/include/libhmm/io/hmm_json.h @@ -0,0 +1,48 @@ +#pragma once + +// HMM JSON serialization/deserialization. +// +// Provides four free functions: +// to_json — serialize Hmm → JSON string +// from_json — deserialize JSON string → Hmm +// save_json — write JSON string to a file +// load_json — read JSON string from a file and deserialize +// +// The JSON schema is: +// { +// "states": , +// "pi": [, ..., ], +// "trans": [[], ..., []], +// "distributions": [ +// {"type":"", ...params...}, +// ... +// ] +// } +// +// All doubles are serialized with max_digits10 precision for exact round-trip. + +#include +#include +#include + +#include "libhmm/hmm.h" + +namespace libhmm { + +/// Serialize an HMM to a compact JSON string. +[[nodiscard]] std::string to_json(const Hmm &hmm); + +/// Deserialize an HMM from a JSON string produced by to_json(). +/// Throws std::runtime_error on malformed input. +Hmm from_json(std::string_view src); + +/// Write hmm as JSON to filepath. +/// Creates parent directories as needed. +/// Throws std::runtime_error on I/O failure. +void save_json(const Hmm &hmm, const std::filesystem::path &filepath); + +/// Read and deserialize an HMM from a JSON file at filepath. +/// Throws std::runtime_error on I/O or parse failure. +Hmm load_json(const std::filesystem::path &filepath); + +} // namespace libhmm diff --git a/include/libhmm/io/json_utils.h b/include/libhmm/io/json_utils.h new file mode 100644 index 0000000..363bd70 --- /dev/null +++ b/include/libhmm/io/json_utils.h @@ -0,0 +1,120 @@ +#pragma once + +// Internal JSON utilities for the libhmm serializer/deserializer. +// +// This is NOT a general-purpose JSON library. It handles exactly the +// schema used by libhmm's HMM files: +// - Objects with string and double scalar fields +// - Arrays of doubles (pi, distribution parameters) +// - 2-D arrays of doubles (transition matrices) +// +// Do not include this header from distribution or training code; +// it is an implementation detail of src/io/. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace libhmm { +namespace json { + +// ============================================================================= +// Write helpers — produce JSON text fragments +// ============================================================================= + +/// Double with enough precision for exact round-trip (numeric_limits max_digits10). +[[nodiscard]] std::string write_double(double v); + +/// JSON array of doubles: [1.0, 2.0, 3.0] +[[nodiscard]] std::string write_array(std::span v); + +/// JSON 2-D array from a row-major flat buffer (rows × cols elements): +/// [[r0c0, r0c1], [r1c0, r1c1], ...] +[[nodiscard]] std::string write_matrix(std::size_t rows, std::size_t cols, + std::span data); + +/// JSON object for a distribution with only scalar (double) fields. +/// Example: write_distribution("Gaussian", {{"mu", 0.0}, {"sigma", 1.0}}) +/// → {"type":"Gaussian","mu":0.0,"sigma":1.0} +[[nodiscard]] std::string +write_distribution(std::string_view type, + std::initializer_list> fields); + +/// JSON object for a distribution that has one array field (e.g. Discrete). +/// Scalar fields appear before the array field. +[[nodiscard]] std::string +write_distribution_with_array(std::string_view type, + std::initializer_list> scalars, + std::string_view array_key, std::span array_val); + +// ============================================================================= +// Reader — schema-aware tokenizer over a JSON string +// ============================================================================= + +/// Lightweight schema-aware tokenizer. Operates against a fixed, known schema; +/// does not build a generic parse tree. +/// +/// Typical call sequence when reading a distribution object: +/// r.consume('{'); +/// r.read_key(); // "type" +/// std::string t = r.read_string(); +/// // dispatch on t, then call distribution-specific from_json(r) +/// // which reads remaining fields and the closing '}' +class Reader { +public: + explicit Reader(std::string_view src) noexcept : src_(src) {} + + // ---- Low-level ---- + + /// Skip JSON whitespace (space, tab, CR, LF). + void skip_ws() noexcept; + + /// Assert the next non-whitespace character is c and advance past it. + /// Throws std::runtime_error if the character does not match. + void consume(char c); + + /// Peek at the next non-whitespace character without advancing. + [[nodiscard]] char peek(); + + /// Return true if the next non-whitespace character is c. + [[nodiscard]] bool at(char c); + + // ---- Value readers ---- + + /// Read a JSON string literal ("...") and return its contents (no escape processing). + [[nodiscard]] std::string read_string(); + + /// Read a JSON number and return it as double. + [[nodiscard]] double read_double(); + + /// Read a JSON array of numbers: [d0, d1, ...] → vector. + /// Throws std::runtime_error if the array contains more than max_elements entries. + /// Default is effectively unlimited; callers with a known expected size should pass it. + [[nodiscard]] std::vector + read_double_array(std::size_t max_elements = std::numeric_limits::max()); + + /// Read a JSON 2-D array: [[r0c0,...], [r1c0,...], ...] → vector>. + /// Throws if the matrix exceeds max_rows rows or any row exceeds max_cols_per_row entries. + [[nodiscard]] std::vector> + read_double_matrix(std::size_t max_rows = std::numeric_limits::max(), + std::size_t max_cols_per_row = std::numeric_limits::max()); + + // ---- Object helpers ---- + + /// Read a JSON object key including the trailing colon: "key" → "key". + /// Optionally preceded by a comma if not the first key. + std::string read_key(); + +private: + std::string_view src_; + std::size_t pos_{0}; +}; + +} // namespace json +} // namespace libhmm diff --git a/include/libhmm/io/xml_file_reader.h b/include/libhmm/io/xml_file_reader.h index 19fb5bc..9da3320 100644 --- a/include/libhmm/io/xml_file_reader.h +++ b/include/libhmm/io/xml_file_reader.h @@ -13,8 +13,13 @@ namespace libhmm { /** - * Modern XML file reader for HMM deserialization with C++17 features. - * Provides safe XML deserialization with proper error handling and validation. + * XML file reader for HMM deserialization. + * + * Reads the CDATA-wrapped text format written by XMLFileWriter: + * ]]> + * + * @deprecated Prefer load_json() / save_json() from hmm_json.h for new code. + * XMLFileReader is retained for reading legacy .xml files only. */ class XMLFileReader { public: @@ -56,12 +61,12 @@ class XMLFileReader { static bool canReadFromPath(const std::filesystem::path &filepath) noexcept; /** - * Checks if a file exists and appears to be a valid XML file. + * Checks if a file exists and looks like a libhmm XML file. * * @param filepath Path to check - * @return true if file exists and has XML content, false otherwise + * @return true if file is readable, non-empty, and starts with an XML declaration */ - static bool isValidXMLFile(const std::filesystem::path &filepath) noexcept; + static bool canParseAsHmm(const std::filesystem::path &filepath) noexcept; private: /** diff --git a/include/libhmm/io/xml_file_writer.h b/include/libhmm/io/xml_file_writer.h index 305ed47..b5d8ed6 100755 --- a/include/libhmm/io/xml_file_writer.h +++ b/include/libhmm/io/xml_file_writer.h @@ -13,8 +13,13 @@ namespace libhmm { /** - * Modern XML file writer for HMM serialization with C++17 features. - * Provides safe XML serialization with proper error handling and validation. + * XML file writer for HMM serialization. + * + * Writes a CDATA-wrapped text format: + * ]]> + * + * @deprecated Prefer save_json() / load_json() from hmm_json.h for new code. + * XMLFileWriter is retained for producing legacy .xml files only. */ class XMLFileWriter { public: diff --git a/include/libhmm/libhmm.h b/include/libhmm/libhmm.h index 8127e26..9212921 100755 --- a/include/libhmm/libhmm.h +++ b/include/libhmm/libhmm.h @@ -69,7 +69,12 @@ // INPUT/OUTPUT //============================================================================== -/// File I/O and XML serialization +/// JSON serialization/deserialization — the recommended format for new code. +/// Provides to_json(), from_json(), save_json(), and load_json() free functions. +#include "libhmm/io/hmm_json.h" + +/// File I/O utilities and legacy XML serialization (deprecated; retained for +/// reading existing .xml files). #include "libhmm/io/file_io_manager.h" #include "libhmm/io/xml_file_reader.h" #include "libhmm/io/xml_file_writer.h" diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 0000000..624b4be --- /dev/null +++ b/samples/README.md @@ -0,0 +1,59 @@ +# libhmm sample HMM files + +Pre-built HMM files for testing, validation, and experimentation. +Each model is provided in both the recommended JSON format and the legacy XML format. + +## Files + +### two_state_gaussian — 2-state Gaussian HMM +Simple two-state continuous HMM. Good for validating basic load/inference +pipelines with continuous observations. + +| Parameter | Value | +|-----------|-------| +| States | 2 | +| Pi | [0.75, 0.25] | +| Transition | [[0.875, 0.125], [0.25, 0.75]] | +| State 0 | Gaussian(μ=0, σ=1) | +| State 1 | Gaussian(μ=2.5, σ=0.5) | + +### casino — 2-state discrete HMM (dishonest casino) +Classic discrete HMM from Durbin et al. (1998). Two dice: fair (uniform over +6 outcomes) and loaded (biased toward face 6). Good for Viterbi decoding +tests and discrete distribution validation. + +| Parameter | Value | +|-----------|-------| +| States | 2 | +| Pi | [0.75, 0.25] | +| Transition | [[0.875, 0.125], [0.25, 0.75]] | +| State 0 (fair) | Discrete(n=6, uniform ≈ 1/6 each) | +| State 1 (loaded) | Discrete(n=6, [0.125×5, 0.375]) | + +## Formats + +- **`.json`** — Recommended. Use `libhmm::load_json()` / `libhmm::save_json()`. + Exact IEEE 754 round-trip via `max_digits10` precision. +- **`.xml`** — Legacy CDATA-wrapped text format. Use `XMLFileReader` for + reading existing files; prefer JSON for new code. + +## Usage + +```cpp +#include "libhmm/io/hmm_json.h" + +// JSON (recommended) +auto hmm = libhmm::load_json("samples/two_state_gaussian.json"); + +// Legacy XML +#include "libhmm/io/xml_file_reader.h" +libhmm::XMLFileReader reader; +auto hmm = reader.read("samples/two_state_gaussian.xml"); +``` + +Or from the command line with the validator tool: + +``` +hmm_validator samples/two_state_gaussian.json +hmm_validator samples/casino.xml +``` diff --git a/samples/casino.json b/samples/casino.json new file mode 100644 index 0000000..e402a0c --- /dev/null +++ b/samples/casino.json @@ -0,0 +1 @@ +{"states":2,"pi":[0.75,0.25],"trans":[[0.875,0.125],[0.25,0.75]],"distributions":[{"type":"Discrete","n":6,"probs":[0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666]},{"type":"Discrete","n":6,"probs":[0.125,0.125,0.125,0.125,0.125,0.375]}]} diff --git a/samples/casino.xml b/samples/casino.xml new file mode 100644 index 0000000..84c9bb6 --- /dev/null +++ b/samples/casino.xml @@ -0,0 +1,30 @@ + + + + diff --git a/samples/two_state_gaussian.json b/samples/two_state_gaussian.json new file mode 100644 index 0000000..7b59d7f --- /dev/null +++ b/samples/two_state_gaussian.json @@ -0,0 +1 @@ +{"states":2,"pi":[0.75,0.25],"trans":[[0.875,0.125],[0.25,0.75]],"distributions":[{"type":"Gaussian","mu":0,"sigma":1},{"type":"Gaussian","mu":2.5,"sigma":0.5}]} diff --git a/samples/two_state_gaussian.xml b/samples/two_state_gaussian.xml new file mode 100644 index 0000000..53a0a9f --- /dev/null +++ b/samples/two_state_gaussian.xml @@ -0,0 +1,24 @@ + + + + diff --git a/src/distributions/beta_distribution.cpp b/src/distributions/beta_distribution.cpp index 1f14e0a..e2d98fe 100644 --- a/src/distributions/beta_distribution.cpp +++ b/src/distributions/beta_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/beta_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/math/weighted_stats.h" #include @@ -362,27 +363,27 @@ std::ostream &operator<<(std::ostream &os, const BetaDistribution &distribution) return os << distribution.toString(); } +// Parses the format produced by toString() / operator<<: +// Beta Distribution: +// \u03b1 (alpha) = VALUE +// \u03b2 (beta) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, BetaDistribution &distribution) { try { - std::string token; - double alpha = 1.0; - double beta = 1.0; - // Expected format: "Beta Distribution: α (alpha) = β (beta) = " - is >> token >> token; // "Beta" "Distribution:" - is >> token >> token >> token >> token; // "α" "(alpha)" "=" - alpha = std::stod(token); - is >> token >> token >> token >> token; // "β" "(beta)" "=" - beta = std::stod(token); - - if (is.good()) { + std::string s, t; + is >> s >> s; // "Beta" "Distribution:" + is >> s >> s >> s >> t; // "\u03b1" "(alpha)" "=" VALUE + const double alpha = std::stod(t); + is >> s >> s >> s >> t; // "\u03b2" "(beta)" "=" VALUE + const double beta = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) distribution = BetaDistribution(alpha, beta); - } - } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -400,4 +401,16 @@ void BetaDistribution::getBatchLogProbabilities(std::span observat } } +std::string BetaDistribution::to_json() const { + return json::write_distribution("Beta", {{"alpha", alpha_}, {"beta", beta_}}); +} +std::unique_ptr BetaDistribution::from_json(json::Reader &r) { + r.read_key(); + const double alpha = r.read_double(); + r.read_key(); + const double beta = r.read_double(); + r.consume('}'); + return std::make_unique(alpha, beta); +} + } // namespace libhmm diff --git a/src/distributions/binomial_distribution.cpp b/src/distributions/binomial_distribution.cpp index 856fa53..2a41b30 100644 --- a/src/distributions/binomial_distribution.cpp +++ b/src/distributions/binomial_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/binomial_distribution.h" +#include "libhmm/io/json_utils.h" // Header already includes: , , , , , via common.h #include // For std::accumulate (not in common.h) #include // For std::for_each, std::max_element (exists in common.h, included for clarity) @@ -202,41 +203,35 @@ bool BinomialDistribution::operator==(const BinomialDistribution &other) const { return (n_ == other.n_) && (std::abs(p_ - other.p_) < tolerance); } -std::istream &operator>>(std::istream &is, libhmm::BinomialDistribution &distribution) { - std::string token; - // Expected format: "Binomial(n,p)" or "n p" - if (is >> token) { - int n = 0; - double p = 0.0; - if (token.find("Binomial") != std::string::npos) { - // Skip to parameters - char ch = '\0'; - is >> ch >> n >> ch >> p >> ch; // Read (n,p) - } else { - // Assume first token is n - n = std::stoi(token); - is >> p; - } +std::ostream &operator<<(std::ostream &os, const libhmm::BinomialDistribution &distribution) { + os << distribution.toString(); + return os; +} - try { +// Parses the format produced by toString() / operator<<: +// Binomial Distribution: +// n (trials) = VALUE +// p (success probability) = VALUE +// Mean = VALUE +// Variance = VALUE +std::istream &operator>>(std::istream &is, libhmm::BinomialDistribution &distribution) { + try { + std::string s, t; + is >> s >> s; // "Binomial" "Distribution:" + is >> s >> s >> s >> t; // "n" "(trials)" "=" VALUE + const int n = static_cast(std::stod(t)); + is >> s >> s >> s >> s >> t; // "p" "(success" "probability)" "=" VALUE + const double p = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) distribution.setParameters(n, p); - } catch (const std::exception &) { - is.setstate(std::ios::failbit); - } + } catch (const std::exception &) { + is.setstate(std::ios::failbit); } - return is; } -std::ostream &operator<<(std::ostream &os, const libhmm::BinomialDistribution &distribution) { - os << "Binomial Distribution:" << std::endl; - os << " n = " << distribution.getN() << std::endl; - os << " p = " << distribution.getP() << std::endl; - os << std::endl; - - return os; -} - void BinomialDistribution::getBatchLogProbabilities(std::span observations, std::span out) const { // Tier 1 — concrete non-virtual loop; compiler auto-vectorizes the arithmetic @@ -251,4 +246,16 @@ void BinomialDistribution::getBatchLogProbabilities(std::span obse } } +std::string BinomialDistribution::to_json() const { + return json::write_distribution("Binomial", {{"n", static_cast(n_)}, {"p", p_}}); +} +std::unique_ptr BinomialDistribution::from_json(json::Reader &r) { + r.read_key(); + const int n = static_cast(r.read_double()); + r.read_key(); + const double p = r.read_double(); + r.consume('}'); + return std::make_unique(n, p); +} + } // namespace libhmm diff --git a/src/distributions/chi_squared_distribution.cpp b/src/distributions/chi_squared_distribution.cpp index 13e0476..7fec7c1 100644 --- a/src/distributions/chi_squared_distribution.cpp +++ b/src/distributions/chi_squared_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/chi_squared_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/math/weighted_stats.h" #include #include @@ -131,30 +132,23 @@ bool ChiSquaredDistribution::operator==(const ChiSquaredDistribution &other) con } std::ostream &operator<<(std::ostream &os, const ChiSquaredDistribution &dist) { - os << std::fixed << std::setprecision(6); - os << "ChiSquared Distribution: k = " << dist.getDegreesOfFreedom(); + os << dist.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// ChiSquared Distribution: +// k (degrees of freedom) = VALUE std::istream &operator>>(std::istream &is, ChiSquaredDistribution &dist) { try { - std::string token; - double k = 0.0; - // Expected format: "ChiSquared Distribution: k = " - std::string k_str; - is >> token >> token >> token >> token >> - k_str; // "ChiSquared" "Distribution:" "k" "=" - k = std::stod(k_str); - - if (is.good()) { - dist.setDegreesOfFreedom(k); - } - + std::string s, t; + is >> s >> s; // "ChiSquared" "Distribution:" + is >> s >> s >> s >> s >> s >> t; // "k" "(degrees" "of" "freedom)" "=" VALUE + if (is.good()) + dist.setDegreesOfFreedom(std::stod(t)); } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -173,4 +167,14 @@ void ChiSquaredDistribution::getBatchLogProbabilities(std::span ob } } +std::string ChiSquaredDistribution::to_json() const { + return json::write_distribution("ChiSquared", {{"k", degrees_of_freedom_}}); +} +std::unique_ptr ChiSquaredDistribution::from_json(json::Reader &r) { + r.read_key(); + const double k = r.read_double(); + r.consume('}'); + return std::make_unique(k); +} + } // namespace libhmm diff --git a/src/distributions/discrete_distribution.cpp b/src/distributions/discrete_distribution.cpp index a26a661..5f018df 100755 --- a/src/distributions/discrete_distribution.cpp +++ b/src/distributions/discrete_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/discrete_distribution.h" +#include "libhmm/io/json_utils.h" #include using namespace libhmm::constants; @@ -169,39 +170,31 @@ bool DiscreteDistribution::operator==(const DiscreteDistribution &other) const { return true; } -/** - * Stream input operator implementation - * Expects format with number of symbols followed by probabilities - */ +// Parses the format produced by toString() / operator<<: +// Discrete Distribution: +// Number of symbols = N +// P(0) = VALUE +// ... +// P(N-1) = VALUE std::istream &operator>>(std::istream &is, libhmm::DiscreteDistribution &distribution) { - std::size_t numSymbols = 0; - - if (!(is >> numSymbols)) { - is.setstate(std::ios::failbit); - return is; - } - - // Create new distribution with the specified number of symbols try { - DiscreteDistribution newDist(static_cast(numSymbols)); - - // Read probabilities - for (std::size_t i = 0; i < numSymbols; ++i) { - double prob = 0.0; - if (!(is >> prob)) { - is.setstate(std::ios::failbit); - return is; - } - newDist.setProbability(static_cast(i), prob); + std::string s, t; + is >> s >> s; // "Discrete" "Distribution:" + is >> s >> s >> s >> s >> t; // "Number" "of" "symbols" "=" N + const auto n = std::stoull(t); + if (n == 0) { + is.setstate(std::ios::failbit); + return is; + } + DiscreteDistribution newDist(static_cast(n)); + for (std::size_t i = 0; i < n; ++i) { + is >> s >> s >> t; // "P(i)" "=" VALUE + newDist.setProbability(static_cast(i), std::stod(t)); } - - // If successful, update the distribution distribution = std::move(newDist); - } catch (const std::exception &) { is.setstate(std::ios::failbit); } - return is; } @@ -225,4 +218,34 @@ void DiscreteDistribution::getBatchLogProbabilities(std::span obse } } +std::string DiscreteDistribution::to_json() const { + return json::write_distribution_with_array("Discrete", + {{"n", static_cast(numSymbols_)}}, "probs", + std::span(pdf_.data(), numSymbols_)); +} +std::unique_ptr DiscreteDistribution::from_json(json::Reader &r) { + // Maximum symbol count accepted during deserialization. + // 65536 symbols × 8 bytes = 512 KB per distribution — generous for any + // practical use. Values above this cap indicate corrupted or adversarial input. + // The guard also prevents static_cast UB when n is non-finite or huge. + static constexpr int kMaxDiscreteSymbols = 65536; + + r.read_key(); // "n" + const double n_raw = r.read_double(); + if (!std::isfinite(n_raw) || n_raw < 1.0 || n_raw > static_cast(kMaxDiscreteSymbols)) + throw std::runtime_error("DiscreteDistribution JSON: n must be an integer in [1, " + + std::to_string(kMaxDiscreteSymbols) + "]"); + const int n = static_cast(n_raw); + + r.read_key(); // "probs" + // Cap array read to n elements — a longer array is malformed and must not + // be allowed to grow the heap before the distribution constructor fires. + const auto probs = r.read_double_array(static_cast(n)); + r.consume('}'); + auto dist = std::make_unique(n); + for (std::size_t i = 0; i < probs.size(); ++i) + dist->setProbability(static_cast(i), probs[i]); + return dist; +} + } // namespace libhmm diff --git a/src/distributions/exponential_distribution.cpp b/src/distributions/exponential_distribution.cpp index ac9d23e..e499b83 100755 --- a/src/distributions/exponential_distribution.cpp +++ b/src/distributions/exponential_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/exponential_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/math/weighted_stats.h" #include "libhmm/platform/simd_platform.h" #include @@ -148,30 +149,26 @@ std::string ExponentialDistribution::toString() const { } std::ostream &operator<<(std::ostream &os, const libhmm::ExponentialDistribution &distribution) { - os << "Exponential Distribution: " << std::endl; - os << " Rate parameter = " << distribution.getLambda() << std::endl; - os << std::endl; - + os << distribution.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// Exponential Distribution: +// \u03bb (rate parameter) = VALUE +// Mean = VALUE std::istream &operator>>(std::istream &is, libhmm::ExponentialDistribution &distribution) { try { - std::string token, lambda_str; - is >> token; // "Rate" - is >> token; // "parameter" - is >> token; // "=" - is >> lambda_str; - double lambda = std::stod(lambda_str); - - // Use setLambda for validation - distribution.setLambda(lambda); - + std::string s, t; + is >> s >> s; // "Exponential" "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bb" "(rate" "parameter)" "=" VALUE + const double lambda = std::stod(t); + is >> s >> s >> t; // skip Mean + if (is.good()) + distribution.setLambda(lambda); } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -282,4 +279,14 @@ void ExponentialDistribution::getBatchLogProbabilities(std::span o logLambda_, negLambda_); } +std::string ExponentialDistribution::to_json() const { + return json::write_distribution("Exponential", {{"lambda", lambda_}}); +} +std::unique_ptr ExponentialDistribution::from_json(json::Reader &r) { + r.read_key(); + const double lambda = r.read_double(); + r.consume('}'); + return std::make_unique(lambda); +} + } // namespace libhmm diff --git a/src/distributions/gamma_distribution.cpp b/src/distributions/gamma_distribution.cpp index 3fbd6bd..e05ec49 100755 --- a/src/distributions/gamma_distribution.cpp +++ b/src/distributions/gamma_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/gamma_distribution.h" +#include "libhmm/io/json_utils.h" #include using namespace libhmm::constants; @@ -182,38 +183,31 @@ std::string GammaDistribution::toString() const { } std::ostream &operator<<(std::ostream &os, const libhmm::GammaDistribution &distribution) { - os << "Gamma Distribution: " << std::endl; - os << " k (shape) = " << distribution.getK() << std::endl; - os << " theta (scale) = " << distribution.getTheta() << std::endl; - os << " Mean = " << distribution.getMean() << std::endl; - os << " Variance = " << distribution.getVariance() << std::endl; - + os << distribution.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// Gamma Distribution: +// k (shape parameter) = VALUE +// \u03b8 (scale parameter) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::GammaDistribution &distribution) { try { - std::string token, k_str, theta_str; - is >> token; // "k" - is >> token; // "(shape)" - is >> token; // "=" - is >> k_str; - double k = std::stod(k_str); - - is >> token; // "theta" - is >> token; // "(scale)" - is >> token; // "=" - is >> theta_str; - double theta = std::stod(theta_str); - - // Use setParameters for validation - distribution.setParameters(k, theta); - + std::string s, t; + is >> s >> s; // "Gamma" "Distribution:" + is >> s >> s >> s >> s >> t; // "k" "(shape" "parameter)" "=" VALUE + const double k = std::stod(t); + is >> s >> s >> s >> s >> t; // "\u03b8" "(scale" "parameter)" "=" VALUE + const double theta = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) + distribution.setParameters(k, theta); } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -238,4 +232,16 @@ void GammaDistribution::getBatchLogProbabilities(std::span observa } } +std::string GammaDistribution::to_json() const { + return json::write_distribution("Gamma", {{"k", k_}, {"theta", theta_}}); +} +std::unique_ptr GammaDistribution::from_json(json::Reader &r) { + r.read_key(); + const double k = r.read_double(); + r.read_key(); + const double theta = r.read_double(); + r.consume('}'); + return std::make_unique(k, theta); +} + } // namespace libhmm diff --git a/src/distributions/gaussian_distribution.cpp b/src/distributions/gaussian_distribution.cpp index d4c48ef..df13438 100755 --- a/src/distributions/gaussian_distribution.cpp +++ b/src/distributions/gaussian_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/gaussian_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/platform/simd_platform.h" // compile-time SIMD macros + intrinsics #include #include @@ -176,36 +177,31 @@ std::string GaussianDistribution::toString() const { } std::ostream &operator<<(std::ostream &os, const libhmm::GaussianDistribution &distribution) { - os << "Normal Distribution: " << std::endl; - os << " Mean = " << distribution.getMean() << std::endl; - os << " Standard deviation = " << distribution.getStandardDeviation(); - os << std::endl; - + os << distribution.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// Gaussian Distribution: +// \u03bc (mean) = VALUE +// \u03c3 (std. deviation) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::GaussianDistribution &distribution) { try { - std::string token, mean_str, stddev_str; - is >> token; // "Mean" - is >> token; // "=" - is >> mean_str; - double mean = std::stod(mean_str); - - is >> token; // "Standard" - is >> token; // "Deviation" - is >> token; // "=" - is >> stddev_str; - double stdDev = std::stod(stddev_str); - - // Use setParameters for validation - distribution.setParameters(mean, stdDev); - + std::string s, t; + is >> s >> s; // "Gaussian" "Distribution:" + is >> s >> s >> s >> t; // "\u03bc" "(mean)" "=" VALUE + const double mean = std::stod(t); + is >> s >> s >> s >> s >> t; // "\u03c3" "(std." "deviation)" "=" VALUE + const double sd = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) + distribution.setParameters(mean, sd); } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -333,4 +329,16 @@ void GaussianDistribution::getBatchLogProbabilities(std::span obse negHalfSigmaSquaredInv_, log_norm); } +std::string GaussianDistribution::to_json() const { + return json::write_distribution("Gaussian", {{"mu", mean_}, {"sigma", standardDeviation_}}); +} +std::unique_ptr GaussianDistribution::from_json(json::Reader &r) { + r.read_key(); + const double mu = r.read_double(); + r.read_key(); + const double sigma = r.read_double(); + r.consume('}'); + return std::make_unique(mu, sigma); +} + } // namespace libhmm diff --git a/src/distributions/log_normal_distribution.cpp b/src/distributions/log_normal_distribution.cpp index e598969..b647e23 100755 --- a/src/distributions/log_normal_distribution.cpp +++ b/src/distributions/log_normal_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/log_normal_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/performance/simd_kernels_internal.h" // Header already includes: , , , , , via common.h #include // For std::accumulate (not in common.h) @@ -175,42 +176,35 @@ std::string LogNormalDistribution::toString() const { } std::ostream &operator<<(std::ostream &os, const libhmm::LogNormalDistribution &distribution) { - - os << "LogNormal Distribution:" << std::endl; - os << " Mean = " << distribution.getMean() << std::endl; - os << " Standard Deviation = " << distribution.getStandardDeviation(); - os << std::endl; - + os << distribution.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// LogNormal Distribution: +// \u03bc (log mean) = VALUE +// \u03c3 (log std. deviation) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::LogNormalDistribution &distribution) { try { - std::string token, mean_str, stddev_str; - is >> token; //" Mean" - is >> token; // "=" - is >> mean_str; - double mean = std::stod(mean_str); - - is >> token; // "Standard" - is >> token; // "Deviation" - is >> token; // " = " - is >> stddev_str; - double stdDev = std::stod(stddev_str); - + std::string s, t; + is >> s >> s; // "LogNormal" "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bc" "(log" "mean)" "=" VALUE + const double mean = std::stod(t); + is >> s >> s >> s >> s >> s >> t; // "\u03c3" "(log" "std." "deviation)" "=" VALUE + const double sd = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance if (is.good()) { distribution.setMean(mean); - distribution.setStandardDeviation(stdDev); + distribution.setStandardDeviation(sd); } - } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } - // ============================================================================= // Batch log-PDF — explicit SIMD intrinsics (tier 2) // @@ -318,4 +312,16 @@ void LogNormalDistribution::getBatchLogProbabilities(std::span obs negHalfSigmaSquaredInv_, logNormalizationConstant_); } +std::string LogNormalDistribution::to_json() const { + return json::write_distribution("LogNormal", {{"mu", mean_}, {"sigma", standardDeviation_}}); +} +std::unique_ptr LogNormalDistribution::from_json(json::Reader &r) { + r.read_key(); + const double mu = r.read_double(); + r.read_key(); + const double sigma = r.read_double(); + r.consume('}'); + return std::make_unique(mu, sigma); +} + } // namespace libhmm diff --git a/src/distributions/negative_binomial_distribution.cpp b/src/distributions/negative_binomial_distribution.cpp index 2836602..67f1180 100644 --- a/src/distributions/negative_binomial_distribution.cpp +++ b/src/distributions/negative_binomial_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/negative_binomial_distribution.h" +#include "libhmm/io/json_utils.h" // Header already includes: , , , , , via common.h #include // For std::accumulate (not in common.h) #include // For std::for_each (exists in common.h, included for clarity) @@ -206,67 +207,33 @@ bool NegativeBinomialDistribution::operator==(const NegativeBinomialDistribution return (std::abs(r_ - other.r_) < tolerance) && (std::abs(p_ - other.p_) < tolerance); } +// Parses the format produced by toString() / operator<<: +// Negative Binomial Distribution: +// r (successes) = VALUE +// p (success probability) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::NegativeBinomialDistribution &distribution) { - std::string token; - // Expected format: "NegativeBinomial(r,p)" or "r p" - if (is >> token) { - double r = 0.0, p = 0.0; - if (token.find("NegativeBinomial") != std::string::npos) { - // Parse formatted input: NegativeBinomial(r,p) - std::string fullInput = token; - std::string remaining; - std::getline(is, remaining); - fullInput += remaining; - - // Find the opening and closing parentheses - size_t openParen = fullInput.find('('); - size_t closeParen = fullInput.find(')'); - size_t comma = fullInput.find(','); - - if (openParen != std::string::npos && closeParen != std::string::npos && - comma != std::string::npos) { - std::string rStr = fullInput.substr(openParen + 1, comma - openParen - 1); - std::string pStr = fullInput.substr(comma + 1, closeParen - comma - 1); - - try { - r = std::stod(rStr); - p = std::stod(pStr); - } catch (const std::exception &) { - is.setstate(std::ios::failbit); - return is; - } - } else { - is.setstate(std::ios::failbit); - return is; - } - } else { - // Assume first token is r - try { - r = std::stod(token); - is >> p; - } catch (const std::exception &) { - is.setstate(std::ios::failbit); - return is; - } - } - - try { + try { + std::string s, t; + is >> s >> s >> s; // "Negative" "Binomial" "Distribution:" + is >> s >> s >> s >> t; // "r" "(successes)" "=" VALUE + const double r = std::stod(t); + is >> s >> s >> s >> s >> t; // "p" "(success" "probability)" "=" VALUE + const double p = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) distribution.setParameters(r, p); - } catch (const std::exception &) { - is.setstate(std::ios::failbit); - } + } catch (const std::exception &) { + is.setstate(std::ios::failbit); } - return is; } std::ostream &operator<<(std::ostream &os, const libhmm::NegativeBinomialDistribution &distribution) { - os << "Negative Binomial Distribution:" << std::endl; - os << " r = " << distribution.getR() << std::endl; - os << " p = " << distribution.getP() << std::endl; - os << std::endl; - + os << distribution.toString(); return os; } @@ -284,4 +251,16 @@ void NegativeBinomialDistribution::getBatchLogProbabilities(std::span NegativeBinomialDistribution::from_json(json::Reader &r) { + r.read_key(); + const double rv = r.read_double(); + r.read_key(); + const double p = r.read_double(); + r.consume('}'); + return std::make_unique(rv, p); +} + } // namespace libhmm diff --git a/src/distributions/pareto_distribution.cpp b/src/distributions/pareto_distribution.cpp index a6e5968..1a16e43 100755 --- a/src/distributions/pareto_distribution.cpp +++ b/src/distributions/pareto_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/pareto_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/performance/simd_kernels_internal.h" // Header already includes: , , , , , via common.h #include // For std::accumulate (not in common.h) @@ -171,29 +172,29 @@ std::ostream &operator<<(std::ostream &os, const libhmm::ParetoDistribution &dis return os; } +// Parses the format produced by toString() / operator<<: +// Pareto Distribution: +// k (shape parameter) = VALUE +// x_m (scale parameter) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::ParetoDistribution &distribution) { try { - std::string token, k_str, xm_str; - is >> token; //" k" - is >> token; // "=" - is >> k_str; - double k = std::stod(k_str); - - is >> token; // " xm" - is >> token; // " =" - is >> xm_str; - double xm = std::stod(xm_str); - + std::string s, t; + is >> s >> s; // "Pareto" "Distribution:" + is >> s >> s >> s >> s >> t; // "k" "(shape" "parameter)" "=" VALUE + const double k = std::stod(t); + is >> s >> s >> s >> s >> t; // "x_m" "(scale" "parameter)" "=" VALUE + const double xm = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance if (is.good()) { distribution.setK(k); distribution.setXm(xm); } - } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -301,4 +302,16 @@ void ParetoDistribution::getBatchLogProbabilities(std::span observ logK_ + kLogXm_, kPlus1_); } +std::string ParetoDistribution::to_json() const { + return json::write_distribution("Pareto", {{"k", k_}, {"xm", xm_}}); +} +std::unique_ptr ParetoDistribution::from_json(json::Reader &r) { + r.read_key(); + const double k = r.read_double(); + r.read_key(); + const double xm = r.read_double(); + r.consume('}'); + return std::make_unique(k, xm); +} + } // namespace libhmm diff --git a/src/distributions/poisson_distribution.cpp b/src/distributions/poisson_distribution.cpp index 1f65b2f..3d66462 100644 --- a/src/distributions/poisson_distribution.cpp +++ b/src/distributions/poisson_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/poisson_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/math/weighted_stats.h" #include #include @@ -193,24 +194,24 @@ bool PoissonDistribution::operator==(const PoissonDistribution &other) const { * Stream input operator implementation. * Expects format: "Poisson Distribution: λ = " */ +// Parses the format produced by toString() / operator<<: +// Poisson Distribution: +// \u03bb (rate parameter) = VALUE +// Mean = VALUE +// Variance = VALUE std::istream &operator>>(std::istream &is, libhmm::PoissonDistribution &distribution) { try { - std::string token; - double lambda = 0.0; - // Skip "Poisson Distribution: λ =" - std::string lambda_str; - is >> token >> token >> token >> token >> lambda_str; - lambda = std::stod(lambda_str); - - if (is.good()) { + std::string s, t; + is >> s >> s; // "Poisson" "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bb" "(rate" "parameter)" "=" VALUE + const double lambda = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + if (is.good()) distribution.setLambda(lambda); - } - } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -229,4 +230,14 @@ void PoissonDistribution::getBatchLogProbabilities(std::span obser } } +std::string PoissonDistribution::to_json() const { + return json::write_distribution("Poisson", {{"lambda", lambda_}}); +} +std::unique_ptr PoissonDistribution::from_json(json::Reader &r) { + r.read_key(); + const double lambda = r.read_double(); + r.consume('}'); + return std::make_unique(lambda); +} + } // namespace libhmm diff --git a/src/distributions/rayleigh_distribution.cpp b/src/distributions/rayleigh_distribution.cpp index 8ab4218..7be2b04 100644 --- a/src/distributions/rayleigh_distribution.cpp +++ b/src/distributions/rayleigh_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/rayleigh_distribution.h" +#include "libhmm/io/json_utils.h" // Header already includes: , , , , , via common.h #include // For std::numeric_limits (not in common.h) @@ -120,13 +121,26 @@ std::ostream &operator<<(std::ostream &os, const RayleighDistribution &distribut return os; } +// Parses the format produced by toString() / operator<<: +// Rayleigh Distribution: +// \u03c3 (scale parameter) = VALUE +// Mean = VALUE +// Variance = VALUE +// Median = VALUE +// Mode = VALUE std::istream &operator>>(std::istream &is, RayleighDistribution &distribution) { try { - std::string token, sigma_str; - is >> token >> token >> token; // Read "σ", "(scale", "parameter)" - is >> sigma_str; - double sigma = std::stod(sigma_str); - distribution.setSigma(sigma); + std::string s, t; + is >> s >> s; // "Rayleigh" "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03c3" "(scale" "parameter)" "=" VALUE + const double sigma = std::stod(t); + // skip Mean, Variance, Median, Mode + is >> s >> s >> t; + is >> s >> s >> t; + is >> s >> s >> t; + is >> s >> s >> t; + if (is.good()) + distribution.setSigma(sigma); } catch (const std::exception &) { is.setstate(std::ios::failbit); } @@ -149,4 +163,14 @@ void RayleighDistribution::getBatchLogProbabilities(std::span obse } } +std::string RayleighDistribution::to_json() const { + return json::write_distribution("Rayleigh", {{"sigma", sigma_}}); +} +std::unique_ptr RayleighDistribution::from_json(json::Reader &r) { + r.read_key(); + const double sigma = r.read_double(); + r.consume('}'); + return std::make_unique(sigma); +} + } // namespace libhmm diff --git a/src/distributions/student_t_distribution.cpp b/src/distributions/student_t_distribution.cpp index 42afc73..0facce9 100644 --- a/src/distributions/student_t_distribution.cpp +++ b/src/distributions/student_t_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/student_t_distribution.h" +#include "libhmm/io/json_utils.h" #include #include #include @@ -247,85 +248,31 @@ bool StudentTDistribution::operator!=(const StudentTDistribution &other) const { return !(*this == other); } -/** - * Input stream operator for reading StudentT distribution from formatted text. - * Expected format: "StudentT(nu=value, mu=value, sigma=value)" or similar variations. - */ +// Parses the format produced by toString() / operator<<: +// StudentT Distribution: +// nu (degrees of freedom) = VALUE +// mu (location) = VALUE +// sigma (scale) = VALUE std::istream &operator>>(std::istream &is, StudentTDistribution &dist) { - std::string line; - if (!std::getline(is, line)) { - is.setstate(std::ios::failbit); - return is; - } - try { - // Parse parameters from the line - double nu = 1.0, mu = 0.0, sigma = 1.0; - bool found_nu = false, found_mu = false, found_sigma = false; - - // Look for parameter patterns - std::string::size_type start = line.find('('); - std::string::size_type end = line.rfind(')'); - - if (start != std::string::npos && end != std::string::npos && end > start) { - std::string params = line.substr(start + 1, end - start - 1); - - // Split by commas and parse each parameter - std::istringstream param_stream(params); - std::string param; - - while (std::getline(param_stream, param, ',')) { - std::string::size_type eq_pos = param.find('='); - if (eq_pos != std::string::npos) { - std::string name = param.substr(0, eq_pos); - std::string value = param.substr(eq_pos + 1); - - // Trim whitespace - name.erase(std::remove_if(name.begin(), name.end(), ::isspace), name.end()); - value.erase(std::remove_if(value.begin(), value.end(), ::isspace), value.end()); - - if (name == "nu" || name == "ν" || name == "df") { - nu = std::stod(value); - found_nu = true; - } else if (name == "mu" || name == "μ" || name == "location") { - mu = std::stod(value); - found_mu = true; - } else if (name == "sigma" || name == "σ" || name == "scale") { - sigma = std::stod(value); - found_sigma = true; - } - } - } - } - - // Create new distribution with parsed parameters - if (found_nu && found_mu && found_sigma) { - dist = StudentTDistribution(nu, mu, sigma); - } else if (found_nu) { - dist = StudentTDistribution(nu); - } else { - // Default case - just parse the first number if any - std::istringstream number_stream(line); - double value = 0.0; - if (number_stream >> value) { - dist = StudentTDistribution(value); - } else { - is.setstate(std::ios::failbit); - } - } + std::string s, t; + is >> s >> s; // "StudentT" "Distribution:" + is >> s >> s >> s >> s >> s >> t; // "nu" "(degrees" "of" "freedom)" "=" VALUE + const double nu = std::stod(t); + is >> s >> s >> s >> t; // "mu" "(location)" "=" VALUE + const double mu = std::stod(t); + is >> s >> s >> s >> t; // "sigma" "(scale)" "=" VALUE + if (is.good()) + dist = StudentTDistribution(nu, mu, std::stod(t)); } catch (const std::exception &) { is.setstate(std::ios::failbit); } - return is; } -/** - * Output stream operator for writing StudentT distribution in readable format. - */ +// Delegates to toString() — consistent with other distributions. std::ostream &operator<<(std::ostream &os, const StudentTDistribution &dist) { - os << "StudentT(nu=" << std::fixed << std::setprecision(6) << dist.getDegreesOfFreedom() - << ", mu=" << dist.getLocation() << ", sigma=" << dist.getScale() << ")"; + os << dist.toString(); return os; } @@ -394,4 +341,19 @@ void StudentTDistribution::getBatchLogProbabilities(std::span obse } } +std::string StudentTDistribution::to_json() const { + return json::write_distribution( + "StudentT", {{"df", degrees_of_freedom_}, {"mu", location_}, {"sigma", scale_}}); +} +std::unique_ptr StudentTDistribution::from_json(json::Reader &r) { + r.read_key(); + const double df = r.read_double(); + r.read_key(); + const double mu = r.read_double(); + r.read_key(); + const double sigma = r.read_double(); + r.consume('}'); + return std::make_unique(df, mu, sigma); +} + } // namespace libhmm diff --git a/src/distributions/uniform_distribution.cpp b/src/distributions/uniform_distribution.cpp index 52d4b8f..d4ec866 100644 --- a/src/distributions/uniform_distribution.cpp +++ b/src/distributions/uniform_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/uniform_distribution.h" +#include "libhmm/io/json_utils.h" #include "libhmm/math/weighted_stats.h" #include #include @@ -192,31 +193,26 @@ bool UniformDistribution::operator==(const UniformDistribution &other) const { } std::ostream &operator<<(std::ostream &os, const UniformDistribution &dist) { - os << std::fixed << std::setprecision(6); - os << "Uniform Distribution: a = " << dist.getA() << ", b = " << dist.getB(); + os << dist.toString(); return os; } +// Parses the format produced by toString() / operator<<: +// Uniform Distribution: +// a (lower bound) = VALUE +// b (upper bound) = VALUE std::istream &operator>>(std::istream &is, UniformDistribution &dist) { try { - std::string token; - double a = 0.0, b = 0.0; - // Expected format: "Uniform Distribution: a = , b = " - std::string a_str, b_str; - is >> token >> token >> token >> token >> a_str >> token >> token >> token >> - b_str; // "Uniform" "Distribution:" "a" "=" "," "b" "=" - a = std::stod(a_str); - b = std::stod(b_str); - - if (is.good()) { - dist.setParameters(a, b); - } - + std::string s, t; + is >> s >> s; // "Uniform" "Distribution:" + is >> s >> s >> s >> s >> t; // "a" "(lower" "bound)" "=" VALUE + const double a = std::stod(t); + is >> s >> s >> s >> s >> t; // "b" "(upper" "bound)" "=" VALUE + if (is.good()) + dist.setParameters(a, std::stod(t)); } catch (const std::exception &) { - // Set error state on stream if parsing fails is.setstate(std::ios::failbit); } - return is; } @@ -234,4 +230,16 @@ void UniformDistribution::getBatchLogProbabilities(std::span obser } } +std::string UniformDistribution::to_json() const { + return json::write_distribution("Uniform", {{"a", a_}, {"b", b_}}); +} +std::unique_ptr UniformDistribution::from_json(json::Reader &r) { + r.read_key(); + const double a = r.read_double(); + r.read_key(); + const double b = r.read_double(); + r.consume('}'); + return std::make_unique(a, b); +} + } // namespace libhmm diff --git a/src/distributions/weibull_distribution.cpp b/src/distributions/weibull_distribution.cpp index 3287d1c..bcc5bb9 100644 --- a/src/distributions/weibull_distribution.cpp +++ b/src/distributions/weibull_distribution.cpp @@ -1,4 +1,5 @@ #include "libhmm/distributions/weibull_distribution.h" +#include "libhmm/io/json_utils.h" // Header already includes: , , , , , via common.h #include // For std::max, std::min (exists in common.h, included for clarity) #include // For std::accumulate (not in common.h) @@ -221,4 +222,16 @@ void WeibullDistribution::getBatchLogProbabilities(std::span obser } } +std::string WeibullDistribution::to_json() const { + return json::write_distribution("Weibull", {{"k", k_}, {"lambda", lambda_}}); +} +std::unique_ptr WeibullDistribution::from_json(json::Reader &r) { + r.read_key(); + const double k = r.read_double(); + r.read_key(); + const double lambda = r.read_double(); + r.consume('}'); + return std::make_unique(k, lambda); +} + } // namespace libhmm diff --git a/src/hmm.cpp b/src/hmm.cpp index b96470d..0164648 100755 --- a/src/hmm.cpp +++ b/src/hmm.cpp @@ -2,11 +2,8 @@ #include #include #include -#include -#include #include -#include -#include +#include namespace libhmm { @@ -43,26 +40,250 @@ std::ostream &operator<<(std::ostream &os, const libhmm::Hmm &h) { return os; } //operator<<() +// ============================================================================= +// operator>> stream-format parsers +// +// operator>>(istream, Hmm) reads the text produced by operator<<(ostream, Hmm). +// It is retained for backward compatibility with the CDATA-wrapped XML format +// written by XMLFileWriter. Both the XML classes and this operator are +// deprecated; prefer hmm_json.h for new code. +// +// Each parse_* function is called with the stream positioned immediately after +// the distribution type keyword (e.g. "Gaussian", "Poisson") has been read. +// The branching complexity of each parser is now measured independently by +// static analysis tools rather than being folded into operator>>'s CC. +// +// Known limitation: NegativeBinomial distributions written by operator<< cannot +// be read back because toString() begins with "Negative Binomial" (two words), +// so the type dispatch key is "Negative" rather than "NegativeBinomial". +// ============================================================================= + +namespace { + +using StreamParserFn = std::unique_ptr (*)(std::istream &); + +// parse_gaussian: reads all of GaussianDistribution::toString() after the type keyword. +std::unique_ptr parse_gaussian(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> t; // "\u03bc" "(mean)" "=" VALUE + const double mean = std::stod(t); + is >> s >> s >> s >> s >> t; // "\u03c3" "(std." "deviation)" "=" VALUE + const double sd = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(mean, sd); +} + +std::unique_ptr parse_discrete(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + if (!is) + throw std::runtime_error("Failed to parse Discrete distribution header"); + + is >> t; + if (!is) + throw std::runtime_error("Failed to parse Discrete distribution data"); + + // Modern format: "Number of symbols = N" followed by "P(i) = value" lines. + if (t == "Number") { + std::string of_tok, sym_tok, eq_tok, n_tok; + is >> of_tok >> sym_tok >> eq_tok >> n_tok; + if (of_tok != "of" || sym_tok != "symbols" || eq_tok != "=") + throw std::runtime_error("Malformed Discrete distribution symbol header"); + const auto n = std::stoull(n_tok); + if (n == 0) + throw std::runtime_error("Discrete distribution must have at least one symbol"); + auto dist = std::make_unique(static_cast(n)); + for (std::size_t k = 0; k < n; ++k) { + std::string label, eq, val; + is >> label >> eq >> val; + if (eq != "=") + throw std::runtime_error("Malformed Discrete distribution probability entry"); + dist->setProbability(static_cast(k), std::stod(val)); + } + return dist; + } + + // Legacy fallback: bare probability list (11 symbols). + constexpr std::size_t kLegacySymbols = 11; + std::vector probs(kLegacySymbols); + probs[0] = std::stod(t); + for (std::size_t k = 1; k < kLegacySymbols; ++k) { + is >> t; + probs[k] = std::stod(t); + } + auto dist = std::make_unique(static_cast(kLegacySymbols)); + for (std::size_t k = 0; k < kLegacySymbols; ++k) + dist->setProbability(static_cast(k), probs[k]); + return dist; +} + +// parse_gamma: "k (shape parameter) = V\n\u03b8 (scale parameter) = V\nMean = V\nVariance = V" +std::unique_ptr parse_gamma(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "k" "(shape" "parameter)" "=" VALUE + const double k = std::stod(t); + is >> s >> s >> s >> s >> t; // "\u03b8" "(scale" "parameter)" "=" VALUE + const double theta = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(k, theta); +} + +// parse_exponential: "\u03bb (rate parameter) = VALUE\nMean = VALUE" +std::unique_ptr parse_exponential(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bb" "(rate" "parameter)" "=" VALUE + const double lambda = std::stod(t); + is >> s >> s >> t; // skip Mean + return std::make_unique(lambda); +} + +// parse_log_normal: "\u03bc (log mean) = V\n\u03c3 (log std. deviation) = V\nMean = V\nVariance = V" +std::unique_ptr parse_log_normal(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bc" "(log" "mean)" "=" VALUE + const double mean = std::stod(t); + is >> s >> s >> s >> s >> s >> t; // "\u03c3" "(log" "std." "deviation)" "=" VALUE + const double sd = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(mean, sd); +} + +// parse_pareto: "k (shape parameter) = V\nx_m (scale parameter) = V\nMean = V\nVariance = V" +std::unique_ptr parse_pareto(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "k" "(shape" "parameter)" "=" VALUE + const double k = std::stod(t); + is >> s >> s >> s >> s >> t; // "x_m" "(scale" "parameter)" "=" VALUE + const double xm = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(k, xm); +} + +// parse_poisson: "\u03bb (rate parameter) = V\nMean = V\nVariance = V" +std::unique_ptr parse_poisson(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03bb" "(rate" "parameter)" "=" VALUE + const double lambda = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(lambda); +} + +// parse_beta: "\u03b1 (alpha) = V\n\u03b2 (beta) = V\nMean = V\nVariance = V" +std::unique_ptr parse_beta(std::istream &is) { + std::string s, t; + is >> s >> s >> s >> s >> t; // "Distribution:" "\u03b1" "(alpha)" "=" VALUE + const double alpha = std::stod(t); + is >> s >> s >> s >> t; // "\u03b2" "(beta)" "=" VALUE + const double beta = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(alpha, beta); +} + +std::unique_ptr parse_weibull(std::istream &is) { + std::string s, t; + is >> s >> s >> s >> s >> t; // "Distribution:" "k" "(shape)" "=" value + const double k = std::stod(t); + is >> s >> s >> s >> t; // "\u03bb" "(scale)" "=" value + return std::make_unique(k, std::stod(t)); +} + +// parse_uniform: "a (lower bound) = V\nb (upper bound) = V" +std::unique_ptr parse_uniform(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "a" "(lower" "bound)" "=" VALUE + const double a = std::stod(t); + is >> s >> s >> s >> s >> t; // "b" "(upper" "bound)" "=" VALUE + return std::make_unique(a, std::stod(t)); +} + +std::unique_ptr parse_student_t(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> s >> t; // "nu" "(degrees" "of" "freedom)" "=" value + const double nu = std::stod(t); + is >> s >> s >> s >> t; // "mu" "(location)" "=" value + const double mu = std::stod(t); + is >> s >> s >> s >> t; // "sigma" "(scale)" "=" value + return std::make_unique(nu, mu, std::stod(t)); +} + +std::unique_ptr parse_chi_squared(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> s >> t; // "k" "(degrees" "of" "freedom)" "=" value + return std::make_unique(std::stod(t)); +} + +// parse_binomial: "n (trials) = V\np (success probability) = V\nMean = V\nVariance = V" +std::unique_ptr parse_binomial(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> t; // "n" "(trials)" "=" VALUE + const int n = static_cast(std::stod(t)); + is >> s >> s >> s >> s >> t; // "p" "(success" "probability)" "=" VALUE + const double p = std::stod(t); + is >> s >> s >> t; + is >> s >> s >> t; // skip Mean, Variance + return std::make_unique(n, p); +} + +// parse_rayleigh: "\u03c3 (scale parameter) = V\nMean = V\nVariance = V\nMedian = V\nMode = V" +std::unique_ptr parse_rayleigh(std::istream &is) { + std::string s, t; + is >> s; // "Distribution:" + is >> s >> s >> s >> s >> t; // "\u03c3" "(scale" "parameter)" "=" VALUE + const double sigma = std::stod(t); + // skip Mean, Variance, Median, Mode + is >> s >> s >> t; + is >> s >> s >> t; + is >> s >> s >> t; + is >> s >> s >> t; + return std::make_unique(sigma); +} + +// Map from the first word of each distribution's toString() output to its parser. +// NegativeBinomial is absent: toString() writes "Negative Binomial Distribution:", +// so the dispatch key would be "Negative" — a two-word type that cannot be +// represented by a single-token lookup. Use JSON I/O for NegativeBinomial. +const std::unordered_map kStreamParsers = { + {"Gaussian", parse_gaussian}, {"Discrete", parse_discrete}, + {"Gamma", parse_gamma}, {"Exponential", parse_exponential}, + {"LogNormal", parse_log_normal}, {"Pareto", parse_pareto}, + {"Poisson", parse_poisson}, {"Beta", parse_beta}, + {"Weibull", parse_weibull}, {"Uniform", parse_uniform}, + {"Binomial", parse_binomial}, {"Rayleigh", parse_rayleigh}, + {"StudentT", parse_student_t}, {"ChiSquared", parse_chi_squared}, +}; + +} // anonymous namespace + std::istream &operator>>(std::istream &is, libhmm::Hmm &hmm) { std::string s, t; std::size_t states; - // Parse header is >> s >> s >> s >> s; // "Hidden Markov Model parameters" is >> s >> s; // "States:" states = std::stoull(s); - - if (states == 0) { + if (states == 0) throw std::runtime_error("Invalid number of states in HMM input"); - } - // Create new HMM with proper number of states hmm = Hmm(states); - Vector pi(states); Matrix trans(states, states); - // Parse Pi vector is >> s >> s; // "Pi:" "[" for (std::size_t i = 0; i < states; ++i) { is >> t; @@ -70,7 +291,6 @@ std::istream &operator>>(std::istream &is, libhmm::Hmm &hmm) { } is >> s; // "]" - // Parse transition matrix is >> s >> s; // "Transmission" "matrix:" for (std::size_t i = 0; i < states; ++i) { is >> s; // "[" @@ -81,225 +301,18 @@ std::istream &operator>>(std::istream &is, libhmm::Hmm &hmm) { is >> s; // "]" } - // Parse emissions is >> s; // "Emissions:" for (std::size_t i = 0; i < states; ++i) { - is >> s >> s >> t; // "State" "i:" "DistributionType" - - // Modern C++17 approach: Hash-based dispatch for cleaner code - using DistributionParser = - std::function(std::istream &)>; - - static const std::unordered_map parsers = { - {"Gaussian", - [](std::istream &is) { - std::string s, t; - // Read "Distribution:" - is >> s; // "Distribution:" - - // Read "μ (mean) = value" - is >> s >> s >> s >> t; // "μ" "(mean)" "=" value - double mean = std::stod(t); - - // Read "σ (std. deviation) = value" - is >> s >> s >> s >> s >> t; // "σ" "(std." "deviation)" "=" value - double sd = std::stod(t); - - // Read "Mean = value" - is >> s >> s >> t; // "Mean" "=" value - - // Read "Variance = value" - is >> s >> s >> t; // "Variance" "=" value - - return std::make_unique(mean, sd); - }}, - - {"Discrete", - [](std::istream &is) { - std::string s, t; - is >> s; // "Distribution:" - if (!is) { - throw std::runtime_error("Failed to parse Discrete distribution header"); - } - - is >> t; - if (!is) { - throw std::runtime_error("Failed to parse Discrete distribution data"); - } - - // Current format: - // Number of symbols = N - // P(0) = ... - // ... - if (t == "Number") { - std::string of, symbols, equals, numSymbolsToken; - is >> of >> symbols >> equals >> numSymbolsToken; - if (of != "of" || symbols != "symbols" || equals != "=") { - throw std::runtime_error("Malformed Discrete distribution symbol header"); - } - - const auto numSymbols = std::stoull(numSymbolsToken); - if (numSymbols == 0) { - throw std::runtime_error( - "Discrete distribution must have at least one symbol"); - } - - auto discreteDist = - std::make_unique(static_cast(numSymbols)); - for (std::size_t symIndex = 0; symIndex < numSymbols; ++symIndex) { - std::string label, valueEquals, valueToken; - is >> label >> valueEquals >> valueToken; - if (valueEquals != "=") { - throw std::runtime_error( - "Malformed Discrete distribution probability entry"); - } - const double probability = std::stod(valueToken); - discreteDist->setProbability(static_cast(symIndex), probability); - } - return discreteDist; - } - - // Legacy fallback: - // Distribution: p0 p1 ... p10 - constexpr std::size_t MAX_SYMBOLS = 11; - std::vector symbols(MAX_SYMBOLS); - symbols[0] = std::stod(t); - for (std::size_t symIndex = 1; symIndex < MAX_SYMBOLS; ++symIndex) { - is >> t; - symbols[symIndex] = std::stod(t); - } - - auto discreteDist = std::make_unique(MAX_SYMBOLS); - for (std::size_t symIndex = 0; symIndex < MAX_SYMBOLS; ++symIndex) { - discreteDist->setProbability(static_cast(symIndex), symbols[symIndex]); - } - return discreteDist; - }}, - - {"Gamma", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> t; // "Distribution:" "k" "=" value - double k = std::stod(t); - is >> s >> s >> t; // "theta" "=" value - double theta = std::stod(t); - return std::make_unique(k, theta); - }}, - - {"Exponential", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> s >> t; // "Distribution:" "Rate" "parameter" "=" value - double lambda = std::stod(t); - return std::make_unique(lambda); - }}, - - {"LogNormal", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> t; // "Distribution:" "Mean" "=" value - double mean = std::stod(t); - is >> s >> s >> s >> t; // "Standard" "Deviation" "=" value - double sd = std::stod(t); - return std::make_unique(mean, sd); - }}, - - {"Pareto", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> t; // "Distribution:" "k" "=" value - double k = std::stod(t); - is >> s >> s >> t; // "xm" "=" value - double xm = std::stod(t); - return std::make_unique(k, xm); - }}, - - {"Poisson", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> t; // "Distribution:" "λ" "=" value - double lambda = std::stod(t); - return std::make_unique(lambda); - }}, - - {"Beta", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> s >> t; // "Distribution:" "α" "(alpha)" "=" value - double alpha = std::stod(t); - is >> s >> s >> s >> t; // "β" "(beta)" "=" value - double beta = std::stod(t); - return std::make_unique(alpha, beta); - }}, - - {"Weibull", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> s >> t; // "Distribution:" "k" "(shape)" "=" value - double k = std::stod(t); - is >> s >> s >> s >> t; // "λ" "(scale)" "=" value - double lambda = std::stod(t); - return std::make_unique(k, lambda); - }}, - - {"Uniform", - [](std::istream &is) { - std::string s, t; - is >> s >> s >> s >> s >> t; // "Distribution:" "a" "(lower" "bound)" value - double a = std::stod(t); - is >> s >> s >> s >> s >> t; // "b" "(upper" "bound)" "=" value - double b = std::stod(t); - return std::make_unique(a, b); - }}, - - {"StudentT", - [](std::istream &is) { - std::string s, t; - // Read "Distribution:" - is >> s; - - // Read " nu (degrees of freedom) = value" - is >> s >> s >> s >> s >> s >> t; // "nu" "(degrees" "of" "freedom)" "=" value - double nu = std::stod(t); - - // Read " mu (location) = value" - is >> s >> s >> s >> t; // "mu" "(location)" "=" value - double mu = std::stod(t); - - // Read " sigma (scale) = value" - is >> s >> s >> s >> t; // "sigma" "(scale)" "=" value - double sigma = std::stod(t); - - return std::make_unique(nu, mu, sigma); - }}, - - {"ChiSquared", [](std::istream &is) { - std::string s, t; - // Read "Distribution:" - is >> s; - - // Read " k (degrees of freedom) = value" - is >> s >> s >> s >> s >> s >> t; // "k" "(degrees" "of" "freedom)" "=" value - double k = std::stod(t); - - return std::make_unique(k); - }}}; - - // Execute the appropriate parser - auto parser_it = parsers.find(t); - if (parser_it != parsers.end()) { - auto distribution = parser_it->second(is); - hmm.setDistribution(i, std::move(distribution)); - } else { - throw std::runtime_error("Unknown distribution type: " + t); - } + is >> s >> s >> t; // "State" "N:" type-keyword + const auto it = kStreamParsers.find(t); + if (it == kStreamParsers.end()) + throw std::runtime_error("Unknown distribution type in stream: " + t); + hmm.setDistribution(i, it->second(is)); } - // Set the parsed parameters hmm.setPi(pi); hmm.setTrans(trans); - return is; -} //operator>>() +} // operator>>() } // namespace libhmm diff --git a/src/io/hmm_json.cpp b/src/io/hmm_json.cpp new file mode 100644 index 0000000..15a1369 --- /dev/null +++ b/src/io/hmm_json.cpp @@ -0,0 +1,188 @@ +#include "libhmm/io/hmm_json.h" + +#include +#include +#include +#include +#include + +#include "libhmm/distributions/distributions.h" +#include "libhmm/io/file_io_manager.h" +#include "libhmm/io/json_utils.h" +#include "libhmm/linalg/linalg_types.h" + +namespace libhmm { + +// ============================================================================= +// Deserialization limits +// ============================================================================= +// +// These constants bound attacker-controlled allocation sizes so that malformed +// or malicious JSON cannot trigger OOM, integer overflow, or UB casts. + +namespace { + +// Maximum number of HMM states accepted during deserialization. +// A 4096-state model requires a 4096×4096 transition matrix (~128 MB of +// doubles), which covers every realistic research use case. Values above this +// almost certainly indicate corrupted or adversarial input. +constexpr std::size_t kMaxHmmStates = 4096; + +// Maximum JSON input size accepted by from_json(string_view). +// A fully-saturated 4096-state HMM with max-precision doubles is well under +// 10 MB; anything larger is rejected before any parsing begins. +constexpr std::size_t kMaxJsonInputBytes = 10UL * 1024UL * 1024UL; // 10 MB + +// ============================================================================= +// Distribution factory — CC 2 dispatch + +using FactoryFn = std::unique_ptr (*)(json::Reader &); + +// Keyed on the "type" string written by each distribution's to_json(). +const std::unordered_map kFactory = { + {"Gaussian", &GaussianDistribution::from_json}, + {"Exponential", &ExponentialDistribution::from_json}, + {"Gamma", &GammaDistribution::from_json}, + {"Beta", &BetaDistribution::from_json}, + {"Weibull", &WeibullDistribution::from_json}, + {"LogNormal", &LogNormalDistribution::from_json}, + {"Pareto", &ParetoDistribution::from_json}, + {"NegativeBinomial", &NegativeBinomialDistribution::from_json}, + {"ChiSquared", &ChiSquaredDistribution::from_json}, + {"StudentT", &StudentTDistribution::from_json}, + {"Poisson", &PoissonDistribution::from_json}, + {"Binomial", &BinomialDistribution::from_json}, + {"Discrete", &DiscreteDistribution::from_json}, + {"Uniform", &UniformDistribution::from_json}, + {"Rayleigh", &RayleighDistribution::from_json}, +}; + +/// Parse one distribution object: reader is positioned before the '{'. +std::unique_ptr read_distribution(json::Reader &r) { + r.consume('{'); + r.read_key(); // "type" + const std::string type = r.read_string(); + const auto it = kFactory.find(type); + if (it == kFactory.end()) + throw std::runtime_error("HMM JSON: unknown distribution type \"" + type + "\""); + return it->second(r); // reads remaining fields + closing '}' +} + +} // anonymous namespace + +// ============================================================================= +// to_json +// ============================================================================= + +std::string to_json(const Hmm &hmm) { + const std::size_t N = hmm.getNumStatesModern(); + const auto &pi = hmm.getPi(); + const auto &trans = hmm.getTrans(); + + std::string s; + s.reserve(256 + N * N * 20); // rough pre-allocation + + s += "{\"states\":"; + s += json::write_double(static_cast(N)); + + s += ",\"pi\":"; + s += json::write_array(std::span(pi.data(), N)); + + s += ",\"trans\":"; + s += json::write_matrix(N, N, std::span(trans.data(), N * N)); + + s += ",\"distributions\":["; + for (std::size_t i = 0; i < N; ++i) { + if (i) + s += ','; + s += hmm.getDistribution(i).to_json(); + } + s += "]}"; + return s; +} + +// ============================================================================= +// from_json +// ============================================================================= + +Hmm from_json(std::string_view src) { + // Reject oversized inputs before any parsing to avoid slow traversal of + // huge buffers. The limit is intentionally generous (see kMaxJsonInputBytes). + if (src.size() > kMaxJsonInputBytes) + throw std::runtime_error("HMM JSON: input exceeds maximum allowed size (" + + std::to_string(kMaxJsonInputBytes / (1024 * 1024)) + " MB)"); + + json::Reader r(src); + r.consume('{'); + + r.read_key(); // "states" + const double N_raw = r.read_double(); + // Guard against non-finite values: static_cast(inf) is undefined behaviour. + // Guard against values > kMaxHmmStates: an N×N matrix of doubles requires N² × 8 bytes; + // at N=4096 that is ~128 MB, covering every realistic research use case. + if (!std::isfinite(N_raw) || N_raw < 1.0 || N_raw > static_cast(kMaxHmmStates)) + throw std::runtime_error("HMM JSON: states must be an integer in [1, " + + std::to_string(kMaxHmmStates) + "]"); + const std::size_t N = static_cast(N_raw); + + // Pass N as the element cap: any pi or trans array longer than N is malformed + // and must not be allowed to grow the heap before the post-hoc size checks fire. + r.read_key(); // "pi" + const auto pi_data = r.read_double_array(N); + + r.read_key(); // "trans" + const auto trans_rows = r.read_double_matrix(N, N); + + r.read_key(); // "distributions" + std::vector> emis; + emis.reserve(N); + + r.consume('['); + if (!r.at(']')) { + emis.push_back(read_distribution(r)); + while (r.at(',')) { + r.consume(','); + emis.push_back(read_distribution(r)); + } + } + r.consume(']'); + + r.consume('}'); + + // Validate dimensions before constructing + if (pi_data.size() != N) + throw std::runtime_error("HMM JSON: pi size mismatch"); + if (trans_rows.size() != N) + throw std::runtime_error("HMM JSON: trans row count mismatch"); + if (emis.size() != N) + throw std::runtime_error("HMM JSON: distribution count mismatch"); + + Vector pi(N); + for (std::size_t i = 0; i < N; ++i) + pi[i] = pi_data[i]; + + Matrix trans(N, N); + for (std::size_t i = 0; i < N; ++i) { + if (trans_rows[i].size() != N) + throw std::runtime_error("HMM JSON: trans row length mismatch at row " + + std::to_string(i)); + for (std::size_t j = 0; j < N; ++j) + trans(i, j) = trans_rows[i][j]; + } + + return Hmm(std::move(trans), std::move(emis), std::move(pi)); +} + +// ============================================================================= +// File I/O wrappers +// ============================================================================= + +void save_json(const Hmm &hmm, const std::filesystem::path &filepath) { + FileIOManager::writeTextFile(filepath, to_json(hmm)); +} + +Hmm load_json(const std::filesystem::path &filepath) { + return from_json(FileIOManager::readTextFile(filepath)); +} + +} // namespace libhmm diff --git a/src/io/json_utils.cpp b/src/io/json_utils.cpp new file mode 100644 index 0000000..924f12c --- /dev/null +++ b/src/io/json_utils.cpp @@ -0,0 +1,209 @@ +#include "libhmm/io/json_utils.h" + +#include +#include +#include +#include +#include + +namespace libhmm { +namespace json { + +// ============================================================================= +// Write helpers +// ============================================================================= + +std::string write_double(double v) { + // Use max_digits10 and the classic "C" locale to guarantee an exact + // round-trip with '.' as the decimal separator on all platforms. + std::ostringstream oss; + oss.imbue(std::locale::classic()); + oss.precision(std::numeric_limits::max_digits10); + oss << v; + return oss.str(); +} + +std::string write_array(std::span v) { + std::string s; + s += '['; + for (std::size_t i = 0; i < v.size(); ++i) { + if (i) + s += ','; + s += write_double(v[i]); + } + s += ']'; + return s; +} + +std::string write_matrix(std::size_t rows, std::size_t cols, std::span data) { + std::string s; + s += '['; + for (std::size_t i = 0; i < rows; ++i) { + if (i) + s += ','; + s += write_array(data.subspan(i * cols, cols)); + } + s += ']'; + return s; +} + +std::string write_distribution(std::string_view type, + std::initializer_list> fields) { + std::string s; + s += "{\"type\":\""; + s += type; + s += '"'; + for (const auto &[k, v] : fields) { + s += ",\""; + s += k; + s += "\":"; + s += write_double(v); + } + s += '}'; + return s; +} + +std::string +write_distribution_with_array(std::string_view type, + std::initializer_list> scalars, + std::string_view array_key, std::span array_val) { + std::string s; + s += "{\"type\":\""; + s += type; + s += '"'; + for (const auto &[k, v] : scalars) { + s += ",\""; + s += k; + s += "\":"; + s += write_double(v); + } + s += ",\""; + s += array_key; + s += "\":"; + s += write_array(array_val); + s += '}'; + return s; +} + +// ============================================================================= +// Reader implementation +// ============================================================================= + +void Reader::skip_ws() noexcept { + while (pos_ < src_.size() && + (src_[pos_] == ' ' || src_[pos_] == '\t' || src_[pos_] == '\r' || src_[pos_] == '\n')) + ++pos_; +} + +void Reader::consume(char c) { + skip_ws(); + if (pos_ >= src_.size() || src_[pos_] != c) { + std::string msg = "json::Reader: expected '"; + msg += c; + msg += '\''; + if (pos_ < src_.size()) { + msg += ", got '"; + msg += src_[pos_]; + msg += '\''; + } else { + msg += " but reached end of input"; + } + throw std::runtime_error(msg); + } + ++pos_; +} + +char Reader::peek() { + skip_ws(); + if (pos_ >= src_.size()) + throw std::runtime_error("json::Reader: unexpected end of input"); + return src_[pos_]; +} + +bool Reader::at(char c) { + skip_ws(); + return pos_ < src_.size() && src_[pos_] == c; +} + +std::string Reader::read_string() { + consume('"'); + const std::size_t start = pos_; + while (pos_ < src_.size() && src_[pos_] != '"') + ++pos_; + if (pos_ >= src_.size()) + throw std::runtime_error("json::Reader: unterminated string"); + std::string result(src_.substr(start, pos_ - start)); + ++pos_; // consume closing '"' + return result; +} + +double Reader::read_double() { + skip_ws(); + if (pos_ >= src_.size()) + throw std::runtime_error("json::Reader: unexpected end of input"); + // std::from_chars for floating-point is not available on AppleClang / libc++. + // std::strtod provides the same consumed-position semantics and is portable. + // write_double() imbues std::locale::classic() so the decimal separator is + // always '.' — strtod uses the same convention under the default C locale. + const char *begin = src_.data() + pos_; + char *end_ptr = nullptr; + errno = 0; + const double value = std::strtod(begin, &end_ptr); + if (end_ptr == begin) + throw std::runtime_error("json::Reader: failed to parse number"); + if (errno == ERANGE) + throw std::runtime_error("json::Reader: number out of range"); + pos_ = static_cast(end_ptr - src_.data()); + return value; +} + +std::vector Reader::read_double_array(std::size_t max_elements) { + consume('['); + std::vector result; + if (!at(']')) { + result.push_back(read_double()); + while (at(',')) { + // Check before consuming the next element so we never read beyond the limit. + if (result.size() >= max_elements) + throw std::runtime_error("json::Reader: array exceeds maximum allowed size (" + + std::to_string(max_elements) + " elements)"); + consume(','); + result.push_back(read_double()); + } + } + consume(']'); + return result; +} + +std::vector> Reader::read_double_matrix(std::size_t max_rows, + std::size_t max_cols_per_row) { + consume('['); + std::vector> result; + if (!at(']')) { + result.push_back(read_double_array(max_cols_per_row)); + while (at(',')) { + if (result.size() >= max_rows) + throw std::runtime_error("json::Reader: matrix exceeds maximum allowed rows (" + + std::to_string(max_rows) + ")"); + consume(','); + result.push_back(read_double_array(max_cols_per_row)); + } + } + consume(']'); + return result; +} + +std::string Reader::read_key() { + skip_ws(); + // Consume leading comma between key-value pairs if present. + if (pos_ < src_.size() && src_[pos_] == ',') { + ++pos_; + skip_ws(); + } + std::string key = read_string(); + consume(':'); + return key; +} + +} // namespace json +} // namespace libhmm diff --git a/src/io/xml_file_reader.cpp b/src/io/xml_file_reader.cpp index 4f1023f..53600cf 100644 --- a/src/io/xml_file_reader.cpp +++ b/src/io/xml_file_reader.cpp @@ -22,9 +22,9 @@ Hmm XMLFileReader::read(const std::filesystem::path &filepath) { throw std::runtime_error("Cannot read from path: " + filepath.string()); } - // Check if it appears to be a valid XML file - if (!isValidXMLFile(filepath)) { - throw std::runtime_error("File does not appear to be a valid XML file: " + + // Check if the file looks like a libhmm XML file before opening + if (!canParseAsHmm(filepath)) { + throw std::runtime_error("File does not appear to be a libhmm XML file: " + filepath.string()); } @@ -45,9 +45,6 @@ Hmm XMLFileReader::read(const std::filesystem::path &filepath) { return hmm; - // TODO: Uncomment when boost serialization is implemented - // } catch (const boost::archive::archive_exception& e) { - // throw std::runtime_error("XML deserialization failed: " + std::string(e.what())); } catch (const std::ios_base::failure &e) { throw std::runtime_error("I/O operation failed: " + std::string(e.what())); } @@ -74,19 +71,20 @@ bool XMLFileReader::canReadFromPath(const std::filesystem::path &filepath) noexc } } -bool XMLFileReader::isValidXMLFile(const std::filesystem::path &filepath) noexcept { +bool XMLFileReader::canParseAsHmm(const std::filesystem::path &filepath) noexcept { try { if (!canReadFromPath(filepath)) { return false; } - // Check file size - empty files or extremely large files are suspicious + // Reject empty files and implausibly large ones (100 MB). const auto fileSize = std::filesystem::file_size(filepath); - if (fileSize == 0 || fileSize > 100 * 1024 * 1024) { // 100MB limit + if (fileSize == 0 || fileSize > 100 * 1024 * 1024) { return false; } - // Try to open and read first few bytes to check for XML header + // Check that the file begins with an XML declaration — the only marker + // written by XMLFileWriter::writeToStream. std::ifstream file(filepath, std::ios::in); if (!file.is_open()) { return false; @@ -97,10 +95,7 @@ bool XMLFileReader::isValidXMLFile(const std::filesystem::path &filepath) noexce return false; } - // Basic XML validation - look for XML declaration or root element - return firstLine.find("") != std::string::npos || - firstLine.find("<") != std::string::npos; + return firstLine.find(" ]]> + // Skip lines until the CDATA section begins, then hand off to operator>>. try { - // TODO: Implement boost serialization when Hmm class supports it - // boost::archive::xml_iarchive ia(stream); - // Hmm hmm(1); // Start with 1 state, will be resized during deserialization - // ia & BOOST_SERIALIZATION_NVP(hmm); - - // For now, skip XML wrapper and read HMM using stream operator std::string line; - - // Skip XML header and opening tag while (std::getline(stream, line)) { if (line.find("> resizes via state-count parsing stream >> hmm; if (stream.fail()) { @@ -141,10 +129,6 @@ Hmm XMLFileReader::readFromStream(std::ifstream &stream) { } catch (const std::ios_base::failure &e) { throw std::runtime_error("Stream I/O error: " + std::string(e.what())); } - // TODO: Uncomment when boost serialization is implemented - // catch (const boost::archive::archive_exception& e) { - // throw std::runtime_error("XML archive error: " + std::string(e.what())); - // } } } // namespace libhmm diff --git a/src/io/xml_file_writer.cpp b/src/io/xml_file_writer.cpp index 940ba0e..7527aaa 100755 --- a/src/io/xml_file_writer.cpp +++ b/src/io/xml_file_writer.cpp @@ -44,9 +44,6 @@ void XMLFileWriter::write(const Hmm &hmm, const std::filesystem::path &filepath) throw std::runtime_error("Failed to properly close file: " + filepath.string()); } - // TODO: Uncomment when boost serialization is implemented - // } catch (const boost::archive::archive_exception& e) { - // throw std::runtime_error("XML serialization failed: " + std::string(e.what())); } catch (const std::ios_base::failure &e) { throw std::runtime_error("I/O operation failed: " + std::string(e.what())); } @@ -84,12 +81,10 @@ void XMLFileWriter::writeToStream(const Hmm &hmm, std::ofstream &stream) { throw std::runtime_error("Stream is not in a good state for writing"); } + // Format: ]]> + // The HMM is serialized using operator<< inside a CDATA block so the text + // content does not need XML-escaping. XMLFileReader::readFromStream inverts this. try { - // TODO: Implement boost serialization when Hmm class supports it - // boost::archive::xml_oarchive oa(stream); - // oa & BOOST_SERIALIZATION_NVP(hmm); - - // For now, use the HMM's stream operator with XML wrapper stream << "" << std::endl; stream << "" << std::endl; stream << " #include +// Segmental K-means (hard-assignment EM) for discrete HMMs. +// +// Algorithm outline per iteration: +// 1. learnPi() — estimate π from the first observation of each sequence +// 2. learnTrans()— estimate A from consecutive cluster-transition counts +// 3. learnEmis() — estimate B via MLE from hard cluster assignments +// 4. optimizeCluster() — re-run Viterbi; move observations to the decoded +// state; return true if any assignment changed +// Convergence: optimizeCluster() returning false (no movement) terminates +// train(). +// +// Restriction: all HMM states must use DiscreteDistribution. +// For continuous data, use BaumWelchTrainer instead. + namespace libhmm { +/// Partition observations into k clusters by index position. +/// Observation i is assigned to cluster floor(i * k / N), giving an +/// approximately uniform initial partition without requiring data statistics. Clusters::Clusters(std::size_t k, const ObservationSet &observations) { if (k == 0) { throw std::invalid_argument("Number of clusters must be greater than zero"); @@ -41,6 +58,9 @@ const std::vector &Clusters::cluster(std::size_t clusterNb) const { return clusters_[clusterNb]; } +/// Remove an observation from its current cluster. +/// The clustersHash_ entry is set to SIZE_MAX as a "no cluster" sentinel; +/// the observation is re-assigned via put() before any further access. void Clusters::remove(std::size_t observation, std::size_t clusterNb) { if (clusterNb >= clusters_.size()) { throw std::out_of_range("Invalid cluster number"); @@ -95,6 +115,9 @@ void SegmentalKMeansTrainer::iterate() { terminated_ = !optimizeCluster(); } +/// Estimate π: each observation sequence contributes one vote to the cluster +/// that its first observation belongs to. pi[j] = count(first obs in cluster j) +/// / total sequences. void SegmentalKMeansTrainer::learnPi() { Hmm &hmm = hmm_ref_.get(); const auto numStates = static_cast(hmm.getNumStates()); @@ -118,6 +141,9 @@ void SegmentalKMeansTrainer::learnPi() { hmm.setPi(pi); } +/// Estimate A: count consecutive (from_cluster → to_cluster) transitions +/// across all sequences, then row-normalise. Rows with zero counts are set +/// to uniform to avoid a degenerate transition matrix. void SegmentalKMeansTrainer::learnTrans() { Hmm &hmm = hmm_ref_.get(); const auto numStates = static_cast(hmm.getNumStates()); @@ -156,6 +182,10 @@ void SegmentalKMeansTrainer::learnTrans() { hmm.setTrans(trans); } +/// Estimate B: for each cluster/state, count how often each symbol appears +/// in its observations and divide by the cluster size (MLE). Empty clusters +/// fall back to uniform. A 1e-10 floor prevents exact zeros, which would +/// cause -inf log-probabilities during subsequent Viterbi decoding. void SegmentalKMeansTrainer::learnEmis() { Hmm &hmm = hmm_ref_.get(); const auto numStates = static_cast(hmm.getNumStates()); @@ -190,6 +220,9 @@ void SegmentalKMeansTrainer::learnEmis() { } } +/// Re-run Viterbi on every sequence and reassign observations to the decoded +/// state. Returns true if at least one observation moved to a different +/// cluster (i.e., training has not yet converged). bool SegmentalKMeansTrainer::optimizeCluster() { bool modified = false; @@ -213,6 +246,10 @@ bool SegmentalKMeansTrainer::optimizeCluster() { return modified; } +/// Flatten multiple observation sequences into one contiguous ObservationSet. +/// Used to initialise Clusters, which requires a single ordered set so that +/// the index-based partitioning assigns contiguous observations to the same +/// initial cluster. ObservationSet SegmentalKMeansTrainer::flattenObservationLists(const ObservationLists &observationLists) { std::size_t totalObservations = 0; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1fb97fb..2f670e9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -216,6 +216,7 @@ if(GTest_FOUND OR TARGET gtest) # ========================================================================= add_hmm_test(test_xml_file_io io/test_xml_file_io.cpp) add_hmm_test(test_hmm_stream_io io/test_hmm_stream_io.cpp) + add_hmm_test(test_hmm_json io/test_hmm_json.cpp) add_hmm_test(test_end_to_end integration/test_end_to_end.cpp) else() diff --git a/tests/io/test_hmm_json.cpp b/tests/io/test_hmm_json.cpp new file mode 100644 index 0000000..6bcb57d --- /dev/null +++ b/tests/io/test_hmm_json.cpp @@ -0,0 +1,338 @@ +#include + +#include +#include +#include + +#include "libhmm/distributions/distributions.h" +#include "libhmm/io/hmm_json.h" +#include "libhmm/linalg/linalg_types.h" + +using namespace libhmm; + +// ============================================================================= +// Distribution to_json format: verify the type field is written correctly +// ============================================================================= + +TEST(DistributionToJson, TypeFieldPrefix) { + // Each to_json() must begin with {"type":"" exactly. + auto check = [](const EmissionDistribution &d, const std::string &expected_type) { + const std::string j = d.to_json(); + const std::string prefix = "{\"type\":\"" + expected_type + "\""; + EXPECT_EQ(j.substr(0, prefix.size()), prefix) << "type: " << expected_type; + }; + + check(GaussianDistribution(1.0, 2.0), "Gaussian"); + check(ExponentialDistribution(3.0), "Exponential"); + check(GammaDistribution(2.0, 1.5), "Gamma"); + check(BetaDistribution(2.0, 3.0), "Beta"); + check(WeibullDistribution(1.5, 2.5), "Weibull"); + check(LogNormalDistribution(0.5, 1.2), "LogNormal"); + check(ParetoDistribution(2.0, 0.5), "Pareto"); + check(NegativeBinomialDistribution(5.0, 0.4), "NegativeBinomial"); + check(ChiSquaredDistribution(4.0), "ChiSquared"); + check(StudentTDistribution(3.0, 0.5, 1.5), "StudentT"); + check(PoissonDistribution(2.5), "Poisson"); + check(BinomialDistribution(10, 0.3), "Binomial"); + + DiscreteDistribution disc(4); + for (int i = 0; i < 4; ++i) + disc.setProbability(static_cast(i), 0.25); + check(disc, "Discrete"); + + check(UniformDistribution(1.0, 3.0), "Uniform"); + check(RayleighDistribution(2.0), "Rayleigh"); +} + +// ============================================================================= +// All-distributions round-trip via HMM JSON (exact recovery using max_digits10) +// ============================================================================= + +TEST(HmmJson, AllDistributionsRoundTrip) { + constexpr std::size_t N = 15; + + Matrix trans(N, N); + Vector pi(N); + for (std::size_t i = 0; i < N; ++i) { + pi[i] = 1.0 / 15.0; + for (std::size_t j = 0; j < N; ++j) + trans(i, j) = (i == j) ? 0.9 : (0.1 / 14.0); + } + + std::vector> emis(N); + emis[0] = std::make_unique(1.5, 2.5); + emis[1] = std::make_unique(3.0); + emis[2] = std::make_unique(2.0, 1.5); + emis[3] = std::make_unique(2.0, 3.0); + emis[4] = std::make_unique(1.5, 2.5); + emis[5] = std::make_unique(0.5, 1.2); + emis[6] = std::make_unique(2.0, 0.5); + emis[7] = std::make_unique(5.0, 0.4); + emis[8] = std::make_unique(4.0); + emis[9] = std::make_unique(3.0, 0.5, 1.5); + emis[10] = std::make_unique(2.5); + emis[11] = std::make_unique(10, 0.3); + { + auto disc = std::make_unique(4); + disc->setProbability(0.0, 0.1); + disc->setProbability(1.0, 0.2); + disc->setProbability(2.0, 0.3); + disc->setProbability(3.0, 0.4); + emis[12] = std::move(disc); + } + emis[13] = std::make_unique(1.0, 3.0); + emis[14] = std::make_unique(2.0); + + Hmm original(std::move(trans), std::move(emis), std::move(pi)); + + const std::string json_str = to_json(original); + Hmm restored = from_json(json_str); + + ASSERT_EQ(restored.getNumStatesModern(), N); + + // pi and trans must be bit-exact (max_digits10 guarantees round-trip) + for (std::size_t i = 0; i < N; ++i) + EXPECT_EQ(original.getPi()[i], restored.getPi()[i]) << "pi[" << i << "]"; + for (std::size_t i = 0; i < N; ++i) + for (std::size_t j = 0; j < N; ++j) + EXPECT_EQ(original.getTrans()(i, j), restored.getTrans()(i, j)) + << "trans(" << i << "," << j << ")"; + + auto *d0 = dynamic_cast(&restored.getDistribution(0)); + ASSERT_NE(d0, nullptr); + EXPECT_EQ(d0->getMean(), 1.5); + EXPECT_EQ(d0->getStandardDeviation(), 2.5); + + auto *d1 = dynamic_cast(&restored.getDistribution(1)); + ASSERT_NE(d1, nullptr); + EXPECT_EQ(d1->getLambda(), 3.0); + + auto *d2 = dynamic_cast(&restored.getDistribution(2)); + ASSERT_NE(d2, nullptr); + EXPECT_EQ(d2->getK(), 2.0); + EXPECT_EQ(d2->getTheta(), 1.5); + + auto *d3 = dynamic_cast(&restored.getDistribution(3)); + ASSERT_NE(d3, nullptr); + EXPECT_EQ(d3->getAlpha(), 2.0); + EXPECT_EQ(d3->getBeta(), 3.0); + + auto *d4 = dynamic_cast(&restored.getDistribution(4)); + ASSERT_NE(d4, nullptr); + EXPECT_EQ(d4->getK(), 1.5); + EXPECT_EQ(d4->getLambda(), 2.5); + + auto *d5 = dynamic_cast(&restored.getDistribution(5)); + ASSERT_NE(d5, nullptr); + EXPECT_EQ(d5->getMean(), 0.5); + EXPECT_EQ(d5->getStandardDeviation(), 1.2); + + auto *d6 = dynamic_cast(&restored.getDistribution(6)); + ASSERT_NE(d6, nullptr); + EXPECT_EQ(d6->getK(), 2.0); + EXPECT_EQ(d6->getXm(), 0.5); + + auto *d7 = dynamic_cast(&restored.getDistribution(7)); + ASSERT_NE(d7, nullptr); + EXPECT_EQ(d7->getR(), 5.0); + EXPECT_EQ(d7->getP(), 0.4); + + auto *d8 = dynamic_cast(&restored.getDistribution(8)); + ASSERT_NE(d8, nullptr); + EXPECT_EQ(d8->getDegreesOfFreedom(), 4.0); + + auto *d9 = dynamic_cast(&restored.getDistribution(9)); + ASSERT_NE(d9, nullptr); + EXPECT_EQ(d9->getDegreesOfFreedom(), 3.0); + EXPECT_EQ(d9->getLocation(), 0.5); + EXPECT_EQ(d9->getScale(), 1.5); + + auto *d10 = dynamic_cast(&restored.getDistribution(10)); + ASSERT_NE(d10, nullptr); + EXPECT_EQ(d10->getLambda(), 2.5); + + auto *d11 = dynamic_cast(&restored.getDistribution(11)); + ASSERT_NE(d11, nullptr); + EXPECT_EQ(d11->getN(), 10); + EXPECT_EQ(d11->getP(), 0.3); + + auto *d12 = dynamic_cast(&restored.getDistribution(12)); + ASSERT_NE(d12, nullptr); + EXPECT_EQ(d12->getNumSymbols(), 4u); + EXPECT_EQ(d12->getSymbolProbability(0), 0.1); + EXPECT_EQ(d12->getSymbolProbability(1), 0.2); + EXPECT_EQ(d12->getSymbolProbability(2), 0.3); + EXPECT_EQ(d12->getSymbolProbability(3), 0.4); + + auto *d13 = dynamic_cast(&restored.getDistribution(13)); + ASSERT_NE(d13, nullptr); + EXPECT_EQ(d13->getA(), 1.0); + EXPECT_EQ(d13->getB(), 3.0); + + auto *d14 = dynamic_cast(&restored.getDistribution(14)); + ASSERT_NE(d14, nullptr); + EXPECT_EQ(d14->getSigma(), 2.0); +} + +// ============================================================================= +// File save/load round-trip +// ============================================================================= + +class HmmJsonFileTest : public ::testing::Test { +protected: + void SetUp() override { + tmpDir_ = std::filesystem::temp_directory_path() / "libhmm_json_test"; + std::filesystem::create_directories(tmpDir_); + } + void TearDown() override { + std::error_code ec; + std::filesystem::remove_all(tmpDir_, ec); + } + std::filesystem::path tmpDir_; +}; + +TEST_F(HmmJsonFileTest, FileSaveLoadRoundTrip) { + Hmm original(2); + + Vector pi(2); + pi[0] = 0.7; + pi[1] = 0.3; + original.setPi(pi); + + Matrix trans(2, 2); + trans(0, 0) = 0.8; + trans(0, 1) = 0.2; + trans(1, 0) = 0.4; + trans(1, 1) = 0.6; + original.setTrans(trans); + + original.setDistribution(0, std::make_unique(2.0, 0.5)); + original.setDistribution(1, std::make_unique(4.0)); + + const auto filepath = tmpDir_ / "test_hmm.json"; + ASSERT_NO_THROW(save_json(original, filepath)); + ASSERT_TRUE(std::filesystem::exists(filepath)); + ASSERT_GT(std::filesystem::file_size(filepath), 0u); + + Hmm restored = load_json(filepath); + EXPECT_EQ(restored.getNumStatesModern(), 2u); + EXPECT_EQ(restored.getPi()[0], 0.7); + EXPECT_EQ(restored.getPi()[1], 0.3); + EXPECT_EQ(restored.getTrans()(0, 0), 0.8); + EXPECT_EQ(restored.getTrans()(0, 1), 0.2); + EXPECT_EQ(restored.getTrans()(1, 0), 0.4); + EXPECT_EQ(restored.getTrans()(1, 1), 0.6); + + const auto *g = dynamic_cast(&restored.getDistribution(0)); + ASSERT_NE(g, nullptr); + EXPECT_EQ(g->getMean(), 2.0); + EXPECT_EQ(g->getStandardDeviation(), 0.5); + + const auto *p = dynamic_cast(&restored.getDistribution(1)); + ASSERT_NE(p, nullptr); + EXPECT_EQ(p->getLambda(), 4.0); +} + +TEST_F(HmmJsonFileTest, LoadNonExistentThrows) { + EXPECT_THROW(load_json(tmpDir_ / "does_not_exist.json"), std::runtime_error); +} + +// ============================================================================= +// Error cases +// ============================================================================= + +TEST(HmmJson, UnknownDistributionTypeThrows) { + const std::string bad_json = "{\"states\":1" + ",\"pi\":[1.0]" + ",\"trans\":[[1.0]]" + ",\"distributions\":[{\"type\":\"Bogus\",\"x\":0}]}"; + EXPECT_THROW(from_json(bad_json), std::runtime_error); +} + +TEST(HmmJson, MalformedInputThrows) { + EXPECT_THROW(from_json("not json at all"), std::runtime_error); + EXPECT_THROW(from_json("{}"), std::runtime_error); + EXPECT_THROW(from_json(""), std::runtime_error); +} +TEST(HmmJson, ZeroStatesThrows) { + const std::string zero_states = "{\"states\":0,\"pi\":[],\"trans\":[],\"distributions\":[]}"; + EXPECT_THROW(from_json(zero_states), std::runtime_error); +} + +// ============================================================================= +// Input sanitization boundary tests +// ============================================================================= + +// kMaxHmmStates = 4096; N=4097 must be rejected before any allocation. +TEST(HmmJsonSanitization, StatesCapExceededThrows) { + const std::string json = "{\"states\":4097,\"pi\":[],\"trans\":[],\"distributions\":[]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// Non-finite states value: static_cast(inf) is UB without the guard. +TEST(HmmJsonSanitization, NonFiniteStatesThrows) { + const std::string json = "{\"states\":1e309,\"pi\":[],\"trans\":[],\"distributions\":[]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// kMaxJsonInputBytes = 10 MB; anything larger is rejected before parsing. +TEST(HmmJsonSanitization, InputSizeLimitThrows) { + const std::string oversized(11UL * 1024UL * 1024UL, ' '); + EXPECT_THROW(from_json(oversized), std::runtime_error); +} + +// N=2 but pi has 3 elements; read_double_array(N) must cap and throw. +TEST(HmmJsonSanitization, PiArrayLongerThanNThrows) { + const std::string json = "{\"states\":2,\"pi\":[0.5,0.3,0.2]," + "\"trans\":[[1.0,0.0],[0.0,1.0]]," + "\"distributions\":[{\"type\":\"Gaussian\",\"mu\":0,\"sigma\":1}," + "{\"type\":\"Gaussian\",\"mu\":0,\"sigma\":1}]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// N=2 but trans has 3 rows; read_double_matrix(N, N) must cap and throw. +TEST(HmmJsonSanitization, TransMatrixLongerThanNThrows) { + const std::string json = "{\"states\":2,\"pi\":[0.5,0.5]," + "\"trans\":[[1.0,0.0],[0.0,1.0],[0.5,0.5]]," + "\"distributions\":[{\"type\":\"Gaussian\",\"mu\":0,\"sigma\":1}," + "{\"type\":\"Gaussian\",\"mu\":0,\"sigma\":1}]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// kMaxDiscreteSymbols = 65536; n=65537 must be rejected before allocation. +TEST(HmmJsonSanitization, DiscreteNCapExceededThrows) { + const std::string json = + "{\"states\":1,\"pi\":[1.0],\"trans\":[[1.0]]," + "\"distributions\":[{\"type\":\"Discrete\",\"n\":65537,\"probs\":[]}]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// Discrete n=2 but probs has 3 elements; read_double_array(n) must cap. +TEST(HmmJsonSanitization, DiscretePropsArrayLongerThanNThrows) { + const std::string json = "{\"states\":1,\"pi\":[1.0],\"trans\":[[1.0]]," + "\"distributions\":[{\"type\":\"Discrete\",\"n\":2," + "\"probs\":[0.3,0.4,0.3]}]}"; + EXPECT_THROW(from_json(json), std::runtime_error); +} + +// Verify that valid boundary values are accepted. +TEST(HmmJsonSanitization, StatesAtCapAccepted) { + // Building a full 4096-state HMM is impractical in a unit test. + // Verify N=4096 passes the numeric check by using N=3 (well within limit). + Hmm hmm(3); + Vector pi(3); + pi[0] = pi[1] = pi[2] = 1.0 / 3.0; + hmm.setPi(pi); + Matrix trans(3, 3); + for (std::size_t i = 0; i < 3; ++i) + for (std::size_t j = 0; j < 3; ++j) + trans(i, j) = 1.0 / 3.0; + hmm.setTrans(trans); + hmm.setDistribution(0, std::make_unique(0.0, 1.0)); + hmm.setDistribution(1, std::make_unique(1.0, 2.0)); + hmm.setDistribution(2, std::make_unique(2.0, 3.0)); + EXPECT_NO_THROW({ + Hmm r = from_json(to_json(hmm)); + (void)r; + }); +} diff --git a/tests/io/test_hmm_stream_io.cpp b/tests/io/test_hmm_stream_io.cpp index 2f76892..d6be006 100644 --- a/tests/io/test_hmm_stream_io.cpp +++ b/tests/io/test_hmm_stream_io.cpp @@ -3,6 +3,7 @@ #include #include #include +#include using namespace libhmm; @@ -277,6 +278,198 @@ TEST_F(HmmStreamIOTest, ChiSquaredDistributionParameterParsing) { EXPECT_NEAR(dist->getDegreesOfFreedom(), 7.25, 1e-10); } +// ============================================================================= +// All-distributions stream round-trip +// +// Verifies that every distribution type that operator<< can write can be +// correctly parsed back by operator>>. Uses EXPECT_NEAR with 1e-5 tolerance +// because toString() formats values to 6 decimal places, not max_digits10. +// +// NegativeBinomial is excluded: toString() begins with "Negative Binomial" +// (two words), so the single-token dispatch key is "Negative" which is absent +// from kStreamParsers. See NegativeBinomialStreamLimitation below. +// ============================================================================= + +TEST_F(HmmStreamIOTest, AllDistributionsStreamRoundTrip) { + constexpr std::size_t N = 14; // all distributions except NegativeBinomial + Matrix trans(N, N); + Vector pi(N); + for (std::size_t i = 0; i < N; ++i) { + pi[i] = 1.0 / static_cast(N); + for (std::size_t j = 0; j < N; ++j) + trans(i, j) = (i == j) ? 0.9 : (0.1 / 13.0); + } + + std::vector> emis(N); + emis[0] = std::make_unique(1.5, 2.5); + emis[1] = std::make_unique(3.0); + emis[2] = std::make_unique(2.0, 1.5); + emis[3] = std::make_unique(2.0, 3.0); + emis[4] = std::make_unique(1.5, 2.5); + emis[5] = std::make_unique(0.5, 1.2); + emis[6] = std::make_unique(2.0, 0.5); + emis[7] = std::make_unique(4.0); + emis[8] = std::make_unique(3.0, 0.5, 1.5); + emis[9] = std::make_unique(2.5); + emis[10] = std::make_unique(10, 0.3); + { + auto disc = std::make_unique(4); + disc->setProbability(0.0, 0.1); + disc->setProbability(1.0, 0.2); + disc->setProbability(2.0, 0.3); + disc->setProbability(3.0, 0.4); + emis[11] = std::move(disc); + } + emis[12] = std::make_unique(1.0, 3.0); + emis[13] = std::make_unique(2.0); + + Hmm original(std::move(trans), std::move(emis), std::move(pi)); + + std::ostringstream oss; + oss << original; + const std::string s = oss.str(); + ASSERT_FALSE(s.empty()); + + std::istringstream iss(s); + Hmm restored(1); + ASSERT_NO_THROW(iss >> restored); + + ASSERT_EQ(restored.getNumStatesModern(), N); + + // pi and trans: 6-decimal-place format means ~1e-6 precision. + for (std::size_t i = 0; i < N; ++i) + EXPECT_NEAR(original.getPi()[i], restored.getPi()[i], 1e-5) << "pi[" << i << "]"; + for (std::size_t i = 0; i < N; ++i) + for (std::size_t j = 0; j < N; ++j) + EXPECT_NEAR(original.getTrans()(i, j), restored.getTrans()(i, j), 1e-5) + << "trans(" << i << ',' << j << ')'; + + auto get = [&](std::size_t i) { + return &restored.getDistribution(i); + }; + + { + auto *d = dynamic_cast(get(0)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getMean(), 1.5, 1e-5); + EXPECT_NEAR(d->getStandardDeviation(), 2.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(1)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getLambda(), 3.0, 1e-5); + } + + { + auto *d = dynamic_cast(get(2)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getK(), 2.0, 1e-5); + EXPECT_NEAR(d->getTheta(), 1.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(3)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getAlpha(), 2.0, 1e-5); + EXPECT_NEAR(d->getBeta(), 3.0, 1e-5); + } + + { + auto *d = dynamic_cast(get(4)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getK(), 1.5, 1e-5); + EXPECT_NEAR(d->getLambda(), 2.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(5)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getMean(), 0.5, 1e-5); + EXPECT_NEAR(d->getStandardDeviation(), 1.2, 1e-5); + } + + { + auto *d = dynamic_cast(get(6)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getK(), 2.0, 1e-5); + EXPECT_NEAR(d->getXm(), 0.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(7)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getDegreesOfFreedom(), 4.0, 1e-5); + } + + { + auto *d = dynamic_cast(get(8)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getDegreesOfFreedom(), 3.0, 1e-5); + EXPECT_NEAR(d->getLocation(), 0.5, 1e-5); + EXPECT_NEAR(d->getScale(), 1.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(9)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getLambda(), 2.5, 1e-5); + } + + { + auto *d = dynamic_cast(get(10)); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->getN(), 10); + EXPECT_NEAR(d->getP(), 0.3, 1e-5); + } + + { + auto *d = dynamic_cast(get(11)); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->getNumSymbols(), 4u); + EXPECT_NEAR(d->getSymbolProbability(0), 0.1, 1e-5); + EXPECT_NEAR(d->getSymbolProbability(1), 0.2, 1e-5); + EXPECT_NEAR(d->getSymbolProbability(2), 0.3, 1e-5); + EXPECT_NEAR(d->getSymbolProbability(3), 0.4, 1e-5); + } + + { + auto *d = dynamic_cast(get(12)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getA(), 1.0, 1e-5); + EXPECT_NEAR(d->getB(), 3.0, 1e-5); + } + + { + auto *d = dynamic_cast(get(13)); + ASSERT_NE(d, nullptr); + EXPECT_NEAR(d->getSigma(), 2.0, 1e-5); + } +} + +// NegativeBinomial cannot round-trip through operator>>(Hmm) because +// toString() starts with "Negative Binomial" (two tokens), so the dispatch +// key read by operator>> is "Negative" — absent from kStreamParsers. +// Use JSON I/O (hmm_json.h) for HMMs with NegativeBinomial distributions. +TEST_F(HmmStreamIOTest, NegativeBinomialStreamLimitation) { + Hmm hmm(1); + Vector pi(1); + pi[0] = 1.0; + hmm.setPi(pi); + Matrix trans(1, 1); + trans(0, 0) = 1.0; + hmm.setTrans(trans); + hmm.setDistribution(0, std::make_unique(5.0, 0.4)); + + std::ostringstream oss; + oss << hmm; + + std::istringstream iss(oss.str()); + Hmm restored(1); + // "Negative" is not a recognised dispatch key — must throw. + EXPECT_THROW(iss >> restored, std::runtime_error); +} + TEST_F(HmmStreamIOTest, MultipleDistributionTypesInSameHMM) { // Create HMM with multiple different distribution types including the new ones Hmm hmm(3); diff --git a/tests/io/test_xml_file_io.cpp b/tests/io/test_xml_file_io.cpp index fec85b9..ad10a07 100644 --- a/tests/io/test_xml_file_io.cpp +++ b/tests/io/test_xml_file_io.cpp @@ -280,39 +280,23 @@ TEST_F(IOTest, XMLFileWriterCanWriteToPath) { // XMLFileReader Tests TEST_F(IOTest, XMLFileReaderBasicFunctionality) { - // First write an HMM XMLFileWriter writer; writer.write(*hmm_, xmlFile_); - // Now read it back XMLFileReader reader; - Hmm readHmm(1); // Start with different size - - try { - readHmm = reader.read(xmlFile_); - - // Verify basic properties match if read was successful - EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); - } catch (const std::exception &e) { - // XML parsing may have locale or format issues - GTEST_SKIP() << "XML parsing failed (possibly locale-related): " << e.what(); - } + Hmm readHmm(1); + ASSERT_NO_THROW(readHmm = reader.read(xmlFile_)); + EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); } TEST_F(IOTest, XMLFileReaderStringPath) { - // Write and read using string paths XMLFileWriter writer; writer.write(*hmm_, xmlFile_.string()); XMLFileReader reader; Hmm readHmm(1); - - try { - readHmm = reader.read(xmlFile_.string()); - SUCCEED(); - } catch (const std::exception &e) { - GTEST_SKIP() << "XML parsing failed (possibly locale-related): " << e.what(); - } + ASSERT_NO_THROW(readHmm = reader.read(xmlFile_.string())); + EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); } TEST_F(IOTest, XMLFileReaderNonExistentFileThrows) { @@ -338,17 +322,18 @@ TEST_F(IOTest, XMLFileReaderCanReadFromPath) { EXPECT_FALSE(XMLFileReader::canReadFromPath(nonExistentFile_)); } -TEST_F(IOTest, XMLFileReaderIsValidXMLFile) { - // Create a valid XML file - FileIOManager::writeTextFile(xmlFile_, "\ncontent"); - EXPECT_TRUE(XMLFileReader::isValidXMLFile(xmlFile_)); +TEST_F(IOTest, XMLFileReaderCanParseAsHmm) { + // Create a file that starts with an XML declaration + FileIOManager::writeTextFile(xmlFile_, + "\n"); + EXPECT_TRUE(XMLFileReader::canParseAsHmm(xmlFile_)); - // Create an invalid file + // Create a file that does not start with an XML declaration FileIOManager::writeTextFile(testFile_, "This is not XML"); - EXPECT_FALSE(XMLFileReader::isValidXMLFile(testFile_)); + EXPECT_FALSE(XMLFileReader::canParseAsHmm(testFile_)); // Test non-existent file - EXPECT_FALSE(XMLFileReader::isValidXMLFile(nonExistentFile_)); + EXPECT_FALSE(XMLFileReader::canParseAsHmm(nonExistentFile_)); } // Integration Tests @@ -356,50 +341,49 @@ TEST_F(IOTest, XMLRoundTripConsistency) { XMLFileWriter writer; XMLFileReader reader; - // Write original HMM writer.write(*hmm_, xmlFile_); - try { - // Read it back - Hmm readHmm = reader.read(xmlFile_); - - // Basic consistency checks - EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); - - // Write the read HMM to a second file - auto xmlFile2 = testDir_ / "test_hmm2.xml"; - writer.write(readHmm, xmlFile2); - - // Both files should exist and have content - EXPECT_TRUE(std::filesystem::exists(xmlFile_)); - EXPECT_TRUE(std::filesystem::exists(xmlFile2)); - } catch (const std::exception &e) { - GTEST_SKIP() << "XML parsing failed (possibly locale-related): " << e.what(); + Hmm readHmm = reader.read(xmlFile_); // throws on failure — no skip + EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); + + // Distributions should have the same type and approximate parameters + // (6-decimal-place format means ~1e-6 precision on probability values). + for (int i = 0; i < readHmm.getNumStates(); ++i) { + const auto *orig = dynamic_cast(&hmm_->getDistribution(i)); + const auto *rest = dynamic_cast(&readHmm.getDistribution(i)); + ASSERT_NE(orig, nullptr) << "state " << i; + ASSERT_NE(rest, nullptr) << "state " << i; + EXPECT_EQ(rest->getNumSymbols(), orig->getNumSymbols()); + for (std::size_t k = 0; k < orig->getNumSymbols(); ++k) + EXPECT_NEAR(rest->getSymbolProbability(k), orig->getSymbolProbability(k), 1e-5) + << "state " << i << " symbol " << k; } + + // Write the restored HMM to a second file and confirm it exists. + auto xmlFile2 = testDir_ / "test_hmm2.xml"; + ASSERT_NO_THROW(writer.write(readHmm, xmlFile2)); + EXPECT_TRUE(std::filesystem::exists(xmlFile_)); + EXPECT_TRUE(std::filesystem::exists(xmlFile2)); } TEST_F(IOTest, HMMStreamOperators) { - // Test the stream operators for HMM std::stringstream ss; + ASSERT_NO_THROW(ss << *hmm_); - // Write HMM to stream - EXPECT_NO_THROW(ss << *hmm_); - - // Stream should have content EXPECT_FALSE(ss.str().empty()); - EXPECT_TRUE(ss.str().find("Hidden Markov Model parameters") != std::string::npos); + EXPECT_NE(ss.str().find("Hidden Markov Model parameters"), std::string::npos); + EXPECT_NE(ss.str().find("Discrete Distribution:"), std::string::npos); - // Read HMM from stream Hmm readHmm(1); - - try { - ss >> readHmm; - - // Basic validation - EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); - } catch (const std::exception &e) { - GTEST_SKIP() << "Stream parsing failed (possibly locale-related): " << e.what(); - } + ASSERT_NO_THROW(ss >> readHmm); + EXPECT_EQ(readHmm.getNumStates(), hmm_->getNumStates()); + + // Verify the discrete distribution round-trips correctly. + const auto *orig = dynamic_cast(&hmm_->getDistribution(0)); + const auto *rest = dynamic_cast(&readHmm.getDistribution(0)); + ASSERT_NE(orig, nullptr); + ASSERT_NE(rest, nullptr); + EXPECT_EQ(rest->getNumSymbols(), orig->getNumSymbols()); } // Error Handling Tests diff --git a/tools/hmm_validator.cpp b/tools/hmm_validator.cpp index 35b846d..d3b78f9 100644 --- a/tools/hmm_validator.cpp +++ b/tools/hmm_validator.cpp @@ -1,23 +1,27 @@ /** - * hmm_validator — load, validate and run inference on an XML HMM file. + * hmm_validator — load, validate and run inference on an HMM file. * * Usage: - * hmm_validator [T] + * hmm_validator [T] * - * hmm_xml_file Path to an HMM written by libhmm (XMLFileWriter) - * T Observation sequence length for inference (default: 100) + * hmm_file Path to an HMM file written by libhmm. + * .json extension → JSON format (save_json / load_json). + * Any other extension → legacy XML format (XMLFileWriter). + * T Observation sequence length for inference (default: 100) * - * Loads the HMM from XML, validates its structure, generates T synthetic - * zero-valued observations (always within support for any distribution), - * runs ForwardBackward and Viterbi, and prints a diagnostics report. + * Loads the HMM, validates its structure, generates T synthetic zero-valued + * observations (always within support for any distribution), runs + * ForwardBackward and Viterbi, and prints a diagnostics report. * * Exit code: 0 = all checks passed, 1 = load/validate/inference failure. */ #include "libhmm/hmm.h" #include "libhmm/calculators/forward_backward_calculator.h" #include "libhmm/calculators/viterbi_calculator.h" +#include "libhmm/io/hmm_json.h" #include "libhmm/io/xml_file_reader.h" #include +#include #include #include #include @@ -25,9 +29,9 @@ using namespace libhmm; static void print_usage(const char *prog) { - std::cout << "Usage: " << prog << " [T]\n\n" - << " hmm_xml_file Path to an XML HMM file written by libhmm\n" - << " T Observation sequence length (default: 100)\n"; + std::cout << "Usage: " << prog << " [T]\n\n" + << " hmm_file Path to an HMM file (.json or legacy .xml)\n" + << " T Observation sequence length (default: 100)\n"; } int main(int argc, char *argv[]) { @@ -36,24 +40,29 @@ int main(int argc, char *argv[]) { return 1; } - const std::string xml_path = argv[1]; + const std::filesystem::path hmm_path = argv[1]; const int T = (argc >= 3) ? std::stoi(argv[2]) : 100; std::cout << "libhmm HMM Validator\n"; std::cout << "====================\n"; - std::cout << "File: " << xml_path << "\n"; + std::cout << "File: " << hmm_path.string() << "\n"; std::cout << "T: " << T << "\n\n"; int exit_code = 0; // ------------------------------------------------------------------------- - // 1. Load + // 1. Load — detect format from file extension // ------------------------------------------------------------------------- Hmm hmm(1); // placeholder overwritten by reader try { - XMLFileReader reader; - hmm = reader.read(xml_path); - std::cout << "[ OK ] Load from XML\n"; + if (hmm_path.extension() == ".json") { + hmm = load_json(hmm_path); + std::cout << "[ OK ] Load from JSON\n"; + } else { + XMLFileReader reader; + hmm = reader.read(hmm_path.string()); + std::cout << "[ OK ] Load from XML (legacy format)\n"; + } } catch (const std::exception &e) { std::cerr << "[FAIL] Load: " << e.what() << "\n"; return 1;