forked from harskish/ganspace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTkTorchWindow.py
208 lines (173 loc) · 7.31 KB
/
TkTorchWindow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.
import tkinter as tk
import numpy as np
import time
from contextlib import contextmanager
import pycuda.driver
from pycuda.gl import graphics_map_flags
from glumpy import gloo, gl
from pyopengltk import OpenGLFrame
import torch
from torch.autograd import Variable
# TkInter widget that can draw torch tensors directly from GPU memory
@contextmanager
def cuda_activate(img):
"""Context manager simplifying use of pycuda.gl.RegisteredImage"""
mapping = img.map()
yield mapping.array(0,0)
mapping.unmap()
def create_shared_texture(w, h, c=4,
map_flags=graphics_map_flags.WRITE_DISCARD,
dtype=np.uint8):
"""Create and return a Texture2D with gloo and pycuda views."""
tex = np.zeros((h,w,c), dtype).view(gloo.Texture2D)
tex.activate() # force gloo to create on GPU
tex.deactivate()
cuda_buffer = pycuda.gl.RegisteredImage(
int(tex.handle), tex.target, map_flags)
return tex, cuda_buffer
# Shape batch as square if possible
def get_grid_dims(B):
S = int(B**0.5 + 0.5)
while B % S != 0:
S -= 1
return (B // S, S)
def create_gl_texture(tensor_shape):
if len(tensor_shape) != 4:
raise RuntimeError('Please provide a tensor of shape NCHW')
N, C, H, W = tensor_shape
cols, rows = get_grid_dims(N)
tex, cuda_buffer = create_shared_texture(W*cols, H*rows, 4)
return tex, cuda_buffer
# Create window with OpenGL context
class TorchImageView(OpenGLFrame):
def __init__(self, root = None, show_fps=True, **kwargs):
self.root = root or tk.Tk()
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.show_fps = show_fps
self.pycuda_initialized = False
self.animate = 0 # disable internal main loop
OpenGLFrame.__init__(self, root, **kwargs)
# Called by pyopengltk.BaseOpenGLFrame
# when the frame goes onto the screen
def initgl(self):
if not self.pycuda_initialized:
self.setup_gl(self.width, self.height)
self.pycuda_initialized = True
"""Initalize gl states when the frame is created"""
gl.glViewport(0, 0, self.width, self.height)
gl.glClearColor(0.0, 0.0, 0.0, 0.0)
self.dt_history = [1000/60]
self.t0 = time.time()
self.t_last = self.t0
self.nframes = 0
def setup_gl(self, width, height):
# setup pycuda and torch
import pycuda.gl.autoinit
import pycuda.gl
assert torch.cuda.is_available(), "PyTorch: CUDA is not available"
print('Using GPU {}'.format(torch.cuda.current_device()))
# Create tensor to be shared between GL and CUDA
# Always overwritten so no sharing is necessary
dummy = torch.cuda.FloatTensor((1))
dummy.uniform_()
dummy = Variable(dummy)
# Create a buffer with pycuda and gloo views, using tensor created above
self.tex, self.cuda_buffer = create_gl_texture((1, 3, width, height))
# create a shader to program to draw to the screen
vertex = """
uniform float scale;
attribute vec2 position;
attribute vec2 texcoord;
varying vec2 v_texcoord;
void main()
{
v_texcoord = texcoord;
gl_Position = vec4(scale*position, 0.0, 1.0);
} """
fragment = """
uniform sampler2D tex;
varying vec2 v_texcoord;
void main()
{
gl_FragColor = texture2D(tex, v_texcoord);
} """
# Build the program and corresponding buffers (with 4 vertices)
self.screen = gloo.Program(vertex, fragment, count=4)
# NDC coordinates: Texcoords: Vertex order,
# (-1, +1) (+1, +1) (0,0) (1,0) triangle strip:
# +-------+ +----+ 1----3
# | NDC | | | | / |
# | SPACE | | | | / |
# +-------+ +----+ 2----4
# (-1, -1) (+1, -1) (0,1) (1,1)
# Upload data to GPU
self.screen['position'] = [(-1,+1), (-1,-1), (+1,+1), (+1,-1)]
self.screen['texcoord'] = [(0,0), (0,1), (1,0), (1,1)]
self.screen['scale'] = 1.0
self.screen['tex'] = self.tex
# Don't call directly, use update() instead
def redraw(self):
t_now = time.time()
dt = t_now - self.t_last
self.t_last = t_now
self.dt_history = ([dt] + self.dt_history)[:50]
dt_mean = sum(self.dt_history) / len(self.dt_history)
if self.show_fps and self.nframes % 60 == 0:
self.master.title('FPS: {:.0f}'.format(1 / dt_mean))
def draw(self, img):
assert len(img.shape) == 4, "Please provide an NCHW image tensor"
assert img.device.type == "cuda", "Please provide a CUDA tensor"
if img.dtype.is_floating_point:
img = (255*img).byte()
# Tile images
N, C, H, W = img.shape
if N > 1:
cols, rows = get_grid_dims(N)
img = img.reshape(cols, rows, C, H, W)
img = img.permute(2, 1, 3, 0, 4) # [C, rows, H, cols, W]
img = img.reshape(1, C, rows*H, cols*W)
tensor = img.squeeze().permute(1, 2, 0).data # CHW => HWC
if C == 3:
tensor = torch.cat((tensor, tensor[:,:,0:1]),2) # add the alpha channel
tensor[:,:,3] = 1 # set alpha
tensor = tensor.contiguous()
tex_h, tex_w, _ = self.tex.shape
tensor_h, tensor_w, _ = tensor.shape
if (tex_h, tex_w) != (tensor_h, tensor_w):
print(f'Resizing texture to {tensor_w}*{tensor_h}')
self.tex, self.cuda_buffer = create_gl_texture((N, C, H, W)) # original shape
self.screen['tex'] = self.tex
# copy from torch into buffer
assert self.tex.nbytes == tensor.numel()*tensor.element_size(), "Tensor and texture shape mismatch!"
with cuda_activate(self.cuda_buffer) as ary:
cpy = pycuda.driver.Memcpy2D()
cpy.set_src_device(tensor.data_ptr())
cpy.set_dst_array(ary)
cpy.width_in_bytes = cpy.src_pitch = cpy.dst_pitch = self.tex.nbytes//tensor_h
cpy.height = tensor_h
cpy(aligned=False)
torch.cuda.synchronize()
# draw to screen
self.screen.draw(gl.GL_TRIANGLE_STRIP)
def update(self):
self.update_idletasks()
self.tkMakeCurrent()
self.redraw()
self.tkSwapBuffers()
# USAGE:
# root = tk.Tk()
# iv = TorchImageView(root, width=512, height=512)
# iv.pack(fill='both', expand=True)
# while True:
# iv.draw(nchw_tensor)
# root.update()
# iv.update()