Skip to content

Commit

Permalink
fixed collection of sparse vectors (#297)
Browse files Browse the repository at this point in the history
* fixed collectiuon of sparse vectors

* applied automatic formatting

Co-authored-by: Anthony Khong <anthony.kusumo.khong@gmail.com>
behrica and anthony-khong authored Nov 22, 2020
1 parent d0bc4e7 commit 482c4b9
Showing 4 changed files with 22 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/clojure/zero_one/geni/interop.clj
Original file line number Diff line number Diff line change
@@ -118,6 +118,11 @@
(defn vector->seq [spark-vector]
(-> spark-vector .values seq))

(defn sparse-vector->seq [spark-sparse-vector]
{:size (.size spark-sparse-vector)
:indices (-> spark-sparse-vector .indices seq)
:values (-> spark-sparse-vector .values seq)})

(defn matrix->seqs [matrix]
(->> matrix .rowIter .toSeq scala-seq->vec (map vector->seq)))

@@ -139,7 +144,7 @@
(scala-map? value) (scala-map->map value)
(spark-row? value) (spark-row->map value)
(dense-vector? value) (vector->seq value)
(sparse-vector? value) (vector->seq value)
(sparse-vector? value) (sparse-vector->seq value)
(dense-matrix? value) (matrix->seqs value)
(scala-tuple2? value) [(->clojure (._1 value)) (->clojure (._2 value))]
(scala-tuple3? value) [(->clojure (._1 value))
6 changes: 5 additions & 1 deletion test/zero_one/geni/data_sources_test.clj
Original file line number Diff line number Diff line change
@@ -232,7 +232,11 @@
(let [temp-file (.toString (create-temp-file! ".libsvm"))
read-df (do (g/write-libsvm! (libsvm-df) temp-file {:mode "overwrite"})
(g/read-libsvm! temp-file))]
(g/collect (libsvm-df)) => (g/collect read-df)))
(map #(get-in % [:features :indices]) (g/collect (libsvm-df))) => (map #(get-in % [:features :indices]) (g/collect read-df))
(map #(get-in % [:features :values]) (g/collect (libsvm-df))) => (map #(get-in % [:features :values]) (g/collect read-df))

;; (map :indices (g/collect (libsvm-df))) => (map :indices (g/collect read-df))
))

(fact "Can read and write json"
(let [temp-file (.toString (create-temp-file! ".json"))
8 changes: 8 additions & 0 deletions test/zero_one/geni/dataset_test.clj
Original file line number Diff line number Diff line change
@@ -533,3 +533,11 @@
exploded (g/with-column agged "exploded" (g/explode "suburbs_list"))]
(g/count agged) => #(< % 20)
(g/count exploded) => 20)))

(facts "On sparse vector"
(fact "collects sparse data"
(let [sparse-df
(g/create-dataframe
[(g/row (g/sparse 4 [1 3] [3.0 4.0]))]
{:test :vector})]
(g/collect-col sparse-df :test) => [{:size 4 :indices [1 3] :values [3.0 4.0]}])))
6 changes: 3 additions & 3 deletions test/zero_one/geni/ml_test.clj
Original file line number Diff line number Diff line change
@@ -259,15 +259,15 @@
(let [estimator (ml/random-forest-classifier {})
model (ml/fit (libsvm-df) estimator)]
(fact "Attributes are callable"
(ml/feature-importances model) => #(every? double? %)
(:values (ml/feature-importances model)) => #(every? double? %)
(ml/total-num-nodes model) => int?
(ml/trees model) => seq?)))

(facts "On gradient boosted tree classifier" :slow
(let [estimator (ml/gbt-classifier {:max-iter 2 :max-depth 2})
model (ml/fit (libsvm-df) estimator)]
(fact "Attributes are callable"
(ml/feature-importances model) => #(every? double? %)
(:values (ml/feature-importances model)) => #(every? double? %)
(ml/total-num-nodes model) => int?
(ml/trees model) => seq?
(ml/get-num-trees model) => int?
@@ -718,7 +718,7 @@
transformed (-> dataset
(ml/transform transformer)
(g/select "features"))]
(->> transformed g/collect-vals flatten) => #(every? double? %)
(->> transformed g/collect-vals flatten :values) => #(every? double? %)
(-> transformer ml/stages last ml/idf-vector) => #(every? double? %)))
(fact "should be able to fit the word2vec example" :slow
(let [dataset (g/table->dataset

0 comments on commit 482c4b9

Please sign in to comment.