diff --git a/include/LinearLib/Matrix.hpp b/include/LinearLib/Matrix.hpp index 3c4306e..531077b 100644 --- a/include/LinearLib/Matrix.hpp +++ b/include/LinearLib/Matrix.hpp @@ -26,12 +26,19 @@ namespace LinearLib { Matrix() = default; static Matrix identity() { + return eye(0); + } + + static Matrix eye(int const offset) { + + assert(offset < static_cast(C) && "Offset absolute value must be less than matrix dimension"); + assert(offset > -static_cast(C) && "Offset absolute value must be less than matrix dimension"); Matrix res; for (std::size_t i = 0; i < R; i++) { for (std::size_t j = 0; j < C; j++) { - if (i == j) { + if (i + offset == j) { res.data[i][j] = T{1}; } else { res.data[i][j] = T{0}; @@ -42,6 +49,18 @@ namespace LinearLib { return res; } + static Matrix zeros() { + Matrix res; + + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + res.data[i][j] = T{0}; + } + } + + return res; + } + template Matrix splice(std::ranges::range auto& rows, std::ranges::range auto& cols) const { assert(std::ranges::distance(rows) == SubR && std::format("Row range must dimensionally cover %i rows", SubR).c_str()); @@ -99,7 +118,8 @@ namespace LinearLib { } T determinant() const { - assert(R == C && "Determinant is only defined for square matrices"); + + assert(isSquare() && "Determinant is only defined for square matrices"); // Base case 1 if (R == 1) { @@ -137,6 +157,27 @@ namespace LinearLib { return res; } + [[nodiscard]] static bool isSquare() { + return R == C; + } + + [[nodiscard]] bool isSymmetric() const { + + if (!isSquare()) { + return false; + } + + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < C; j++) { + if (data[i][j] != data[j][i]) { + return false; + } + } + } + + return true; + } + std::array, R> getData() const { return data; } @@ -187,23 +228,30 @@ namespace LinearLib { return res; } + /** + * Matrix dot product + */ template Matrix operator*(const Matrix& other) const { + Matrix res; for (std::size_t i = 0; i < R; i++) { - for (std::size_t j = 0; j < I; j++) { + for (std::size_t j = 0; j < C; j++) { T sum = T{}; for (std::size_t k = 0; k < C; k++) { - sum += data[i][k] * other.data[k][j]; + sum += data[i][k] * other[k][j]; } - res.data[i][j] = sum; + res[i][j] = sum; } } return res; } + /** + * Scalar Multiplication + */ Matrix operator*(const T& scalar) const { Matrix res; @@ -215,5 +263,26 @@ namespace LinearLib { return res; } + + /** + * Matrix Multiplication + */ + template + Matrix operator&(const Matrix& other) const { + + Matrix res; + + for (std::size_t i = 0; i < R; i++) { + for (std::size_t j = 0; j < I; j++) { + T sum = T{}; + for (std::size_t k = 0; k < C; k++) { + sum += data[i][k] * other.data[k][j]; + } + res.data[i][j] = sum; + } + } + + return res; + } }; } diff --git a/include/LinearLib/Vector.hpp b/include/LinearLib/Vector.hpp index dd1ce18..f8a9f8d 100644 --- a/include/LinearLib/Vector.hpp +++ b/include/LinearLib/Vector.hpp @@ -5,6 +5,8 @@ #include #include +#include "Matrix.hpp" + namespace LinearLib { template requires std::is_arithmetic_v @@ -20,6 +22,26 @@ namespace LinearLib { // Default constructor is still needed Vector() = default; + T magnitude() const { + T res = 0; + + for (std::size_t i = 0; i < N; i++) { + res += std::pow(data[i], 2); + } + + return std::sqrt(res); + } + + Matrix<1, N, T> asMatrix() { + Matrix<1, N, T> res; + + for (std::size_t i = 0; i < N; i++) { + res[0][i] = data[i]; + } + + return res; + } + std::array getData() const { return data; } @@ -101,4 +123,4 @@ namespace LinearLib { return res; } }; -} \ No newline at end of file +} diff --git a/tests/MatrixTests.cpp b/tests/MatrixTests.cpp index bff4913..aa2a441 100644 --- a/tests/MatrixTests.cpp +++ b/tests/MatrixTests.cpp @@ -35,6 +35,33 @@ TEST_CASE("Matrix operations", "[vector]") { REQUIRE(result[1][1] == 6); } + SECTION("Matrix dot product") { + const Matrix<2, 2, int> m1 {{1, 2},{3,4}}; + const Matrix<2, 2, int> m2 {{5, 6},{7,8}}; + + const Matrix<2, 2, int> result1 = m1 * m2; + + REQUIRE(result1[0][0] == 19); + REQUIRE(result1[0][1] == 22); + REQUIRE(result1[1][0] == 43); + REQUIRE(result1[1][1] == 50); + + const Matrix<3, 3, int> m3 {{1, 1, 3}, + {-1, 4, 1}, + {0, 2, -2}}; + const Matrix<3, 1, int> m4 {{1}, + {0}, + {-1}}; + + const Matrix<3, 1, int> result2 = m3 * m4; + + REQUIRE(result2[0][0] == -2); + REQUIRE(result2[1][0] == -2); + REQUIRE(result2[2][0] == 2); + + + } + SECTION("Matrix multiplication") { const Matrix<2, 4, int> m1 {{5, 3, 3, 4}, {2, 4, 4, 3}}; @@ -44,7 +71,7 @@ TEST_CASE("Matrix operations", "[vector]") { {3, 2}, {7, 1}}; - const Matrix<2, 2, int> result = m1 * m2; + const Matrix<2, 2, int> result = m1 & m2; REQUIRE(result[0][0] == 67); REQUIRE(result[0][1] == 46); @@ -118,4 +145,56 @@ TEST_CASE("Matrix operations", "[vector]") { REQUIRE(result[3][2] == 30); REQUIRE(result[3][3] == 32); } + + SECTION("Zeros") { + const Matrix<4, 4, float> zeros = Matrix<4, 4, float>::zeros(); + + REQUIRE(zeros[0][0] == 0); + REQUIRE(zeros[0][1] == 0); + REQUIRE(zeros[0][2] == 0); + REQUIRE(zeros[0][3] == 0); + REQUIRE(zeros[1][0] == 0); + REQUIRE(zeros[1][1] == 0); + REQUIRE(zeros[1][2] == 0); + } + + SECTION("Eye") { + const Matrix<4, 4, int> eye1 = Matrix<4, 4, int>::eye(1); + + REQUIRE(eye1[0][0] == 0); + REQUIRE(eye1[0][1] == 1); + REQUIRE(eye1[0][2] == 0); + REQUIRE(eye1[0][3] == 0); + REQUIRE(eye1[1][0] == 0); + REQUIRE(eye1[1][1] == 0); + REQUIRE(eye1[1][2] == 1); + REQUIRE(eye1[1][3] == 0); + REQUIRE(eye1[2][0] == 0); + REQUIRE(eye1[2][1] == 0); + REQUIRE(eye1[2][2] == 0); + REQUIRE(eye1[2][3] == 1); + REQUIRE(eye1[3][0] == 0); + REQUIRE(eye1[3][1] == 0); + REQUIRE(eye1[3][2] == 0); + REQUIRE(eye1[3][3] == 0); + + const Matrix<4, 4, int> eye2 = Matrix<4, 4, int>::eye(-2); + + REQUIRE(eye2[0][0] == 0); + REQUIRE(eye2[0][1] == 0); + REQUIRE(eye2[0][2] == 0); + REQUIRE(eye2[0][3] == 0); + REQUIRE(eye2[1][0] == 0); + REQUIRE(eye2[1][1] == 0); + REQUIRE(eye2[1][2] == 0); + REQUIRE(eye2[1][3] == 0); + REQUIRE(eye2[2][0] == 1); + REQUIRE(eye2[2][1] == 0); + REQUIRE(eye2[2][2] == 0); + REQUIRE(eye2[2][3] == 0); + REQUIRE(eye2[3][0] == 0); + REQUIRE(eye2[3][1] == 1); + REQUIRE(eye2[3][2] == 0); + REQUIRE(eye2[3][3] == 0); + } } \ No newline at end of file diff --git a/tests/VectorTests.cpp b/tests/VectorTests.cpp index 054ceab..2da2ff5 100644 --- a/tests/VectorTests.cpp +++ b/tests/VectorTests.cpp @@ -4,6 +4,25 @@ using namespace LinearLib; TEST_CASE("Vector operations", "[vector]") { + + SECTION("Vector magnitude") { + Vector<2, int> v {3, 4}; + + int result = v.magnitude(); + + REQUIRE(result == 5); + } + + SECTION("As Matrix") { + Vector<3, int> v {1, 2, 3}; + + Matrix<1, 3, int> m = v.asMatrix(); + + REQUIRE(m[0][0] == 1); + REQUIRE(m[0][1] == 2); + REQUIRE(m[0][2] == 3); + } + SECTION("Vector addition") { Vector<3, int> v1 {1, 2, 3};