-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparse_matrix.h
90 lines (71 loc) · 1.8 KB
/
sparse_matrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Author: Mingcheng Chen ([email protected])
#ifndef SPARSE_MATRIX_H_
#define SPARSE_MATRIX_H_
#include <vector>
class SparseMatrix {
public:
SparseMatrix() {
num_rows_ = num_cols_ = 0;
}
SparseMatrix(int num_rows, int num_cols)
: num_rows_(num_rows), num_cols_(num_cols) {
data_.resize(num_rows);
col_.resize(num_rows);
}
void multiply_column(const double *column, double *result) const;
void set_element(int row, int col, double value) {
if (value == 0.0) {
for (int c = 0; c < col_[row].size(); c++) {
if (col_[row][c] == col) {
col_[row].erase(col_[row].begin() + c);
data_[row].erase(data_[row].begin() + c);
break;
}
}
return;
}
for (int c = 0; c < col_[row].size(); c++) {
if (col_[row][c] == col) {
data_[row][c] = value;
return;
}
}
col_[row].push_back(col);
data_[row].push_back(value);
}
double get_element(int row, int col) const {
for (int c = 0; c < col_[row].size(); c++) {
if (col_[row][c] == col) {
return data_[row][c];
}
}
return 0.0;
}
SparseMatrix transpose() const {
SparseMatrix result(num_cols_, num_rows_);
for (int r = 0; r < num_rows_; r++) {
for (int c = 0; c < col_[r].size(); c++) {
result.set_element(col_[r][c], r, data_[r][c]);
}
}
return result;
}
int num_elements() const {
int result = 0;
for (int r = 0; r < num_rows_; r++) {
result += col_[r].size();
}
return result;
}
int get_num_rows() const {
return num_rows_;
}
int get_num_cols() const {
return num_cols_;
}
private:
std::vector<std::vector<double> > data_;
std::vector<std::vector<int> > col_;
int num_rows_, num_cols_;
};
#endif // SPARSE_MATRIX_H_