-
Notifications
You must be signed in to change notification settings - Fork 32
/
common.py
346 lines (264 loc) · 10.4 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
# -*- coding: utf-8 -*-
"""
Created on Fri May 26 2023 16:03:59
Modified on 2023-5-26 16:03:59
@auther: HJ https://github.com/zhaohaojie1998
"""
# 算法共同组成部分
from typing import Union
import cv2
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from queue import PriorityQueue
from dataclasses import dataclass, field
Number = Union[int, float]
__all__ = ['tic', 'toc', 'limit_angle', 'GridMap', 'PriorityQueuePro', 'ListQueue', 'SetQueue', 'Node']
# 坐标节点
@dataclass(eq=False)
class Node:
"""节点"""
x: int
y: int
cost: Number = 0 # F代价
parent: "Node" = None # 父节点指针
def __sub__(self, other) -> int:
"""计算节点与坐标的曼哈顿距离"""
if isinstance(other, Node):
return abs(self.x - other.x) + abs(self.y - other.y)
elif isinstance(other, (tuple, list)):
return abs(self.x - other[0]) + abs(self.y - other[1])
raise ValueError("other必须为坐标或Node")
def __add__(self, other: Union[tuple, list]) -> "Node":
"""生成新节点"""
x = self.x + other[0]
y = self.y + other[1]
cost = self.cost + math.sqrt(other[0]**2 + other[1]**2) # 欧式距离
return Node(x, y, cost, self)
def __eq__(self, other):
"""坐标x,y比较 -> node in list"""
if isinstance(other, Node):
return self.x == other.x and self.y == other.y
elif isinstance(other, (tuple, list)):
return self.x == other[0] and self.y == other[1]
return False
def __le__(self, other: "Node"):
"""代价<=比较 -> min(open_list)"""
return self.cost <= other.cost
def __lt__(self, other: "Node"):
"""代价<比较 -> min(open_list)"""
return self.cost < other.cost
def __hash__(self) -> int:
"""使可变对象可hash, 能放入set中 -> node in set"""
return hash((self.x, self.y)) # tuple可hash
# data in set 时间复杂度为 O(1), 但data必须可hash
# data in list 时间复杂度 O(n)
# Set版优先队列
@dataclass
class SetQueue:
"""节点优先存储队列 set 版"""
queue: set[Node] = field(default_factory=set)
# Queue容器增强
def __bool__(self):
"""判断: while Queue:"""
return bool(self.queue)
def __contains__(self, item):
"""包含: pos in Queue"""
return item in self.queue
#NOTE: in是值比较, 只看hash是否在集合, 不看id是否在集合
def __len__(self):
"""长度: len(Queue)"""
return len(self.queue)
# PriorityQueue操作
def get(self):
"""Queue 弹出代价最小节点"""
node = min(self.queue) # O(n)?
self.queue.remove(node) # O(1)
return node
def put(self, node: Node):
"""Queue 加入/更新节点"""
if node in self.queue: # O(1)
qlist = list(self.queue) # 索引元素, set无法索引需转换
idx = qlist.index(node) # O(n)
if node.cost < qlist[idx].cost: # 新节点代价更小则加入新节点
self.queue.remove(node) # O(1)
self.queue.add(node) # O(1) 移除node和加入node的hash相同, 但cost和parent不同
else:
self.queue.add(node) # O(1)
def empty(self):
"""Queue 是否为空"""
return len(self.queue) == 0
# List版优先队列
@dataclass
class ListQueue:
"""节点优先存储队列 list 版"""
queue: list[Node] = field(default_factory=list)
# Queue容器增强
def __bool__(self):
"""判断: while Queue:"""
return bool(self.queue)
def __contains__(self, item):
"""包含: pos in Queue"""
return item in self.queue
#NOTE: in是值比较, 只看value是否在列表, 不看id是否在列表
def __len__(self):
"""长度: len(Queue)"""
return len(self.queue)
def __getitem__(self, idx):
"""索引: Queue[i]"""
return self.queue[idx]
# List操作
def append(self, node: Node):
"""List 添加节点"""
self.queue.append(node) # O(1)
def pop(self, idx = -1):
"""List 弹出节点"""
return self.queue.pop(idx) # O(1) ~ O(n)
# PriorityQueue操作
def get(self):
"""Queue 弹出代价最小节点"""
idx = self.queue.index(min(self.queue)) # O(n) + O(n)
return self.queue.pop(idx) # O(1) ~ O(n)
def put(self, node: Node):
"""Queue 加入/更新节点"""
if node in self.queue: # O(n)
idx = self.queue.index(node) # O(n)
if node.cost < self.queue[idx].cost: # 新节点代价更小
self.queue[idx].cost = node.cost # O(1) 更新代价
self.queue[idx].parent = node.parent # O(1) 更新父节点
else:
self.queue.append(node) # O(1)
# NOTE try语法虽然时间复杂度更小, 但频繁抛出异常速度反而更慢
# try:
# idx = self.queue.index(node) # O(n)
# if node.cost < self.queue[idx].cost: # 新节点代价更小
# self.queue[idx].cost = node.cost # O(1) 更新代价
# self.queue[idx].parent = node.parent # O(1) 更新父节点
# except ValueError:
# self.queue.append(node) # O(1)
def empty(self):
"""Queue 是否为空"""
return len(self.queue) == 0
# 原版优先队列增强(原版也是list实现, 但get更快, put更慢)
class PriorityQueuePro(PriorityQueue):
"""节点优先存储队列 原版"""
# PriorityQueue操作
def put(self, item, block=True, timeout=None):
"""Queue 加入/更新节点"""
if item in self.queue: # O(n)
return # 修改数据会破坏二叉树结构, 就不存了
else:
super().put(item, block, timeout) # O(logn)
# Queue容器增强
def __bool__(self):
"""判断: while Queue:"""
return bool(self.queue)
def __contains__(self, item):
"""包含: pos in Queue"""
return item in self.queue
#NOTE: in是值比较, 只看value是否在列表, 不看id是否在列表
def __len__(self):
"""长度: len(Queue)"""
return len(self.queue)
def __getitem__(self, idx):
"""索引: Queue[i]"""
return self.queue[idx]
# 图像处理生成网格地图
class GridMap:
"""从图片中提取栅格地图"""
def __init__(
self,
img_path: str,
thresh: int,
high: int,
width: int,
):
"""提取栅格地图
Parameters
----------
img_path : str
原图片路径
thresh : int
图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
high : int
栅格地图高度
width : int
栅格地图宽度
"""
# 存储路径
self.__map_path = 'map.png' # 栅格地图路径
self.__path_path = 'path.png' # 路径规划结果路径
# 图像处理 # NOTE cv2 按 HWC 存储图片
image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) # 读取原图 H,W,C
thresh, map_img = cv2.threshold(image, thresh, 255, cv2.THRESH_BINARY) # 地图二值化
map_img = cv2.resize(map_img, (width, high)) # 设置地图尺寸
cv2.imwrite(self.__map_path, map_img) # 存储二值地图
# 栅格地图属性
self.map_array = np.array(map_img)
"""ndarray地图, H*W, 0代表障碍物"""
self.high = high
"""ndarray地图高度"""
self.width = width
"""ndarray地图宽度"""
def show_path(self, path_list, *, save = False):
"""路径规划结果绘制
Parameters
----------
path_list : list[Node]
路径节点组成的列表, 要求Node有x,y属性
save : bool, optional
是否保存结果图片
"""
if not path_list:
print("\n传入空列表, 无法绘图\n")
return
if not hasattr(path_list[0], "x") or not hasattr(path_list[0], "y"):
print("\n路径节点中没有坐标x或坐标y属性, 无法绘图\n")
return
x, y = [], []
for p in path_list:
x.append(p.x)
y.append(p.y)
fig, ax = plt.subplots()
map_ = cv2.imread(self.__map_path)
map_ = cv2.resize(map_, (self.width, self.high))
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # R G B
#img = img[:, :, ::-1] # R G B
map_ = map_[::-1] # 画出来的鸡哥是反的, 需要转过来
ax.imshow(map_, extent=[0, self.width, 0, self.high]) # extent[x_min, x_max, y_min, y_max]
ax.plot(x, y, c = 'r', label='path', linewidth=2)
ax.scatter(x[0], y[0], c='c', marker='o', label='start', s=40, linewidth=2)
ax.scatter(x[-1], y[-1], c='c', marker='x', label='end', s=40, linewidth=2)
ax.invert_yaxis() # 反转y轴
ax.legend().set_draggable(True)
plt.show()
if save:
plt.savefig(self.__path_path)
# matlab计时器
def tic():
'''计时开始'''
if 'global_tic_time' not in globals():
global global_tic_time
global_tic_time = []
global_tic_time.append(time.time())
def toc(name='', *, CN=True, digit=6):
'''计时结束'''
if 'global_tic_time' not in globals() or not global_tic_time: # 未设置全局变量或全局变量为[]
print('未设置tic' if CN else 'tic not set')
return
name = name+' ' if (name and not CN) else name
if CN:
print('%s历时 %f 秒。\n' % (name, round(time.time() - global_tic_time.pop(), digit)))
else:
print('%sElapsed time is %f seconds.\n' % (name, round(time.time() - global_tic_time.pop(), digit)))
# 角度归一化
def limit_angle(x, mode=1):
"""
mode1 : (-inf, inf) -> (-π, π]
mode2 : (-inf, inf) -> [0, 2π)
"""
x = x - x//(2*math.pi) * 2*math.pi # any -> [0, 2π)
if mode == 1 and x > math.pi:
return x - 2*math.pi # [0, 2π) -> (-π, π]
return x