Skip to content
This repository has been archived by the owner on Jul 16, 2021. It is now read-only.

Latent Dirichlet Allocation #172

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ description = "A machine learning library."
repository = "https://github.com/AtheMathmo/rusty-machine"
documentation = "https://AtheMathmo.github.io/rusty-machine/"
keywords = ["machine","learning","stats","data","machine-learning"]
categories = ["science"]
readme = "README.md"
license = "MIT"

Expand All @@ -15,6 +16,6 @@ stats = []
datasets = []

[dependencies]
num = { version = "0.1.35", default-features = false }
rand = "0.3.14"
rulinalg = "0.3.7"
num = { version = "0.1.36", default-features = false }
rand = "0.3.15"
rulinalg = "0.4.2"
2 changes: 1 addition & 1 deletion benches/examples/cross_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct DummyModel {
impl SupModel<Matrix<f64>, Matrix<f64>> for DummyModel {
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
let predictions: Vec<f64> = inputs
.iter_rows()
.row_iter()
.map(|row| { self.sum + sum(row.iter()) })
.collect();
Ok(Matrix::new(inputs.rows(), 1, predictions))
Expand Down
4 changes: 2 additions & 2 deletions benches/examples/k_means.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ fn generate_data(centroids: &Matrix<f64>, points_per_centroid: usize, noise: f64

for _ in 0..points_per_centroid {
// Generate points from each centroid
for centroid in centroids.iter_rows() {
for centroid in centroids.row_iter() {
// Generate a point randomly around the centroid
let mut point = Vec::with_capacity(centroids.cols());
for feature in centroid {
for feature in centroid.iter() {
point.push(feature + normal_rv.ind_sample(&mut rng));
}

Expand Down
69 changes: 69 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This directory gathers fully-fledged programs, each using a piece of
* [SVM](#svm)
* [Neural Networks](#neural-networks)
* [Naïve Bayes](#naïve-bayes)
* [LDA](#lda)

## The Examples

Expand Down Expand Up @@ -165,3 +166,71 @@ Predicted: White; Actual: White; Accurate? true
Predicted: Red; Actual: Red; Accurate? true
Accuracy: 822/1000 = 82.2%
```

### LDA

#### Word distribution

The [word distribution](lda_gen.rs) example starts by generating a distribution
of words over topics, then generating documents based on a distribution of
topics. The example then tries to estimate the distribution of words based only
on the generated documents.

The generated distribution (G) of words are visualized as a grid, with each cell
in the grid corresponding to the probability of a particular word being
selected. Following this, documents (D) are generated based on a distribution
over these topics.

The distribution for each topic is shown, then Linear Dirichlet Allocation is
used to try to estimate the distribution (E) of words to topic, based solely on
generated documents (D).

The resulting word distribution(E) is printed. The order may not be the same,
but for each estimated topic in (E), there should be a corresponding generated
distribution in (G).

Sample run:
```
$ cargo run --example lda_gen
...
Creating word distribution
Distribution generated:
Topic 1 Topic 2 Topic 3 Topic 4 Topic 5
------- ------- ------- ------- -------
99999 ..... ..... ..... .....
..... 99999 ..... ..... .....
..... ..... 99999 ..... .....
..... ..... ..... 99999 .....
..... ..... ..... ..... 99999


Topic 6 Topic 7 Topic 8 Topic 9 Topic 10
------- ------- ------- ------- -------
9.... .9... ..9.. ...9. ....9
9.... .9... ..9.. ...9. ....9
9.... .9... ..9.. ...9. ....9
9.... .9... ..9.. ...9. ....9
9.... .9... ..9.. ...9. ....9


Generating documents
Predicting word distribution from generated documents
Prediction completed. Predicted word distribution:
(Should be similar to generated distribution above)
Topic 1 Topic 2 Topic 3 Topic 4 Topic 5
------- ------- ------- ------- -------
..8.. ..... ..... ....8 8....
..8.. ..... ..... ....8 8....
..9.. 98888 ..... ....9 8....
..8.. ..... ..... ....8 8....
..8.. ..... 88988 ....8 9....


Topic 6 Topic 7 Topic 8 Topic 9 Topic 10
------- ------- ------- ------- -------
...8. ..... .8... ..... 89888
...8. ..... .8... 88889 .....
...8. ..... .9... ..... .....
...8. 88889 .8... ..... .....
...9. ..... .8... ..... .....
```
4 changes: 2 additions & 2 deletions examples/k-means_generating_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ fn generate_data(centroids: &Matrix<f64>,

for _ in 0..points_per_centroid {
// Generate points from each centroid
for centroid in centroids.iter_rows() {
for centroid in centroids.row_iter() {
// Generate a point randomly around the centroid
let mut point = Vec::with_capacity(centroids.cols());
for feature in centroid {
for feature in centroid.iter() {
point.push(feature + normal_rv.ind_sample(&mut rng));
}

Expand Down
193 changes: 193 additions & 0 deletions examples/lda_gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/// An example of how Latent Diriclhet Allocation (LDA) can be used. This example begins by
/// generating a distribution of words to categories. This distribution is created so that
/// there are 10 topics. Each of the 25 words are assigned to two topics with equal probability.
/// (The distribution of words is printed to the screen as a chart. Each entry in the chart
/// corresponds to a word in the vocabulary, arranged into a square for easy viewing). Documents
/// are then generated based on these distributions (each topic is assumed equally likely to be
/// assigned to a document, but each document has only one topic).
///
/// Once the documents are created, then the example uses LDA to attempt to reverse engineer the
/// distrbution of words, and prints the results to the screen for comparison.

extern crate rusty_machine;
extern crate rand;
extern crate rulinalg;

use rusty_machine::linalg::{Matrix, BaseMatrix, Vector};
use rusty_machine::data::transforms::{TransformFitter, LDAFitter};

use rand::{thread_rng, Rng};
use rand::distributions::{gamma, IndependentSample};

use std::cmp::max;

// These constants control the generation algorithm. You can set them how you wish,
// although very large values for TOPIC_COUNT size will cause problems.

// TOPIC_COUNT should be even
const TOPIC_COUNT:usize = 10;
const DOCUMENT_LENGTH:usize = 100;
const DOCUMENT_COUNT:usize = 500;
const ALPHA:f64 = 0.1;
const ITERATION_COUNT:usize = 300;

/// Given `topic_count` topics, this function will create a distrbution of words for each
/// topic. For simplicity, this function assumes that the total number of words in the corpus
/// will be `(topic_count / 2)^2`.
fn generate_word_distribution(topic_count: usize) -> Matrix<f64> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point: Potential issue here with assuming topic_count is divisible by 2. From what I can tell this shouldn't break anything but I wanted to draw you attention to it regardless.

let width = topic_count / 2;
let vocab_size = width * width;
let initial_value = 1.0 / width as f64;
Matrix::from_fn(topic_count, vocab_size, |col, row| {
if row < width {
// Horizontal topics
if col / width == row {
initial_value
} else {
0.0
}
} else {
//Vertical topics
if col % width == (row - width) {
initial_value
} else {
0.0
}
}
})
}

/// Samples `count` times from a dirichlet distribution with alpha as given and
/// beta 1.0.
fn get_dirichlet(count: usize, alpha: f64) -> Vector<f64> {
let mut rng = thread_rng();
let g_dist = gamma::Gamma::new(alpha, 1.0);
let result = Vector::from_fn(count, |_| {
g_dist.ind_sample(&mut rng)
});
let sum = result.sum();
result / sum
}

/// Generates a document based on a word distributiion as given. The topics are randomly sampled
/// from a dirichlet distribution and then the word sampled from the selected topic.
fn generate_document(word_distribution: &Matrix<f64>, topic_count:usize, vocab_size: usize, document_length: usize, alpha: f64) -> Vec<usize> {
let mut document = vec![0; vocab_size];
let topic_distribution = get_dirichlet(topic_count, alpha);
for _ in 0..document_length {
let topic = choose_from(&topic_distribution);
let word = choose_from(&word_distribution.row(topic).into());
document[word] += 1;
}
document
}

/// Generate a collection of documents based on the word distribution
fn generate_documents(word_distribution: &Matrix<f64>, topic_count: usize, vocab_size: usize, document_count: usize, document_length: usize, alpha: f64) -> Matrix<usize> {
let mut documents = Vec::with_capacity(vocab_size * document_count);
for _ in 0..document_count {
documents.append(&mut generate_document(word_distribution, topic_count, vocab_size, document_length, alpha));
}
Matrix::new(document_count, vocab_size, documents)
}

/// Chooses from a vector of probailities.
fn choose_from(probability: &Vector<f64>) -> usize {
let mut rng = thread_rng();
let selection:f64 = rng.next_f64();
let mut total:f64 = 0.0;
for (index, p) in probability.iter().enumerate() {
total += *p;
if total >= selection {
return index;
}
}
return probability.size() - 1;
}

/// Displays the distrbution of words to a topic as a square graph
fn topic_to_string(topic: &Vector<f64>, width: usize, topic_index: usize) -> String {
let max = topic.iter().fold(0.0, |a, b|{
if a > *b {
a
} else {
*b
}
});
let mut result = String::with_capacity(topic.size() * (topic.size()/width) + 18);
result.push_str(&format!("Topic {}\n", topic_index));
result.push_str("-------\n");
for (index, element) in topic.iter().enumerate() {
let col = index % width;
let out = element / max * 9.0;
if out >= 1.0 {
result.push_str(&(out as u32).to_string());
} else {
result.push('.');
}
if col == width - 1 {
result.push('\n');
}
}
result
}


/// Prints a collection of multiline strings in columns
fn print_multi_line(o: &Vec<String>, column_width: usize) {
let o_split:Vec<_> = o.iter().map(|col| {col.split('\n').collect::<Vec<_>>()}).collect();
let mut still_printing = true;
let mut line_index = 0;
while still_printing {
let mut gap = 0;
still_printing = false;
for col in o_split.iter() {
if col.len() > line_index {
if gap > 0 {
print!("{:width$}", "", width=column_width * gap);
gap = 0;
}
let line = col[line_index];
print!("{:width$}", line, width=column_width);
still_printing = true;
} else {
gap += 1;
}
}
print!("\n");
line_index += 1

}
}


/// Prints the word distribution within topics
fn print_topic_distribution(dist: &Matrix<f64>, topic_count: usize, width: usize) {
let top_strings = &dist.row_iter().take(topic_count/2).enumerate().map(|(topic_index, topic)|topic_to_string(&topic.into(), width, topic_index + 1)).collect();
let bottom_strings = &dist.row_iter().skip(topic_count/2).enumerate().map(|(topic_index, topic)|topic_to_string(&topic.into(), width, topic_index + 1 + topic_count / 2)).collect();

print_multi_line(top_strings, max(12, width + 1));
print_multi_line(bottom_strings, max(12, width + 1));
}

pub fn main() {
let width = TOPIC_COUNT / 2;
let vocab_count = width * width;
println!("Creating word distribution");
let word_distribution = generate_word_distribution(TOPIC_COUNT);
println!("Distrbution generated:");
print_topic_distribution(&word_distribution, TOPIC_COUNT, width);
println!("Generating documents");
let input = generate_documents(&word_distribution, TOPIC_COUNT, vocab_count, DOCUMENT_COUNT, DOCUMENT_LENGTH, ALPHA);
let lda = LDAFitter::new(TOPIC_COUNT, ALPHA, 0.1, ITERATION_COUNT);
println!("Predicting word distrbution from generated documents");
let result = lda.fit(&input).unwrap();
let dist = result.word_distribution();

println!("Prediction completed. Predicted word distribution:");
println!("(Should be similar to generated distribution above)", );

print_topic_distribution(&dist, TOPIC_COUNT, width);


}
10 changes: 5 additions & 5 deletions examples/naive_bayes_dogs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,16 @@ fn main() {
// Score how well we did.
let mut hits = 0;
let unprinted_total = test_set_size.saturating_sub(10) as usize;
for (dog, prediction) in test_dogs.iter().zip(predictions.iter_rows()).take(unprinted_total) {
evaluate_prediction(&mut hits, dog, prediction);
for (dog, prediction) in test_dogs.iter().zip(predictions.row_iter()).take(unprinted_total) {
evaluate_prediction(&mut hits, dog, prediction.raw_slice());
}

if unprinted_total > 0 {
println!("...");
}

for (dog, prediction) in test_dogs.iter().zip(predictions.iter_rows()).skip(unprinted_total) {
let (actual_color, accurate) = evaluate_prediction(&mut hits, dog, prediction);
for (dog, prediction) in test_dogs.iter().zip(predictions.row_iter()).skip(unprinted_total) {
let (actual_color, accurate) = evaluate_prediction(&mut hits, dog, prediction.raw_slice());
println!("Predicted: {:?}; Actual: {:?}; Accurate? {:?}",
dog.color, actual_color, accurate);
}
Expand Down
10 changes: 6 additions & 4 deletions src/analysis/score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ use learning::toolkit::cost_fn::{CostFunc, MeanSqError};
/// # Panics
///
/// - outputs and targets have different length
pub fn accuracy<I>(outputs: I, targets: I) -> f64
where I: ExactSizeIterator,
I::Item: PartialEq
pub fn accuracy<I1, I2, T>(outputs: I1, targets: I2) -> f64
where T: PartialEq,
I1: ExactSizeIterator + Iterator<Item=T>,
I2: ExactSizeIterator + Iterator<Item=T>
{
assert!(outputs.len() == targets.len(), "outputs and targets must have the same length");
let len = outputs.len() as f64;
Expand All @@ -46,7 +47,8 @@ pub fn accuracy<I>(outputs: I, targets: I) -> f64

/// Returns the fraction of outputs rows which match their target.
pub fn row_accuracy(outputs: &Matrix<f64>, targets: &Matrix<f64>) -> f64 {
accuracy(outputs.iter_rows(), targets.iter_rows())
accuracy(outputs.row_iter().map(|r| r.raw_slice()),
targets.row_iter().map(|r| r.raw_slice()))
}

/// Returns the precision score for 2 class classification.
Expand Down
Loading