-
Notifications
You must be signed in to change notification settings - Fork 150
Latent Dirichlet Allocation #172
base: master
Are you sure you want to change the base?
Changes from 6 commits
4a06bb8
dd18418
8e21fba
0f7db31
2bea906
057c2b8
e34f191
96b2896
3d507c8
e52ee17
39bf12b
89b1ca9
8be4d20
8689743
fa1b0ec
b96e49b
174a522
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
/// An example of how Latent Diriclhet Allocation (LDA) can be used. This example begins by | ||
/// generating a distribution of words to categories. This distribution is creatred so that | ||
/// there are 10 topics. Each of the 25 words are assigned to two topics with equal probability. | ||
/// (The distribution of words is printed to the screen as a chart. Each entry in the chart | ||
///corresponds to a word in the vocabulary, arranged into a square for easy viewing). Documents | ||
/// are then generated based on these distributions (each topic is assumed equally likely to be | ||
/// assigned to a document, but each document has only one topic). | ||
/// | ||
/// Once the documents are created, then the example uses LDA to attempt to reverse engineer the | ||
/// distrbution of words, and prints the results to the screen for comparison. | ||
|
||
extern crate rusty_machine; | ||
extern crate rand; | ||
extern crate rulinalg; | ||
|
||
use rusty_machine::linalg::{Matrix, BaseMatrix, Vector}; | ||
use rusty_machine::learning::UnSupModel; | ||
use rusty_machine::learning::lda::LDA; | ||
|
||
use rand::{thread_rng, Rng}; | ||
use rand::distributions::{gamma, IndependentSample}; | ||
|
||
use std::cmp::max; | ||
|
||
/// Given `topic_count` topics, this function will create a distrbution of words for each | ||
/// topic. For simplicity, this function assumes that the total number of words in the corpus | ||
/// will be `(topic_count / 2)^2`. | ||
fn generate_word_distribution(topic_count: usize) -> Matrix<f64> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor point: Potential issue here with assuming |
||
let width = topic_count / 2; | ||
let vocab_size = width * width; | ||
let initial_value = 1.0 / width as f64; | ||
Matrix::from_fn(topic_count, vocab_size, |col, row| { | ||
if row < width { | ||
// Horizontal topics | ||
if col / width == row { | ||
initial_value | ||
} else { | ||
0.0 | ||
} | ||
} else { | ||
//Vertical topics | ||
if col % width == (row - width) { | ||
initial_value | ||
} else { | ||
0.0 | ||
} | ||
} | ||
}) | ||
} | ||
|
||
/// Samples `count` times from a dirichlet distribution with alpha as given and | ||
/// beta 1.0. | ||
fn get_dirichlet(count: usize, alpha: f64) -> Vector<f64> { | ||
let mut rng = thread_rng(); | ||
let g_dist = gamma::Gamma::new(alpha, 1.0); | ||
let result = Vector::from_fn(count, |_| { | ||
g_dist.ind_sample(&mut rng) | ||
}); | ||
let sum = result.sum(); | ||
result / sum | ||
} | ||
|
||
/// Generates a document based on a word distributiion as given. The topics are randomly sampled | ||
/// from a dirichlet distribution and then the word sampled from the selected topic. | ||
fn generate_document(word_distribution: &Matrix<f64>, topic_count:usize, vocab_size: usize, document_length: usize, alpha: f64) -> Vec<usize> { | ||
let mut document = vec![0; vocab_size]; | ||
let topic_distribution = get_dirichlet(topic_count, alpha); | ||
for _ in 0..document_length { | ||
let topic = choose_from(&topic_distribution); | ||
let word = choose_from(&word_distribution.row(topic).into()); | ||
document[word] += 1; | ||
} | ||
document | ||
} | ||
|
||
/// Generate a collection of documents based on the word distribution | ||
fn generate_documents(word_distribution: &Matrix<f64>, topic_count: usize, vocab_size: usize, document_count: usize, document_length: usize, alpha: f64) -> Matrix<usize> { | ||
let mut documents = Vec::with_capacity(vocab_size * document_count); | ||
for _ in 0..document_count { | ||
documents.append(&mut generate_document(word_distribution, topic_count, vocab_size, document_length, alpha)); | ||
} | ||
Matrix::new(document_count, vocab_size, documents) | ||
} | ||
|
||
/// Chooses from a vector of probailities. | ||
fn choose_from(probability: &Vector<f64>) -> usize { | ||
let mut rng = thread_rng(); | ||
let selection:f64 = rng.next_f64(); | ||
let mut total:f64 = 0.0; | ||
for (index, p) in probability.iter().enumerate() { | ||
total += *p; | ||
if total >= selection { | ||
return index; | ||
} | ||
} | ||
return probability.size() - 1; | ||
} | ||
|
||
/// Displays the distrbution of words to a topic as a square graph | ||
fn topic_to_string(topic: &Vector<f64>, width: usize, topic_index: usize) -> String { | ||
let max = topic.iter().fold(0.0, |a, b|{ | ||
if a > *b { | ||
a | ||
} else { | ||
*b | ||
} | ||
}); | ||
let mut result = String::with_capacity(topic.size() * (topic.size()/width) + 18); | ||
result.push_str(&format!("Topic {}\n", topic_index)); | ||
result.push_str("-------\n"); | ||
for (index, element) in topic.iter().enumerate() { | ||
let col = index % width; | ||
let out = element / max * 9.0; | ||
if out >= 1.0 { | ||
result.push_str(&(out as u32).to_string()); | ||
} else { | ||
result.push('.'); | ||
} | ||
if col == width - 1 { | ||
result.push('\n'); | ||
} | ||
} | ||
result | ||
} | ||
|
||
|
||
/// Prints a collection of multiline strings in columns | ||
fn print_multi_line(o: &Vec<String>, column_width: usize) { | ||
let o_split:Vec<_> = o.iter().map(|col| {col.split('\n').collect::<Vec<_>>()}).collect(); | ||
let mut still_printing = true; | ||
let mut line_index = 0; | ||
while still_printing { | ||
let mut gap = 0; | ||
still_printing = false; | ||
for col in o_split.iter() { | ||
if col.len() > line_index { | ||
if gap > 0 { | ||
print!("{:width$}", "", width=column_width * gap); | ||
gap = 0; | ||
} | ||
let line = col[line_index]; | ||
print!("{:width$}", line, width=column_width); | ||
still_printing = true; | ||
} else { | ||
gap += 1; | ||
} | ||
} | ||
print!("\n"); | ||
line_index += 1 | ||
|
||
} | ||
} | ||
|
||
|
||
/// Prints the word distribution within topics | ||
fn print_topic_distribution(dist: &Matrix<f64>, topic_count: usize, width: usize) { | ||
let top_strings = &dist.row_iter().take(topic_count/2).enumerate().map(|(topic_index, topic)|topic_to_string(&topic.into(), width, topic_index + 1)).collect(); | ||
let bottom_strings = &dist.row_iter().skip(topic_count/2).enumerate().map(|(topic_index, topic)|topic_to_string(&topic.into(), width, topic_index + 1 + topic_count / 2)).collect(); | ||
|
||
print_multi_line(top_strings, max(12, width + 1)); | ||
print_multi_line(bottom_strings, max(12, width + 1)); | ||
} | ||
|
||
pub fn main() { | ||
// Set initial constants | ||
// These can be changed as you wish | ||
let topic_count = 28; | ||
let document_length = 100; | ||
let document_count = 500; | ||
let alpha = 0.1; | ||
|
||
// Don't change these though; they are calculated based on the above | ||
let width = topic_count / 2; | ||
let vocab_count = width * width; | ||
println!("Creating word distribution"); | ||
let word_distribution = generate_word_distribution(topic_count); | ||
println!("Distrbution generated:"); | ||
print_topic_distribution(&word_distribution, topic_count, width); | ||
println!("Generating documents"); | ||
let input = generate_documents(&word_distribution, topic_count, vocab_count, document_count, document_length, alpha); | ||
let lda = LDA::new(topic_count, alpha, 0.1); | ||
println!("Predicting word distrbution from generated documents"); | ||
let result = lda.predict(&(input, 30)).unwrap(); | ||
let dist = result.phi(); | ||
println!("Prediction completed. Predicted word distribution:"); | ||
println!("(Should be similar generated distribution above)", ); | ||
|
||
print_topic_distribution(&dist, topic_count, width); | ||
|
||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,9 +44,24 @@ pub fn accuracy<I>(outputs: I, targets: I) -> f64 | |
correct as f64 / len | ||
} | ||
|
||
|
||
/// Returns the fraction of outputs rows which match their target. | ||
pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 { | ||
accuracy(outputs.iter_rows(), targets.iter_rows()) | ||
pub fn row_accuracy<T: PartialEq>(outputs: &Matrix<T>, targets: &Matrix<T>) -> f64 { | ||
|
||
assert!(outputs.rows() == targets.rows()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little confused by this change. I'm guessing it comes from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, I am pretty certain I took this one directly from the linalg bump PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've now merged a new linalg bump PR into master. I guess you'll want to rebase from that (or just C+P the relevant modules onto a new branch :) ) |
||
let len = outputs.rows() as f64; | ||
|
||
let correct = outputs.row_iter() | ||
.zip(targets.row_iter()) | ||
.filter(|&(ref x, ref y)| x.raw_slice() | ||
.iter() | ||
.zip(y.raw_slice()) | ||
.all(|(v1, v2)| v1 == v2)) | ||
.count(); | ||
correct as f64 / len | ||
|
||
// Row doesn't impl PartialEq | ||
// accuracy(outputs.row_iter(), targets.row_iter()) | ||
} | ||
|
||
/// Returns the precision score for 2 class classification. | ||
|
@@ -212,7 +227,8 @@ pub fn neg_mean_squared_error(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f | |
#[cfg(test)] | ||
mod tests { | ||
use linalg::Matrix; | ||
use super::{accuracy, precision, recall, f1, neg_mean_squared_error}; | ||
use super::{accuracy, precision, recall, f1, neg_mean_squared_error, row_accuracy}; | ||
|
||
|
||
#[test] | ||
fn test_accuracy() { | ||
|
@@ -331,6 +347,34 @@ mod tests { | |
f1(outputs.iter(), targets.iter()); | ||
} | ||
|
||
#[test] | ||
fn test_row_accuracy() { | ||
let outputs = matrix![1, 0; | ||
0, 1; | ||
1, 0]; | ||
let targets = matrix![1, 0; | ||
0, 1; | ||
1, 0]; | ||
assert_eq!(row_accuracy(&outputs, &targets), 1.0); | ||
|
||
let outputs = matrix![1, 0; | ||
0, 1; | ||
1, 0]; | ||
let targets = matrix![0, 1; | ||
0, 1; | ||
1, 0]; | ||
assert_eq!(row_accuracy(&outputs, &targets), 2. / 3.); | ||
|
||
let outputs = matrix![1., 0.; | ||
0., 1.; | ||
1., 0.]; | ||
let targets = matrix![0., 1.; | ||
0., 1.; | ||
1., 0.]; | ||
assert_eq!(row_accuracy(&outputs, &targets), 2. / 3.); | ||
} | ||
|
||
|
||
#[test] | ||
fn test_neg_mean_squared_error_1d() { | ||
let outputs = Matrix::new(3, 1, vec![1f64, 2f64, 3f64]); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo here:
distribution is CREATED
.