Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions include/LinearLib/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@ namespace LinearLib {
Matrix() = default;

static Matrix identity() {
return eye(0);
}

static Matrix eye(int const offset) {

assert(offset < static_cast<int>(C) && "Offset absolute value must be less than matrix dimension");
assert(offset > -static_cast<int>(C) && "Offset absolute value must be less than matrix dimension");

Matrix res;

for (std::size_t i = 0; i < R; i++) {
for (std::size_t j = 0; j < C; j++) {
if (i == j) {
if (i + offset == j) {
res.data[i][j] = T{1};
} else {
res.data[i][j] = T{0};
Expand All @@ -42,6 +49,18 @@ namespace LinearLib {
return res;
}

static Matrix zeros() {
Matrix res;

for (std::size_t i = 0; i < R; i++) {
for (std::size_t j = 0; j < C; j++) {
res.data[i][j] = T{0};
}
}

return res;
}

template<std::size_t SubR, std::size_t SubC>
Matrix<SubR, SubC, T> splice(std::ranges::range auto& rows, std::ranges::range auto& cols) const {
assert(std::ranges::distance(rows) == SubR && std::format("Row range must dimensionally cover %i rows", SubR).c_str());
Expand Down Expand Up @@ -99,7 +118,8 @@ namespace LinearLib {
}

T determinant() const {
assert(R == C && "Determinant is only defined for square matrices");

assert(isSquare() && "Determinant is only defined for square matrices");

// Base case 1
if (R == 1) {
Expand Down Expand Up @@ -137,6 +157,27 @@ namespace LinearLib {
return res;
}

[[nodiscard]] static bool isSquare() {
return R == C;
}

[[nodiscard]] bool isSymmetric() const {

if (!isSquare()) {
return false;
}

for (std::size_t i = 0; i < R; i++) {
for (std::size_t j = 0; j < C; j++) {
if (data[i][j] != data[j][i]) {
return false;
}
}
}

return true;
}

std::array<std::array<T, C>, R> getData() const {
return data;
}
Expand Down Expand Up @@ -187,23 +228,30 @@ namespace LinearLib {
return res;
}

/**
* Matrix dot product
*/
template<std::size_t I>
Matrix<R, I, T> operator*(const Matrix<C, I, T>& other) const {

Matrix<R, I, T> res;

for (std::size_t i = 0; i < R; i++) {
for (std::size_t j = 0; j < I; j++) {
for (std::size_t j = 0; j < C; j++) {
T sum = T{};
for (std::size_t k = 0; k < C; k++) {
sum += data[i][k] * other.data[k][j];
sum += data[i][k] * other[k][j];
}
res.data[i][j] = sum;
res[i][j] = sum;
}
}

return res;
}

/**
* Scalar Multiplication
*/
Matrix operator*(const T& scalar) const {
Matrix res;

Expand All @@ -215,5 +263,26 @@ namespace LinearLib {

return res;
}

/**
* Matrix Multiplication
*/
template<std::size_t I>
Matrix<R, I, T> operator&(const Matrix<C, I, T>& other) const {

Matrix<R, I, T> res;

for (std::size_t i = 0; i < R; i++) {
for (std::size_t j = 0; j < I; j++) {
T sum = T{};
for (std::size_t k = 0; k < C; k++) {
sum += data[i][k] * other.data[k][j];
}
res.data[i][j] = sum;
}
}

return res;
}
};
}
24 changes: 23 additions & 1 deletion include/LinearLib/Vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <initializer_list>
#include <type_traits>

#include "Matrix.hpp"

namespace LinearLib {
template<std::size_t N, typename T>
requires std::is_arithmetic_v<T>
Expand All @@ -20,6 +22,26 @@ namespace LinearLib {
// Default constructor is still needed
Vector() = default;

T magnitude() const {
T res = 0;

for (std::size_t i = 0; i < N; i++) {
res += std::pow(data[i], 2);
}

return std::sqrt(res);
}

Matrix<1, N, T> asMatrix() {
Matrix<1, N, T> res;

for (std::size_t i = 0; i < N; i++) {
res[0][i] = data[i];
}

return res;
}

std::array<T, N> getData() const {
return data;
}
Expand Down Expand Up @@ -101,4 +123,4 @@ namespace LinearLib {
return res;
}
};
}
}
81 changes: 80 additions & 1 deletion tests/MatrixTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ TEST_CASE("Matrix operations", "[vector]") {
REQUIRE(result[1][1] == 6);
}

SECTION("Matrix dot product") {
const Matrix<2, 2, int> m1 {{1, 2},{3,4}};
const Matrix<2, 2, int> m2 {{5, 6},{7,8}};

const Matrix<2, 2, int> result1 = m1 * m2;

REQUIRE(result1[0][0] == 19);
REQUIRE(result1[0][1] == 22);
REQUIRE(result1[1][0] == 43);
REQUIRE(result1[1][1] == 50);

const Matrix<3, 3, int> m3 {{1, 1, 3},
{-1, 4, 1},
{0, 2, -2}};
const Matrix<3, 1, int> m4 {{1},
{0},
{-1}};

const Matrix<3, 1, int> result2 = m3 * m4;

REQUIRE(result2[0][0] == -2);
REQUIRE(result2[1][0] == -2);
REQUIRE(result2[2][0] == 2);


}

SECTION("Matrix multiplication") {
const Matrix<2, 4, int> m1 {{5, 3, 3, 4},
{2, 4, 4, 3}};
Expand All @@ -44,7 +71,7 @@ TEST_CASE("Matrix operations", "[vector]") {
{3, 2},
{7, 1}};

const Matrix<2, 2, int> result = m1 * m2;
const Matrix<2, 2, int> result = m1 & m2;

REQUIRE(result[0][0] == 67);
REQUIRE(result[0][1] == 46);
Expand Down Expand Up @@ -118,4 +145,56 @@ TEST_CASE("Matrix operations", "[vector]") {
REQUIRE(result[3][2] == 30);
REQUIRE(result[3][3] == 32);
}

SECTION("Zeros") {
const Matrix<4, 4, float> zeros = Matrix<4, 4, float>::zeros();

REQUIRE(zeros[0][0] == 0);
REQUIRE(zeros[0][1] == 0);
REQUIRE(zeros[0][2] == 0);
REQUIRE(zeros[0][3] == 0);
REQUIRE(zeros[1][0] == 0);
REQUIRE(zeros[1][1] == 0);
REQUIRE(zeros[1][2] == 0);
}

SECTION("Eye") {
const Matrix<4, 4, int> eye1 = Matrix<4, 4, int>::eye(1);

REQUIRE(eye1[0][0] == 0);
REQUIRE(eye1[0][1] == 1);
REQUIRE(eye1[0][2] == 0);
REQUIRE(eye1[0][3] == 0);
REQUIRE(eye1[1][0] == 0);
REQUIRE(eye1[1][1] == 0);
REQUIRE(eye1[1][2] == 1);
REQUIRE(eye1[1][3] == 0);
REQUIRE(eye1[2][0] == 0);
REQUIRE(eye1[2][1] == 0);
REQUIRE(eye1[2][2] == 0);
REQUIRE(eye1[2][3] == 1);
REQUIRE(eye1[3][0] == 0);
REQUIRE(eye1[3][1] == 0);
REQUIRE(eye1[3][2] == 0);
REQUIRE(eye1[3][3] == 0);

const Matrix<4, 4, int> eye2 = Matrix<4, 4, int>::eye(-2);

REQUIRE(eye2[0][0] == 0);
REQUIRE(eye2[0][1] == 0);
REQUIRE(eye2[0][2] == 0);
REQUIRE(eye2[0][3] == 0);
REQUIRE(eye2[1][0] == 0);
REQUIRE(eye2[1][1] == 0);
REQUIRE(eye2[1][2] == 0);
REQUIRE(eye2[1][3] == 0);
REQUIRE(eye2[2][0] == 1);
REQUIRE(eye2[2][1] == 0);
REQUIRE(eye2[2][2] == 0);
REQUIRE(eye2[2][3] == 0);
REQUIRE(eye2[3][0] == 0);
REQUIRE(eye2[3][1] == 1);
REQUIRE(eye2[3][2] == 0);
REQUIRE(eye2[3][3] == 0);
}
}
19 changes: 19 additions & 0 deletions tests/VectorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
using namespace LinearLib;

TEST_CASE("Vector operations", "[vector]") {

SECTION("Vector magnitude") {
Vector<2, int> v {3, 4};

int result = v.magnitude();

REQUIRE(result == 5);
}

SECTION("As Matrix") {
Vector<3, int> v {1, 2, 3};

Matrix<1, 3, int> m = v.asMatrix();

REQUIRE(m[0][0] == 1);
REQUIRE(m[0][1] == 2);
REQUIRE(m[0][2] == 3);
}

SECTION("Vector addition") {

Vector<3, int> v1 {1, 2, 3};
Expand Down