Skip to content

Commit

Permalink
Update requirements, add generated duet notebook, removed some tqdms
Browse files Browse the repository at this point in the history
  • Loading branch information
Mariel Pettee committed Apr 29, 2020
1 parent 4e0d115 commit 9c186f7
Show file tree
Hide file tree
Showing 5 changed files with 100,078 additions and 40,305 deletions.
21 changes: 12 additions & 9 deletions functions/seq_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def setup_gpus():
# tf.set_random_seed(1)
# np.random.seed(1)
# identify available GPU's
gpus = K.tensorflow_backend._get_available_gpus()
# gpus = tf.config.experimental.list_physical_devices('GPU')
gpus = K.tensorflow_backend._get_available_gpus() # works with TF 1 (?)
# gpus = tf.config.experimental.list_physical_devices('GPU') # works with TF 2
# allow dynamic GPU memory allocation
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
Expand Down Expand Up @@ -283,10 +283,10 @@ def get_line_segments(seq, zcolor=None, cmap=None, cloud=False):
return xline

# put line segments on the given axis, with given colors
def put_lines(ax, segments, color=None, lw=2.5, alpha=None, cloud=False):
def put_lines(ax, segments, color=None, lw=2.5, alpha=None, cloud=False, cloud_alpha=None):
lines = []
### Main skeleton
for i in tqdm(range(len(skeleton_idxs)), desc="Skeleton lines"):
for i in range(len(skeleton_idxs)):
if isinstance(color, (list,tuple,np.ndarray)):
c = color[i]
else:
Expand All @@ -301,7 +301,7 @@ def put_lines(ax, segments, color=None, lw=2.5, alpha=None, cloud=False):

if cloud:
### Cloud of all-connected joints
for i in tqdm(range(len(skeleton_idxs),len(all_idxs)), desc="Cloud lines"):
for i in range(len(skeleton_idxs),len(all_idxs)):
if isinstance(color, (list,tuple,np.ndarray)):
c = color[i]
else:
Expand All @@ -310,7 +310,7 @@ def put_lines(ax, segments, color=None, lw=2.5, alpha=None, cloud=False):
np.linspace(segments[i,1,0],segments[i,1,1],2),
np.linspace(segments[i,2,0],segments[i,2,1],2),
color=c,
alpha=0.03,
alpha=cloud_alpha,
lw=lw)[0]
lines.append(l)
return lines
Expand All @@ -323,7 +323,7 @@ def put_lines(ax, segments, color=None, lw=2.5, alpha=None, cloud=False):
# `zcolor` may be an N-length array, where N is the number of vertices in seq, and will
# be used to color the vertices. Typically this is set to the avg. z-value of each vtx.
def animate_stick(seq, ghost=None, ghost_shift=0, figsize=None, zcolor=None, pointer=None, ax_lims=(-0.4,0.4), speed=45,
dot_size=20, dot_alpha=0.5, lw=2.5, cmap='cool_r', pointer_color='black', cloud=False):
dot_size=20, dot_alpha=0.5, lw=2.5, cmap='cool_r', pointer_color='black', cloud=False, birds_eye=False, cloud_alpha=0.035):
if zcolor is None:
zcolor = np.zeros(seq.shape[1])
fig = plt.figure(figsize=figsize)
Expand Down Expand Up @@ -353,17 +353,20 @@ def animate_stick(seq, ghost=None, ghost_shift=0, figsize=None, zcolor=None, poi
if ghost is not None:
pts_g = ax.scatter(ghost[0,:,0],ghost[0,:,1],ghost[0,:,2], c=ghost_color, s=dot_size, alpha=dot_alpha)

if birds_eye == True:
ax.view_init(elev=90., azim=-45.)

if ax_lims:
ax.set_xlim(*ax_lims)
ax.set_ylim(*ax_lims)
ax.set_zlim(0,ax_lims[1]-ax_lims[0])
plt.close(fig)
xline, colors = get_line_segments(seq, zcolor, cm)
lines = put_lines(ax, xline[0], colors, lw=lw, alpha=0.9, cloud=cloud)
lines = put_lines(ax, xline[0], colors, lw=lw, alpha=0.9, cloud=cloud, cloud_alpha=cloud_alpha)

if ghost is not None:
xline_g = get_line_segments(ghost)
lines_g = put_lines(ax, xline_g[0], ghost_color, lw=lw, alpha=1.0, cloud=cloud)
lines_g = put_lines(ax, xline_g[0], ghost_color, lw=lw, alpha=1.0, cloud=cloud, cloud_alpha=cloud_alpha)

if pointer is not None:
vR = 0.15
Expand Down
Loading

0 comments on commit 9c186f7

Please sign in to comment.