-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInitializer101.py
54 lines (42 loc) · 2.65 KB
/
Initializer101.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from Initializer import Initializer
from pyspark import StorageLevel
from pyspark.mllib.linalg.distributed import MatrixEntry
import sys
sys.path.insert(1, '../utils')
sys.path.insert(1, '..')
import SparseDistributedVector as sdv
import SparseDistributedMatrix as sdm
import PandiNetwork as pn
class Initializer101(Initializer):
def __init__(self, nbr_vertices, nbr_edges, nbr_infected = 0, nbr_recovered = 0):
self.nbr_vertices = nbr_vertices
self.nbr_edges = nbr_edges
self.nbr_infected = nbr_infected
self.nbr_recovered = nbr_recovered
super().__init__(nbr_vertices, nbr_edges, nbr_infected, nbr_recovered)
def initialize_vertices(self):
df = self.spark.range(0, self.nbr_vertices, 1).toDF("id").orderBy(F.rand()).persist(StorageLevel.MEMORY_AND_DISK)
# df = df.withColumn('score', F.when(F.rand() >= F.lit(self.prob_infection), F.lit(1.0)).otherwise(F.lit(0.0)))
infected = df.limit(self.nbr_infected).withColumn('health_status', F.lit(1.0))
recovered = df.select('id').exceptAll(infected.select('id')).limit(self.nbr_recovered)
recovered = recovered.withColumn('health_status', F.lit(-1.0))
total = infected.unionAll(recovered).persist(StorageLevel.MEMORY_AND_DISK)
rest = df.select('id').exceptAll(total.select('id')).withColumn('health_status', F.lit(0.0))
self.vertices = rest.unionAll(total).withColumn('score', F.when(F.col('health_status') == F.lit(-1.0), F.lit(0.0)).otherwise(F.col('health_status')))\
.select('id', 'score', 'health_status').orderBy('id').persist(StorageLevel.MEMORY_AND_DISK)
return self.vertices
def initialize_edges(self, vertices):
src = vertices.select(F.col("id")).orderBy(F.rand()).limit(self.nbr_edges).withColumnRenamed("id", "src") \
.withColumn("id", F.row_number().over(Window.orderBy(F.monotonically_increasing_id())))
src.createOrReplaceTempView("src")
vertices.createOrReplaceTempView("vertices")
query = self.spark.sql("select vertices.id from vertices minus select src.src from src")
dst = query.orderBy(F.rand()).limit(self.nbr_edges).withColumnRenamed("id", "dst") \
.withColumn("id", F.row_number().over(Window.orderBy(F.monotonically_increasing_id())))
self.edges = src.join(dst, src.id == dst.id).select(F.col('src'), F.col('dst')).persist(StorageLevel.MEMORY_AND_DISK)
return self.edges
def toPandiNetwork(self):
return pn.PandiNetwork(self.vertices, self.edges, self.nbr_vertices)