Skip to content

Commit

Permalink
Allow seed words for #30
Browse files Browse the repository at this point in the history
  • Loading branch information
koheiw committed Aug 10, 2020
1 parent 9a50fbc commit c22bd43
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 12 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ qatd_cpp_ca <- function(dfm, residual_floor) {
.Call(`_quanteda_textmodels_qatd_cpp_ca`, dfm, residual_floor)
}

qatd_cpp_lda <- function(mt, k, max_iter, alpha, beta, verbose) {
.Call(`_quanteda_textmodels_qatd_cpp_lda`, mt, k, max_iter, alpha, beta, verbose)
qatd_cpp_lda <- function(mt, k, max_iter, alpha, beta, seeds, verbose) {
.Call(`_quanteda_textmodels_qatd_cpp_lda`, mt, k, max_iter, alpha, beta, seeds, verbose)
}

qatd_cpp_tbb_enabled <- function() {
Expand Down
45 changes: 41 additions & 4 deletions R/textmodel_lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ textmodel_lda.dfm <- function(
alpha <- -1.0
if (is.null(beta))
beta <- -1.0

result <- qatd_cpp_lda(x, k, max_iter, alpha, beta, verbose)

topic <- paste0("topic", seq_len(k))
if (is.null(seeds)) {
seeds <- as(Matrix::sparseMatrix(nfeat(x), k), "dgCMatrix")
topic <- paste0("topic", seq_len(k))
} else {
if (!identical(colnames(x), rownames(seeds)))
stop("seeds must have the same features")
k <- ncol(seeds)
topic <- colnames(seeds)
}
result <- qatd_cpp_lda(x, k, max_iter, alpha, beta, seeds, verbose)
dimnames(result$phi) <- list(topic, colnames(x))
dimnames(result$theta) <- list(rownames(x), topic)
result$alpha <- alpha
Expand Down Expand Up @@ -56,3 +62,34 @@ topics <- function(x) {
topics.textmodel_lda <- function(x) {
colnames(x$theta)[max.col(x$theta)]
}

#' Internal function to construct topic-feature matrix
#' @import Matrix quanteda
#' @noRd
tfm <- function(x, dictionary,
valuetype = c("glob", "regex", "fixed"),
case_insensitive = TRUE,
weight = 0.01, residual = TRUE) {

valuetype <- match.arg(valuetype)
if (weight < 0)
stop("weight must be pisitive a value")
id_key <- id_feat <- integer()
for (i in seq_along(dictionary)) {
f <- featnames(dfm_select(x, dictionary[i]))
id_key <- c(id_key, rep(i, length(f)))
id_feat <- c(id_feat, match(f, featnames(x)))
}
count <- rep(floor(sum(x) * weight), length(id_feat))
key <- names(dictionary)
if (residual)
key <- c(key, "other")
result <- Matrix::sparseMatrix(
i = id_feat,
j = id_key,
x = count,
dims = c(nfeat(x), length(key)),
dimnames = list(featnames(x), key)
)
as.dfm(result)
}
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ BEGIN_RCPP
END_RCPP
}
// qatd_cpp_lda
List qatd_cpp_lda(arma::sp_mat& mt, int k, int max_iter, double alpha, double beta, bool verbose);
RcppExport SEXP _quanteda_textmodels_qatd_cpp_lda(SEXP mtSEXP, SEXP kSEXP, SEXP max_iterSEXP, SEXP alphaSEXP, SEXP betaSEXP, SEXP verboseSEXP) {
List qatd_cpp_lda(arma::sp_mat& mt, int k, int max_iter, double alpha, double beta, arma::sp_mat& seeds, bool verbose);
RcppExport SEXP _quanteda_textmodels_qatd_cpp_lda(SEXP mtSEXP, SEXP kSEXP, SEXP max_iterSEXP, SEXP alphaSEXP, SEXP betaSEXP, SEXP seedsSEXP, SEXP verboseSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -29,8 +29,9 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< int >::type max_iter(max_iterSEXP);
Rcpp::traits::input_parameter< double >::type alpha(alphaSEXP);
Rcpp::traits::input_parameter< double >::type beta(betaSEXP);
Rcpp::traits::input_parameter< arma::sp_mat& >::type seeds(seedsSEXP);
Rcpp::traits::input_parameter< bool >::type verbose(verboseSEXP);
rcpp_result_gen = Rcpp::wrap(qatd_cpp_lda(mt, k, max_iter, alpha, beta, verbose));
rcpp_result_gen = Rcpp::wrap(qatd_cpp_lda(mt, k, max_iter, alpha, beta, seeds, verbose));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -83,7 +84,7 @@ END_RCPP

static const R_CallMethodDef CallEntries[] = {
{"_quanteda_textmodels_qatd_cpp_ca", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_ca, 2},
{"_quanteda_textmodels_qatd_cpp_lda", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_lda, 6},
{"_quanteda_textmodels_qatd_cpp_lda", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_lda, 7},
{"_quanteda_textmodels_qatd_cpp_tbb_enabled", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_tbb_enabled, 0},
{"_quanteda_textmodels_qatd_cpp_wordfish_dense", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_wordfish_dense, 7},
{"_quanteda_textmodels_qatd_cpp_wordfish", (DL_FUNC) &_quanteda_textmodels_qatd_cpp_wordfish, 9},
Expand Down
4 changes: 3 additions & 1 deletion src/lda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace Rcpp;

// [[Rcpp::export]]
List qatd_cpp_lda(arma::sp_mat &mt, int k, int max_iter, double alpha, double beta,
bool verbose) {
arma::sp_mat &seeds, bool verbose) {
model lda;
lda.set_data(mt);

Expand All @@ -22,6 +22,8 @@ List qatd_cpp_lda(arma::sp_mat &mt, int k, int max_iter, double alpha, double be
if (verbose)
lda.verbose = verbose;
if (lda.init_est() == 0) {
if (arma::size(seeds) == arma::size(lda.nw) && arma::accu(seeds) > 0)
lda.nw = lda.nw + arma::conv_to<arma::umat>::from(arma::mat(seeds));
lda.estimate();
}

Expand Down
2 changes: 1 addition & 1 deletion src/lda.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ int model::sampling(int m, int n, int w) {
// do multinomial sampling via cumulative method
for (int k = 0; k < K; k++) {
p[k] = (nw.at(w, k) + beta) / (nwsum[k] + Vbeta) *
(nd.at(m, k) + alpha) / (ndsum[m] + Kalpha);
(nd.at(m, k) + alpha) / (ndsum[m] + Kalpha);
}
// cumulate multinomial parameters
for (int k = 1; k < K; k++) {
Expand Down

0 comments on commit c22bd43

Please sign in to comment.