-
Notifications
You must be signed in to change notification settings - Fork 0
/
ffm.h
94 lines (69 loc) · 1.92 KB
/
ffm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#ifndef _LIBFFM_H
#define _LIBFFM_H
#ifdef __cplusplus
#include <unordered_set>
extern "C"
{
namespace ffm
{
using namespace std;
#endif
typedef float ffm_float;
typedef double ffm_double;
typedef int ffm_int;
typedef long long ffm_long;
struct ffm_node
{
ffm_int f;
ffm_int j;
ffm_float v;
};
struct ffm_problem
{
ffm_int n;
ffm_int l;
ffm_int m;
ffm_node *X;
ffm_long *P;
ffm_float *Y;
};
struct ffm_model
{
ffm_int n;
ffm_int m;
ffm_int k;
ffm_float *W;
unordered_set<ffm_int> tr_feat_idx;
ffm_float tr_avg;
bool normalization;
};
struct ffm_parameter
{
ffm_float eta;
ffm_float lambda;
ffm_int nr_iters;
ffm_int k;
ffm_int nr_threads;
bool quiet;
bool normalization;
bool random;
bool auto_stop;
};
ffm_problem* ffm_read_problem(char const *path);
int ffm_read_problem_to_disk(char const *txt_path, char const *bin_path);
void ffm_destroy_problem(struct ffm_problem **prob);
ffm_int ffm_save_model(ffm_model *model, char const *path, ffm_int ¶m_k);
ffm_model* ffm_load_model(char const *path);
void ffm_destroy_model(struct ffm_model **model);
ffm_parameter ffm_get_default_param();
ffm_model* ffm_train_with_validation(struct ffm_problem *Tr, struct ffm_problem *Va, struct ffm_parameter param, ffm_model *raw_model);
ffm_model* ffm_train(struct ffm_problem *prob, struct ffm_parameter param, ffm_model *raw_model);
ffm_model* ffm_train_with_validation_on_disk(char const *Tr_path, char const *Va_path, struct ffm_parameter param, ffm_model *raw_model);
ffm_model* ffm_train_on_disk(char const *path, struct ffm_parameter param, ffm_model *raw_model);
ffm_float ffm_predict(ffm_node *begin, ffm_node *end, ffm_model *raw_model);
ffm_float ffm_cross_validation(struct ffm_problem *prob, ffm_int nr_folds, struct ffm_parameter param, ffm_model *raw_model);
#ifdef __cplusplus
} // namespace ffm
} // extern "C"
#endif
#endif // _LIBFFM_H