From 32ce95d295119ebd47fce4c5eeda2596bcf5a352 Mon Sep 17 00:00:00 2001 From: Chad Cotton Date: Wed, 7 May 2025 13:14:05 -0500 Subject: [PATCH] Added random, uniform abstraction, and forEach. --- include/LinearLib/Matrix.hpp | 63 ++++++++++++++++++++++++++++++++++-- tests/MatrixTests.cpp | 50 +++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/include/LinearLib/Matrix.hpp b/include/LinearLib/Matrix.hpp index 531077b..659f4d7 100644 --- a/include/LinearLib/Matrix.hpp +++ b/include/LinearLib/Matrix.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -50,11 +51,43 @@ namespace LinearLib { } static Matrix zeros() { + return uniform(T{0}); + } + + static Matrix ones() { + return uniform(T{1}); + } + + static Matrix uniform(T const val) { Matrix res; for (std::size_t i = 0; i < R; i++) { for (std::size_t j = 0; j < C; j++) { - res.data[i][j] = T{0}; + res.data[i][j] = val; + } + } + + return res; + } + + static Matrix random(T const min, T const max, std::size_t const seed = 0) { + Matrix res; + + std::mt19937_64 rng(seed); + + if constexpr (std::is_integral_v) { + std::uniform_int_distribution dist(min, max); + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + res.data[i][j] = dist(rng); + } + } + } else if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist(min, max); + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + res.data[i][j] = dist(rng); + } } } @@ -182,6 +215,31 @@ namespace LinearLib { return data; } + + void forEach(const std::function& func) { + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + func(); + } + } + } + + void forEach(const std::function& func) { + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + func(data[i][j]); + } + } + } + + void forEach(const std::function func) { + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + func(data[i][j], i, j); + } + } + } + bool operator==(const Matrix& other) const { for (std::size_t i = 0; i < R; i++) { for (std::size_t j = 0; j < C; j++) { @@ -237,7 +295,7 @@ namespace LinearLib { Matrix res; for (std::size_t i = 0; i < R; i++) { - for (std::size_t j = 0; j < C; j++) { + for (std::size_t j = 0; j < I; j++) { T sum = T{}; for (std::size_t k = 0; k < C; k++) { sum += data[i][k] * other[k][j]; @@ -284,5 +342,6 @@ namespace LinearLib { return res; } + }; } diff --git a/tests/MatrixTests.cpp b/tests/MatrixTests.cpp index aa2a441..da6acf7 100644 --- a/tests/MatrixTests.cpp +++ b/tests/MatrixTests.cpp @@ -27,12 +27,12 @@ TEST_CASE("Matrix operations", "[vector]") { const Matrix<2, 2, int> m2 {{3, 6}, {7, 2}}; - const Matrix<2, 2, int> result = m1 + m2; + const Matrix<2, 2, int> result = m1 - m2; - REQUIRE(result[0][0] == 8); - REQUIRE(result[0][1] == 9); - REQUIRE(result[1][0] == 9); - REQUIRE(result[1][1] == 6); + REQUIRE(result[0][0] == 2); + REQUIRE(result[0][1] == -3); + REQUIRE(result[1][0] == -5); + REQUIRE(result[1][1] == 2); } SECTION("Matrix dot product") { @@ -197,4 +197,44 @@ TEST_CASE("Matrix operations", "[vector]") { REQUIRE(eye2[3][2] == 0); REQUIRE(eye2[3][3] == 0); } + + SECTION("Random") { + size_t seed = 42; + const Matrix<2, 2, int> m1 = Matrix<2, 2, int>::random(1, 10, seed); + + REQUIRE(m1[0][0] == 8); + REQUIRE(m1[0][1] == 7); + REQUIRE(m1[1][0] == 8); + REQUIRE(m1[1][1] == 2); + + const Matrix<2, 2, float> m2 = Matrix<2, 2, float>::random(0.0, 100.0, seed); + REQUIRE(m2[0][0] == 75.515548706f); + REQUIRE(m2[0][1] == 63.903141022f); + REQUIRE(m2[1][0] == 75.214515686f); + REQUIRE(m2[1][1] == 13.627268791f); + } + + SECTION("ForEach") { + Matrix<2, 2, int> m {{1, 2}, {3, 4} }; + + m.forEach([](int& value) { + value *= 2; + }); + + REQUIRE(m[0][0] == 2); + REQUIRE(m[0][1] == 4); + REQUIRE(m[1][0] == 6); + REQUIRE(m[1][1] == 8); + + Matrix<2, 2, float> m2 = Matrix<2, 2, float>::uniform(10); + + m2.forEach([](float& value, std::size_t row, std::size_t col) { + value = row + col; + }); + + REQUIRE(m2[0][0] == 0.0f); + REQUIRE(m2[0][1] == 1.0f); + REQUIRE(m2[1][0] == 1.0f); + REQUIRE(m2[1][1] == 2.0f); + } } \ No newline at end of file