5#ifndef IDOL_SQUAREMATRIX_H
6#define IDOL_SQUAREMATRIX_H
10#include "MatrixIndices.h"
17class idol::SquareMatrix {
18 const MatrixIndices* m_indices;
19 Eigen::MatrixXd m_matrix;
21 SquareMatrix(
const MatrixIndices& t_indices, Eigen::MatrixXd t_matrix) : m_indices(&t_indices), m_matrix(std::move(t_matrix)) {}
23 explicit SquareMatrix(
const MatrixIndices& t_indices) : m_indices(&t_indices), m_matrix(Eigen::MatrixXd::Zero(t_indices.n_indices(), t_indices.n_indices())) {}
25 void set(
const Var& t_var1,
const Var& t_var2,
double t_value) {
26 const unsigned int i = m_indices->get(t_var1);
27 const unsigned int j = m_indices->get(t_var2);
28 m_matrix(i, j) = t_value;
29 m_matrix(j, i) = t_value;
32 [[nodiscard]]
double get(
unsigned int t_i,
unsigned int t_j)
const {
return m_matrix(t_i, t_j); }
34 [[nodiscard]]
double get(
const Var& t_i,
const Var& t_j)
const {
return m_matrix(m_indices->get(t_i), m_indices->get(t_j)); }
36 [[nodiscard]]
unsigned int size()
const {
return m_indices->n_indices(); }
38 [[nodiscard]] SquareMatrix operator*(
const SquareMatrix& t_matrix)
const {
40 result.m_matrix *= t_matrix.m_matrix;
44 [[nodiscard]] SquareMatrix transpose()
const {
45 return { *m_indices, m_matrix.transpose() };
48 [[nodiscard]] SquareMatrix sqrt()
const {
49 assert(m_matrix.isDiagonal());
51 for (
const auto& [var, index] : m_indices->indices()) {
52 result.m_matrix(index, index) = std::sqrt(result.m_matrix(index, index));
57 [[nodiscard]] std::pair<SquareMatrix, SquareMatrix> eigen_decomposition()
const {
58 Eigen::EigenSolver<Eigen::MatrixXd> solver(m_matrix);
59 return std::make_pair(
60 SquareMatrix(*m_indices, solver.pseudoEigenvectors()),
61 SquareMatrix(*m_indices, solver.pseudoEigenvalueMatrix())
65 [[nodiscard]] SquareMatrix cholesky()
const {
66 Eigen::LLT<Eigen::MatrixXd> solver(m_matrix);
67 return SquareMatrix(*m_indices, solver.matrixL());
73 static std::ostream &operator<<(std::ostream &t_os,
const SquareMatrix &t_matrix) {
74 for (
unsigned int i = 0, n = t_matrix.size(); i < n; ++i) {
75 for (
unsigned int j = 0; j < n; ++j) {
76 t_os << t_matrix.get(i, j) <<
", ";