Skip to content

Commit

Permalink
add jump none option to get/apply gradient functions
Browse files Browse the repository at this point in the history
  • Loading branch information
qiauil committed Sep 20, 2024
1 parent 9050c3f commit 36b5752
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
16 changes: 14 additions & 2 deletions conflictfree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,50 @@ def get_para_vector(network) -> torch.Tensor:
para_vec = torch.cat((para_vec, viewed))
return para_vec

def get_gradient_vector(network)->torch.Tensor:
def get_gradient_vector(network,jump_none=True)->torch.Tensor:
"""
Returns the gradient vector of the given network.
Args:
network (torch.nn.Module): The network for which to compute the gradient vector.
jump_none (bool): Whether to skip the None gradients. default: True
This is useful when part of your neural network is frozen or not trainable.
You should set the same value to `apply_gradient_vector` when applying the gradient vector.
Returns:
torch.Tensor: The gradient vector of the network.
"""
with torch.no_grad():
grad_vec = None
for par in network.parameters():
if par.grad is None:
if jump_none:
continue
viewed=par.grad.data.view(-1)
if grad_vec is None:
grad_vec = viewed
else:
grad_vec = torch.cat((grad_vec, viewed))
return grad_vec

def apply_gradient_vector(network:torch.nn.Module,grad_vec:torch.Tensor)->None:
def apply_gradient_vector(network:torch.nn.Module,grad_vec:torch.Tensor,jump_none=True)->None:
"""
Applies a gradient vector to the network's parameters.
Args:
network (torch.nn.Module): The network to apply the gradient vector to.
grad_vec (torch.Tensor): The gradient vector to apply.
jump_none (bool): Whether to skip the None gradients. default: True
This is useful when part of your neural network is frozen or not trainable.
You should set the same value to `get_gradient_vector` when applying the gradient vector.
"""
with torch.no_grad():
start=0
for par in network.parameters():
if par.grad is None:
if jump_none:
continue
end=start+par.grad.data.view(-1).shape[0]
par.grad.data=grad_vec[start:end].view(par.grad.data.shape)
start=end
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_install_requires():

setuptools.setup(
name="conflictfree",
version="0.1.4",
version="0.1.5",
author="Qiang Liu, Mengyu Chu, Nils Thuerey",
author_email="[email protected]",
description="Official implementation of Conflict-free Inverse Gradients method",
Expand Down

0 comments on commit 36b5752

Please sign in to comment.