Skip to content

Commit

Permalink
🎨 Type hinted the heck out of the project
Browse files Browse the repository at this point in the history
Some things are still left to type hint.
  • Loading branch information
vishalpaudel committed Oct 28, 2023
1 parent c337a11 commit 44e317b
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 123 deletions.
6 changes: 5 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from searchViz.Search import depthfirstsearch
from searchViz.constants import NUM_NODES

from searchViz._typing import NodeCount

# to hide the message


def main():
aGame = Game(search_method=depthfirstsearch, num_nodes=NUM_NODES)
dfs = depthfirstsearch()

aGame = Game(search=dfs, num_nodes=NodeCount(NUM_NODES))

aGame.run()

Expand Down
47 changes: 30 additions & 17 deletions src/searchViz/Game.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@


from .Graph import Graph
from ._typing import NodeCount, SearchMethod

from .Search import Search


class Game:
def __init__(self, search_method: search_method, num_nodes: int) -> None:
def __init__(self, search: Search, num_nodes: NodeCount) -> None:
pg.init()

# main attributes of the game
self.Graph = Graph(num_nodes)
self.graph = Graph(num_nodes)
self.search = search
self.search.graph = self.graph

# pg initialization
self.screen = pg.display.set_mode(SCR_SIZE)
self.graph_surf = pg.Surface(self.screen.get_size(), pg.SRCALPHA)
pg.display.set_caption(f"Search Method: {search_method.name}")
pg.display.set_caption(f"Search Method: {search.name}")

# more helper attributes
self.font = pg.font.Font(None, 36)
Expand Down Expand Up @@ -60,15 +65,21 @@ def handle_events(self) -> None:

def draw_graph(self) -> None:
_start_time = time.time()
print("\tDrawing the Graph")
print("\n🕰\tDrawing the Graph")

# Unpack frequently used variables
graph_surf = self.graph_surf
num_nodes = self.Graph.N_num
nodes = self.Graph.N_locs
colors = self.Graph.N_colors
radius = self.Graph.N_radii
edges = self.Graph.edge_connections
num_nodes = self.graph.N_num
nodes = self.graph.N_locs
colors = self.graph.N_colors
radius = self.graph.N_radii
edges = self.graph.edge_connections

# Draw edges
edge_indices = np.transpose(np.where(edges))
for i, j in edge_indices:
# TODO: tuple type conversion overhead? Probably not
pg.draw.line(graph_surf, WHITE, tuple(nodes[i]), tuple(nodes[j]))

# Draw nodes
# TODO: vectorization of this possible? Probably not
Expand All @@ -77,15 +88,12 @@ def draw_graph(self) -> None:
graph_surf, color=colors[i], center=nodes[i], radius=radius[i]
)

# Draw edges
edge_indices = np.transpose(np.where(edges))
for i, j in edge_indices:
# TODO: tuple type conversion overhead? Probably not
pg.draw.line(graph_surf, WHITE, tuple(nodes[i]), tuple(nodes[j]))
_txt = self.font.render(f"{i}", 1, (255, 255, 255))
graph_surf.blit(_txt, nodes[i])

_end_time = time.time()
_elapsed_time = _end_time - _start_time
print("✏️\tCompleted the drawing!")
print("\tCompleted the drawing!\n")

print(f"\t🕰️ Took {_elapsed_time:.3f} seconds.\n")

Expand All @@ -95,6 +103,8 @@ def run(self) -> None:
last_time = pg.time.get_ticks()
step = 0
self.draw_graph()

generated_open = self.search.search()
while self.running:
cur_time = pg.time.get_ticks()

Expand All @@ -114,7 +124,10 @@ def run(self) -> None:
step += 1
last_time = cur_time
# APPLY SEARCH HERE

self.Search.run()
try:
print(next(generated_open))
except StopIteration:
print("Goal Not found")
self.start_search = False

pg.quit()
57 changes: 28 additions & 29 deletions src/searchViz/Graph.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
#!/usr/bin/env python

from .constants import NODES_X_DISTRIBUTION, NODES_Y_DISTRIBUTION
from .constants import NODES_X_DISTRIBUTION, NODES_Y_DISTRIBUTION, RED, YELLOW
from .constants import EDGE_CONFIDENCE, BLUE, NODE_RADIUS, WHITE

import numpy as np
import time
from typing import List, Callable

import threading

from ._typing import NodeType, NodeList, NodeLocs, NodeCount
from .Node import Node

class Node:
def __init__(self, id: np.uint16):
"""
id: the id of the node, np.uint16 for maximum
of 2^16 (= 65535 + 1) nodes on the screen
"""
self.id = id
import threading


class Graph:
def __init__(self, n: int):
def __init__(self, n: NodeCount):
self.N_num = n

_start_time = time.time()

# start and end nodes
start, end = np.random.randint(0, self.N_num, size=2, dtype=np.uint16)
self.start_node = Node(start)
self.end_node = Node(end)

# Node stuff
self.N = [Node(np.uint16(i)) for i in range(self.N_num)]
self.N = np.arange(0, self.N_num, dtype=NodeCount)
self.N_locs = create_nodes(self.N_num)

self.N_colors = np.full((self.N_num, 4), BLUE, dtype=np.uint8)
self.N_radii = np.full((self.N_num,), NODE_RADIUS, dtype=np.uint8)

# start and end nodes
start, goal = np.random.randint(0, self.N_num, size=2, dtype=NodeCount)
self.start_node = Node(start)
self.goal_node = Node(goal)

self.N_colors[self.start_node.id] = YELLOW
self.N_colors[self.goal_node.id] = RED

self.N_radii[self.start_node.id] = 5 * NODE_RADIUS
self.N_radii[self.goal_node.id] = 5 * NODE_RADIUS

# Edge stuff
self.edge_connections, self.edge_colors = create_edges(self.N_locs)

Expand All @@ -45,34 +44,34 @@ def __init__(self, n: int):

print(f"\t🕰️ Took {_elapsed_time:.3f} seconds.\n")

# The search function
self.run: Callable[[Node], List[Node]]

########################################################################
# THE MOVEGEN AND GOALTEST FUNCTIONS
########################################################################

def MoveGen(self, state: Node) -> List[Node]:
def MoveGen(self, state: NodeType) -> NodeList:
"""
Takes a state and returns an array of neighbors
"""
neighbors = []

id = np.uint16(0)
while id < self.N_num:
if self.edge_connections[id, state.id] == 1:
if (
self.edge_connections[id, state.id]
or self.edge_connections[state.id, id]
):
neighbors.append(Node(id))
id += 1

print()
return neighbors

def GoalTest(self, state: Node) -> bool:
print(state)
return True
def GoalTest(self, state: NodeType) -> bool:
print(state.id, self.goal_node.id)
return state.id == self.goal_node.id


def create_nodes(n: int):
def create_nodes(n: NodeCount) -> NodeLocs:
"""
Returns a vector of `n` points, colors, radii:
[
Expand All @@ -85,8 +84,8 @@ def create_nodes(n: int):
"""

print(f"🕰️\tCreating {n} nodes... timer started...")
x_values = NODES_X_DISTRIBUTION(n)
y_values = NODES_Y_DISTRIBUTION(n)
x_values = NODES_X_DISTRIBUTION(int(n))
y_values = NODES_Y_DISTRIBUTION(int(n))

node_locs = np.column_stack((x_values, y_values))

Expand Down
11 changes: 11 additions & 0 deletions src/searchViz/Node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np


# Seperated from the Graph.py file because of circular imports in _typing.py
class Node:
def __init__(self, id: np.uint16):
"""
id: the id of the node, np.uint16 for maximum
of 2^16 (= 65535 + 1) nodes on the screen
"""
self.id = id
65 changes: 40 additions & 25 deletions src/searchViz/Search.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
#!/usr/bin/env python

from .search_utils import ReconstructPath, MakePairs, RemoveSeen
from .Graph import Graph


def depthfirstsearch(self, start_state):
open = [(start_state, None)]
closed = []

while len(open) > 0 and len(closed) < 50:
print(f"open: {len(open)}")
print(f"closed: {len(closed)}")
nodePair = open[0]
node = nodePair[0]
# print(node)
# print(f"closed {closed}")
# print()
if self.GoalTest(node) is True:
print("Found goal!")
return ReconstructPath(nodePair, closed)
else:
closed = [nodePair] + closed
children = self.MoveGen(node)
noLoops = RemoveSeen(children, open, closed)
new = MakePairs(noLoops, node)
open = new + open[1:] # The only change from DFS
yield open

print("No path found")
return -1
class Search:
def __init__(self, name, graph: Graph):
self.name = name
# self.search = types.MethodType(search, Search)
self.graph = graph

def search(self):
raise NotImplementedError


class depthfirstsearch(Search):
def __init__(self):
self.name = "DFS"
self.graph: Graph

self.i = 0

def search(self):
open = [(self.graph.start_node, None)]
print(open[0][0].id)
closed = []

while len(open) > 0:
nodePair = open[0]
node = nodePair[0]
if self.graph.GoalTest(node) is True:
# print("Found goal!")
# print(path = ReconstructPath(nodePair, closed)])
return
else:
closed = [nodePair] + closed
children = self.graph.MoveGen(node)
noLoops = RemoveSeen(children, open, closed)
new = MakePairs(noLoops, node)
open = new + open[1:] # The only change from DFS
yield [nodepair[0].id for nodepair in open]

# print("No path found")
return -1
21 changes: 21 additions & 0 deletions src/searchViz/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3

from .Node import Node

import numpy as np

import typing
from numpy.typing import NDArray

# TODO: put Node.py Node class inside Graph file? Circular imports?
# from __future__ import annotations
# from typing import TYPE_CHECKING
# if TYPE_CHECKING:
# pass


NodeType = Node
NodeLocs = NDArray[np.float64]
NodeList = typing.List[Node]
NodeCount = np.uint16
SearchMethod = typing.Callable[[], typing.List[Node]]
25 changes: 11 additions & 14 deletions src/searchViz/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

from .models import nodes_dist_wrapper, edge_dist_wrapper
from .models import N_Distbn, E_Distbn


# Screen configuration
Expand All @@ -11,29 +11,26 @@


# Search configuration
NUM_NODES = 5000
NUM_NODES = 15
NODE_RADIUS = 2

SEARCH_RATE = 0.5
SEARCH_RATE = 2

# MODEL CONFIGURATIONS
# distribution of nodes, available options: uniform, gaussian
NODE_X_MODEL = "uniform"
NODE_Y_MODEL = "uniform"
NODE_X_MODEL = "gaussian"
NODE_Y_MODEL = "gaussian"

# available options threshold, exponential
EDGE_MODEL = "threshold"
MAX_EDGE_LEN = 20
MIN_EDGE_LEN = 15

E_MODEL = "threshold"
E_MAX_LEN = 200
E_MIN_LEN = 15

# OTHERS
NODES_X_DISTRIBUTION = nodes_dist_wrapper(SCREEN_WIDTH, model=NODE_X_MODEL)
NODES_Y_DISTRIBUTION = nodes_dist_wrapper(SCREEN_HEIGHT, model=NODE_Y_MODEL)
NODES_X_DISTRIBUTION = N_Distbn(SCREEN_WIDTH, model=NODE_X_MODEL)
NODES_Y_DISTRIBUTION = N_Distbn(SCREEN_HEIGHT, model=NODE_Y_MODEL)

EDGE_CONFIDENCE = edge_dist_wrapper(
max_edge_len=MAX_EDGE_LEN, min_edge_len=MIN_EDGE_LEN, model=EDGE_MODEL
)
EDGE_CONFIDENCE = E_Distbn(max_len=E_MAX_LEN, min_len=E_MIN_LEN, model=E_MODEL)

# Colours (R G B A), A is the opacity (255 is opaque)
RED = (255, 0, 0, 255)
Expand Down
Loading

0 comments on commit 44e317b

Please sign in to comment.