Skip to content

Commit

Permalink
add more options to get and apply the gradient from/to the neural net…
Browse files Browse the repository at this point in the history
…works
  • Loading branch information
qiauil committed Oct 29, 2024
1 parent abe10cd commit 9283e25
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
97 changes: 80 additions & 17 deletions conflictfree/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# usr/bin/python3
# -*- coding: UTF-8 -*-
from . import *
from warnings import warn
import numpy as np
from typing import Literal


def get_para_vector(network) -> torch.Tensor:
def get_para_vector(network: torch.nn.Module) -> torch.Tensor:
"""
Returns the parameter vector of the given network.
Expand All @@ -26,15 +26,22 @@ def get_para_vector(network) -> torch.Tensor:
return para_vec


def get_gradient_vector(network, jump_none=True) -> torch.Tensor:
def get_gradient_vector(
network: torch.nn.Module, none_grad_mode: Literal["raise", "zero", "skip"] = "skip"
) -> torch.Tensor:
"""
Returns the gradient vector of the given network.
Args:
network (torch.nn.Module): The network for which to compute the gradient vector.
jump_none (bool): Whether to skip the None gradients. default: True
This is useful when part of your neural network is frozen or not trainable.
You should set the same value to `apply_gradient_vector` when applying the gradient vector.
none_grad_mode (Literal['raise', 'zero', 'skip']): The mode to handle None gradients. default: 'skip'
- 'raise': Raise an error when the gradient of a parameter is None.
- 'zero': Replace the None gradient with a zero tensor.
- 'skip': Skip the None gradient.
The None gradient usually occurs when part of the network is not trainable (e.g., fine-tuning)
or the weight is not used to calculate the current loss (e.g., different parts of the network calculate different losses).
If all of your losses are calculated using the same part of the network, you should set none_grad_mode to 'skip'.
If your losses are calculated using different parts of the network, you should set none_grad_mode to 'zero' to ensure the gradients have the same shape.
Returns:
torch.Tensor: The gradient vector of the network.
Expand All @@ -43,9 +50,16 @@ def get_gradient_vector(network, jump_none=True) -> torch.Tensor:
grad_vec = None
for par in network.parameters():
if par.grad is None:
if jump_none:
if none_grad_mode == "raise":
raise RuntimeError("None gradient detected.")
elif none_grad_mode == "zero":
viewed = torch.zeros_like(par.data.view(-1))
elif none_grad_mode == "skip":
continue
viewed = par.grad.data.view(-1)
else:
raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.")
else:
viewed = par.grad.data.view(-1)
if grad_vec is None:
grad_vec = viewed
else:
Expand All @@ -54,27 +68,74 @@ def get_gradient_vector(network, jump_none=True) -> torch.Tensor:


def apply_gradient_vector(
network: torch.nn.Module, grad_vec: torch.Tensor, jump_none=True
network: torch.nn.Module,
grad_vec: torch.Tensor,
none_grad_mode: Literal["zero", "skip"] = "skip",
zero_grad_mode: Literal["skip", "pad_zero", "pad_value"] = "pad_value",
) -> None:
"""
Applies a gradient vector to the network's parameters.
This function requires the network contains the some gradient information in order to apply the gradient vector.
If your network does not contain the gradient information, you should consider using `apply_gradient_vector_para_based` function.
Args:
network (torch.nn.Module): The network to apply the gradient vector to.
grad_vec (torch.Tensor): The gradient vector to apply.
jump_none (bool): Whether to skip the None gradients. default: True
This is useful when part of your neural network is frozen or not trainable.
You should set the same value to `get_gradient_vector` when applying the gradient vector.
none_grad_mode (Literal['zero', 'skip']): The mode to handle None gradients.
You should set this parameter to the same value as the one used in `get_gradient_vector` method.
zero_grad_mode (Literal['padding', 'skip']): How to set the value of the gradient if your `none_grad_mode` is "zero". default: 'skip'
- 'skip': Skip the None gradient.
- 'padding': Replace the None gradient with a zero tensor.
- 'pad_value': Replace the None gradient using the value in the gradient.
If you set `none_grad_mode` to 'zero', that means you padded zero to your `grad_vec` if the gradient of the parameter is None when getting the gradient vector.
When you apply the gradient vector back to the network, the value in the `grad_vec` corresponding to the previous None gradient may not be zero due to the applied gradient operation.
Thus, you need to determine whether to recover the original None value, set it to zero, or set the value according to the value in `grad_vec`.
If you are not sure what you are doing, it is safer to set it to 'pad_value'.
"""
if none_grad_mode == "zero" and zero_grad_mode == "pad_value":
apply_gradient_vector_para_based(network, grad_vec)
with torch.no_grad():
start = 0
for par in network.parameters():
if par.grad is None:
if jump_none:
if none_grad_mode == "skip":
continue
end = start + par.grad.data.view(-1).shape[0]
par.grad.data = grad_vec[start:end].view(par.grad.data.shape)
elif none_grad_mode == "zero":
start = start + par.data.view(-1).shape[0]
if zero_grad_mode == "pad_zero":
par.grad = torch.zeros_like(par.data)
elif zero_grad_mode == "skip":
continue
else:
raise ValueError(f"Invalid zero_grad_mode '{zero_grad_mode}'.")
else:
raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.")
else:
end = start + par.data.view(-1).shape[0]
par.grad.data = grad_vec[start:end].view(par.data.shape)
start = end


def apply_gradient_vector_para_based(
network: torch.nn.Module,
grad_vec: torch.Tensor,
) -> None:
"""
Applies a gradient vector to the network's parameters.
Please only use this function when you are sure that the length of `grad_vec` is the same of your network's parameters.
This happens when you use `get_gradient_vector` with `none_grad_mode` set to 'zero'.
Or, the 'none_grad_mode' is 'skip' but all of the parameters in your network is involved in the loss calculation.
Args:
network (torch.nn.Module): The network to apply the gradient vector to.
grad_vec (torch.Tensor): The gradient vector to apply.
"""
with torch.no_grad():
start = 0
for par in network.parameters():
end = start + par.data.view(-1).shape[0]
par.grad = grad_vec[start:end].view(par.data.shape)
start = end


Expand Down Expand Up @@ -109,7 +170,7 @@ def get_cos_similarity(vector1: torch.Tensor, vector2: torch.Tensor) -> torch.Te
return torch.dot(vector1, vector2) / vector1.norm() / vector2.norm()


def unit_vector(vector: torch.Tensor, warn_zero=False) -> torch.Tensor:
def unit_vector(vector: torch.Tensor, warn_zero: bool = False) -> torch.Tensor:
"""
Compute the unit vector of a given tensor.
Expand Down Expand Up @@ -259,7 +320,9 @@ def select(
Returns:
Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice.
"""
assert n <= len(source_sequence), "n can not be larger than or equal to the length of the source sequence"
assert n <= len(
source_sequence
), "n can not be larger than or equal to the length of the source sequence"
indexes = np.random.choice(len(source_sequence), n, replace=False)
if len(indexes) == 1:
return indexes, source_sequence[indexes[0]]
Expand Down
1 change: 1 addition & 0 deletions docs/api/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The `utils` module contains utility functions for the ConFIG algorithm.
::: conflictfree.utils.apply_para_vector
::: conflictfree.utils.get_gradient_vector
::: conflictfree.utils.apply_gradient_vector
::: conflictfree.utils.apply_gradient_vector_para_based

## Math Utility Functions
::: conflictfree.utils.get_cos_similarity
Expand Down

0 comments on commit 9283e25

Please sign in to comment.