-
Notifications
You must be signed in to change notification settings - Fork 2
/
mf.h
118 lines (94 loc) · 3.32 KB
/
mf.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <vector>
#include <algorithm>
#include <string>
#include <memory>
#include <string>
#include <fstream>
#include <sstream>
#include <iostream>
#include <cmath>
#include <cstdio>
#include <iomanip>
#include <random>
#include "omp.h"
using namespace std;
typedef double ImpFloat;
typedef double ImpDouble;
typedef long long int ImpInt;
typedef long long int ImpLong;
class Parameter {
public:
ImpFloat lambda_u, lambda_i, w, a;
ImpInt nr_pass, k, nr_threads, scheme;
string model_path, predict_path;
Parameter():lambda_u(0.1), lambda_i(0.1), w(1), a(0), nr_pass(20), k(10),nr_threads(1),scheme(0) {};
};
struct smat {
vector<ImpLong> row_ptr;
vector<ImpLong> col_idx;
vector<ImpDouble> val;
};
class ImpData {
public:
string file_name;
ImpLong l, m, n;
ImpLong m_real, n_real;
smat R;
smat RT;
ImpData(string file_name): file_name(file_name), l(0), m(0), n(0) {};
void read();
void print_data_info();
class Compare {
public:
const ImpLong *row_idx;
const ImpLong *col_idx;
Compare(const ImpLong *row_idx_, const ImpLong *col_idx_) {
row_idx = row_idx_;
col_idx = col_idx_;
}
bool operator()(size_t x, size_t y) const {
return (row_idx[x] < row_idx[y]) || ((row_idx[x] == row_idx[y]) && (col_idx[x]<= col_idx[y]));
}
};
};
class ImpProblem {
public:
shared_ptr<ImpData> data;
shared_ptr<ImpData> test_data;
shared_ptr<Parameter> param;
ImpProblem(shared_ptr<ImpData> &data, shared_ptr<Parameter> ¶m)
:data(data), param(param) {};
ImpProblem(shared_ptr<ImpData> &data, shared_ptr<ImpData> &test_data, shared_ptr<Parameter> ¶m)
:data(data), test_data(test_data), param(param) {};
ImpFloat *W, *H;
ImpFloat *WT, *HT;
vector<ImpDouble> p, q;
ImpInt t;
ImpDouble obj, reg, tr_loss;
vector<ImpDouble> va_loss;
ImpFloat start_time;
ImpFloat U_time, C_time, W_time, H_time, I_time, R_time;
ImpFloat sum, sq;
vector<ImpFloat> gamma_w, gamma_h;
ImpDouble cal_loss(ImpLong &l, smat &R);
ImpDouble cal_reg();
ImpDouble cal_tr_loss(ImpLong &l, smat &R);
void update(const smat &R, ImpLong i, vector<ImpFloat> &gamma, ImpDouble *u, ImpDouble *v, const ImpFloat lambda, const ImpDouble w_p, const vector<ImpDouble> &w_q );
void save();
void load();
void initialize();
void init_va_loss(ImpInt size);
void set_weight(const smat &R, const ImpInt m, vector<ImpDouble> &p, const ImpInt scheme);
void solve();
void update_R(ImpFloat *wt, ImpFloat *ht, bool add);
void validate(const vector<ImpInt> &topks);
void validate_ndcg(const vector<ImpInt> &topks);
void predict_candidates(const ImpFloat* w, vector<ImpFloat> &Z);
ImpLong precision_k(vector<ImpFloat> &Z, ImpLong i, const vector<ImpInt> &topks, vector<ImpLong> &hit_counts);
ImpDouble ndcg_k(vector<ImpFloat> &Z, ImpLong i, const vector<ImpInt> &topks, vector<double> &ndcgs);
void cache(ImpDouble* WT, ImpDouble* H, vector<ImpFloat> &gamma, ImpDouble *ut, ImpLong m, ImpLong n, const vector<ImpDouble> &w_q);
void update_coordinates();
void print_epoch_info();
void print_header_info(vector<ImpInt> &topks);
bool is_hit(const smat &R, ImpLong i, ImpLong argmax);
};