Skip to content

Commit

Permalink
added base class for sparsification modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Jan 11, 2024
1 parent 7afe3ae commit eb22c29
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 110 deletions.
2 changes: 2 additions & 0 deletions dnn/torch/dnntools/dnntools/sparsification/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

import torch

debug=False

def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
""" sparsifies matrix with specified block size
Expand Down
42 changes: 15 additions & 27 deletions dnn/torch/dnntools/dnntools/sparsification/conv1d_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

import torch

from .common import sparsify_matrix
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug


class Conv1dSparsifier:
class Conv1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Expand Down Expand Up @@ -68,26 +69,21 @@ def __init__(self, task_list, start, stop, interval, exponent=3):
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list
super().__init__(task_list, start, stop, interval, exponent=3)

# ... and setting counter to 0
self.step_counter = 0
self.last_mask = None

self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}

def step(self, verbose=False):
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Expand All @@ -96,20 +92,6 @@ def step(self, verbose=False):
None
"""
# compute current interpolation factor
self.step_counter += 1

if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return

alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0


with torch.no_grad():
for conv, params in self.task_list:
Expand All @@ -122,10 +104,16 @@ def step(self, verbose=False):
w = weight.permute(0, 2, 1).flatten(1)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w = sparsify_matrix(w, density, block_size)
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(i, k, o).permute(0, 2, 1)
weight[:] = w

if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")

self.last_mask = new_mask

if verbose:
print(f"conv1d_sparsier[{self.step_counter}]: {density=}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@

import torch

from .common import sparsify_matrix

from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug

class ConvTranspose1dSparsifier:

class ConvTranspose1dSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Expand Down Expand Up @@ -68,24 +70,21 @@ def __init__(self, task_list, start, stop, interval, exponent=3):
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list

# ... and setting counter to 0
self.step_counter = 0
super().__init__(task_list, start, stop, interval, exponent=3)

self.last_mask = None

def step(self, verbose=False):
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Expand All @@ -94,20 +93,6 @@ def step(self, verbose=False):
None
"""
# compute current interpolation factor
self.step_counter += 1

if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return

alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0


with torch.no_grad():
for conv, params in self.task_list:
Expand All @@ -120,10 +105,16 @@ def step(self, verbose=False):
w = weight.permute(2, 1, 0).reshape(k * o, i)
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
w = sparsify_matrix(w, density, block_size)
w, new_mask = sparsify_matrix(w, density, block_size, return_mask=True)
w = w.reshape(k, o, i).permute(2, 1, 0)
weight[:] = w

if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")

self.last_mask = new_mask

if verbose:
print(f"convtrans1d_sparsier[{self.step_counter}]: {density=}")

Expand Down
43 changes: 13 additions & 30 deletions dnn/torch/dnntools/dnntools/sparsification/gru_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

import torch

from .common import sparsify_matrix
from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix, debug


class GRUSparsifier:
class GRUSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Expand Down Expand Up @@ -78,26 +79,20 @@ def __init__(self, task_list, start, stop, interval, exponent=3):
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list

# ... and setting counter to 0
self.step_counter = 0
super().__init__(task_list, start, stop, interval, exponent=3)

self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}

def step(self, verbose=False):
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Expand All @@ -106,20 +101,6 @@ def step(self, verbose=False):
None
"""
# compute current interpolation factor
self.step_counter += 1

if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return

alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0


with torch.no_grad():
for gru, params in self.task_list:
Expand All @@ -145,8 +126,8 @@ def step(self, verbose=False):
)

if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
if not torch.all(self.last_masks[key] * new_mask == new_mask) and debug:
print("weight resurrection in weight_ih_l0_v")

self.last_masks[key] = new_mask

Expand All @@ -169,8 +150,8 @@ def step(self, verbose=False):
)

if type(self.last_masks[key]) != type(None):
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
print(f"sparsification mask {key} changed for gru {gru}")
if not torch.all(self.last_masks[key] * new_mask == new_mask) and True:
print("weight resurrection in weight_hh_l0_v")

self.last_masks[key] = new_mask

Expand All @@ -193,3 +174,5 @@ def step(self, verbose=False):

for i in range(100):
sparsifier.step(verbose=True)

print(gru.weight_hh_l0)
40 changes: 14 additions & 26 deletions dnn/torch/dnntools/dnntools/sparsification/linear_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

import torch

from .base_sparsifier import BaseSparsifier
from .common import sparsify_matrix


class LinearSparsifier:
class LinearSparsifier(BaseSparsifier):
def __init__(self, task_list, start, stop, interval, exponent=3):
""" Sparsifier for torch.nn.GRUs
Expand Down Expand Up @@ -68,26 +69,21 @@ def __init__(self, task_list, start, stop, interval, exponent=3):
>>> for i in range(100):
... sparsifier.step()
"""
# just copying parameters...
self.start = start
self.stop = stop
self.interval = interval
self.exponent = exponent
self.task_list = task_list

# ... and setting counter to 0
self.step_counter = 0
super().__init__(task_list, start, stop, interval, exponent=3)

self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
self.last_mask = None

def step(self, verbose=False):
def sparsify(self, alpha, verbose=False):
""" carries out sparsification step
Call this function after optimizer.step in your
training loop.
Parameters:
----------
alpha : float
density interpolation parameter (1: dense, 0: target density)
verbose : bool
if true, densities are printed out
Expand All @@ -96,20 +92,6 @@ def step(self, verbose=False):
None
"""
# compute current interpolation factor
self.step_counter += 1

if self.step_counter < self.start:
return
elif self.step_counter < self.stop:
# update only every self.interval-th interval
if self.step_counter % self.interval:
return

alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
else:
alpha = 0


with torch.no_grad():
for linear, params in self.task_list:
Expand All @@ -119,7 +101,13 @@ def step(self, verbose=False):
weight = linear.weight
target_density, block_size = params
density = alpha + (1 - alpha) * target_density
weight[:] = sparsify_matrix(weight, density, block_size)
weight[:], new_mask = sparsify_matrix(weight, density, block_size, return_mask=True)

if self.last_mask is not None:
if not torch.all(self.last_mask * new_mask == new_mask) and debug:
print("weight resurrection in conv.weight")

self.last_mask = new_mask

if verbose:
print(f"linear_sparsifier[{self.step_counter}]: {density=}")
Expand Down
2 changes: 1 addition & 1 deletion dnn/torch/osce/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
reply = input('continue? (y/n): ')

if reply == 'n':
os._exit()
os._exit(0)
else:
os.makedirs(args.output, exist_ok=True)

Expand Down

0 comments on commit eb22c29

Please sign in to comment.