Skip to content

Commit

Permalink
Towards TSP
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishal Paudel committed Nov 1, 2023
1 parent 4d5558d commit 71d2c62
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 49 deletions.
5 changes: 1 addition & 4 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
#!/usr/bin/env python3

from searchViz.Game import Game
from searchViz.constants import NUM_NODES, SEARCH_METHOD


def main():
search = SEARCH_METHOD()
aGame = Game(search=search, num_nodes=NUM_NODES)

aGame = Game()
aGame.run()


Expand Down
14 changes: 7 additions & 7 deletions src/searchViz/Game.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@
RED,
NODE_RADIUS,
SEARCH_RATE,
SEARCH_METHOD,
NUM_NODES
)


from .Graph import Graph
from ._typing import NodeCount

from .Search import Search


class Game:
def __init__(self, search: Search, num_nodes: NodeCount | int) -> None:
def __init__(self) -> None:
pg.init()

# main attributes of the game
self.graph = Graph(num_nodes)
self.search = search
self.graph = Graph(n=NUM_NODES)

self.search = SEARCH_METHOD()
self.search.graph = self.graph

# pg initialization
self.screen = pg.display.set_mode(SCR_SIZE)
self.graph_surf = pg.Surface(self.screen.get_size(), pg.SRCALPHA)
pg.display.set_caption(f"Search Method: {search.name}")
pg.display.set_caption(f"Search Method: {self.search.name}")

# more helper attributes
self.font = pg.font.Font(None, 36)
Expand Down
42 changes: 29 additions & 13 deletions src/searchViz/Graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, n: NodeCount | int):

# start and end nodes
start, goal = random.randint(0, self.N_num, size=2, dtype=NodeCount)
self.start_node = Node(start)
self.goal_node = Node(goal)
self.start = start
self.goal = goal

self.N_colors[start] = colors["YELLOW"]
self.N_colors[goal] = colors["RED"]
Expand Down Expand Up @@ -64,29 +64,45 @@ def MoveGen(self, state: NodeType) -> NodeList:
"""
neighbors = []

id = NodeCount(0)
id = 0
while id < self.N_num:
if (
self.edge_connections[id, state.id]
or self.edge_connections[state.id, id]
self.edge_connections[id, state]
or self.edge_connections[state, id]
):
neighbors.append(Node(id))
neighbors.append(id)

id += 1

# print("neighbors", neighbors)
# exit(0)

# # updating the node-ui : will not work, as need to know closed for not re-entering node colors
# for node in neighbors:
# id = node.id
# self.N_colors[id] = colors["BLUE"]
# self.N_radii[id] = 2 * NODE_RADIUS

return neighbors

def GoalTest(self, state: NodeType) -> bool:
foundGoal = state.id == self.goal_node.id
foundGoal = state == self.goal

# # updating the node-ui
# self.N_colors[id] = colors["RED"]
# self.N_radii[id] = 2 * NODE_RADIUS

return foundGoal

def update_nodes(self):
for open_ids in self.open_ids:
self.N_colors[open_ids[0].id] = colors["BLUE"]
self.N_radii[open_ids[0].id] = 4 * NODE_RADIUS
self.N_colors[self.open_ids] = colors["BLUE"]
self.N_radii[self.open_ids] = 2 * NODE_RADIUS

for id in self.closed_ids:
self.N_colors[id] = colors["RED"]
self.N_radii[id] = 1.5 * NODE_RADIUS

for closed_ids in self.closed_ids:
self.N_colors[closed_ids[0].id] = colors["RED"]
self.N_radii[closed_ids[0].id] = 1.5 * NODE_RADIUS
return


def create_nodes(n: NodeCount) -> NodeLocs:
Expand Down
31 changes: 16 additions & 15 deletions src/searchViz/Search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,55 @@

from .search_utils import ReconstructPath, MakePairs, RemoveSeen

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .Graph import Graph


class Search:
def __init__(self, name, graph: Graph):
class Search(ABC):
def __init__(self, name: str, graph):
self.name = name
# self.search = types.MethodType(search, Search)
self.graph = graph
self.graph: Graph = graph

@abstractmethod
def search(self):
raise NotImplementedError


class depthfirstsearch(Search):
def __init__(self):
self.name = "DFS"
self.graph: Graph
self.graph : Graph

self.i = 0

def search(self):
open = [(self.graph.start_node, None)]
open = [(self.graph.start, None)]
closed = []

print(open[0][0].id)

while len(open) > 0:
nodePair = open[0]
node = nodePair[0]
if self.graph.GoalTest(node):
print("Found goal!")
path = [node.id for node in ReconstructPath(nodePair, closed)]
path = [node for node in ReconstructPath(nodePair, closed)]
print(path)
return path
else:
closed = [nodePair] + closed
closed.insert(0, nodePair)
self.graph.closed_ids.insert(0, nodePair[0])

# the following methods are time-hogs because of stack creations because of function calls
children = self.graph.MoveGen(node)
noLoops = RemoveSeen(children, open, closed)
new = MakePairs(noLoops, node)
open = new + open[1:] # The only change from DFS

self.graph.open_ids = open
self.graph.closed_ids = closed
open[0:1] = new
self.graph.open_ids[0:1] = [node[0] for node in new]


yield
print("No path found")
Expand All @@ -62,11 +65,9 @@ def __init__(self):
self.i = 0

def search(self):
open = [(self.graph.start_node, None)]
open = [(self.graph.start, None)]
closed = []

print(open[0][0].id)

while len(open) > 0:
nodePair = open[0]
node = nodePair[0]
Expand Down
2 changes: 1 addition & 1 deletion src/searchViz/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# pass


NodeType = Node
NodeType = np.uint16
NodeLocs = NDArray[np.float64]
NodeList = typing.List[Node]
NodeCount = np.uint16
Expand Down
14 changes: 7 additions & 7 deletions src/searchViz/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@

# Search configuration
SEARCH_METHOD = depthfirstsearch
SEARCH_RATE = 0.05
SEARCH_RATE = 0 #.0000000000000000005

# Graph configuration
NUM_NODES = 2000
NUM_NODES = 5000
NODE_RADIUS = 2


# MODEL CONFIGURATIONS
# distribution of nodes, available options: uniform, gaussian
NODE_X_MODEL = "gaussian"
NODE_Y_MODEL = "gaussian"
NODE_X_MODEL = "uniform"
NODE_Y_MODEL = "uniform"

# available options threshold, exponential
E_MODEL = "threshold"
E_MAX_LEN = 20
E_MIN_LEN = 15
E_MAX_LEN = 24
E_MIN_LEN = 13

# OTHERS
N_X_DISTRIBUTION = N_Distbn(SCREEN_WIDTH, model=NODE_X_MODEL)
Expand All @@ -39,7 +39,7 @@
RED = (255, 0, 0, 255)
GREEN = (0, 255, 0, 255)
BLUE = (0, 255, 255, 255)
WHITE = (255, 255, 255, 255)
WHITE = (255, 255, 255, 120)
YELLOW = (255, 255, 153, 200)
BLACK = (0, 0, 0, 255)

Expand Down
2 changes: 1 addition & 1 deletion src/searchViz/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def RemoveSeen(nodeList, openList, closedList):

def OccursIn(node, listOfPairs):
for pair in listOfPairs:
if node.id == pair[0].id:
if node == pair[0]:
return True
return False

Expand Down
1 change: 0 additions & 1 deletion src/searchViz/utils.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
import time

0 comments on commit 71d2c62

Please sign in to comment.