-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation question in Tensorflow #26
Comments
Hi, I have same issue with similar code: class WeightStandardization(tf.keras.constraints.Constraint):
def __call__(self, w):
mean = tf.math.reduce_mean(w, axis=[0, 1, 2], keepdims=True)
std = tf.math.reduce_std(w, axis=[0, 1, 2], keepdims=True)
return (w - mean) / tf.maximum(std, tf.keras.backend.epsilon()) @jbp70, any progress please? |
Hi @markub3327. I had tried using this code a long while ago. I was unsuccessful in figuring out how to implement this properly and ended up simply moving on. If you figure it out, do let me know! |
@jbp70 It must be done before computing grad. I believe this process must be done during inference and after that must be done a normalization. These constraints do it after gradients are calculated. This process needs to be kernel substituted with a normalized version and used it for prediction. After that, the gradients be applied. |
@jbp70 What about this implementation? class WeightStandardization(tf.keras.constraints.Constraint):
def __call__(self, w):
mean, variance = tf.nn.moments(w, axes=[0, 1, 2], keepdims=True)
std = tf.sqrt(variance)
epsilon = tf.keras.backend.epsilon()
return (w - mean) / tf.maximum(std, epsilon) |
Hi! I am currently trying to implement your code in Tensorflow 2.2, but am running into errors. I get some strange error when trying to place the standardization directly before calling the convolution (as you said to do in post #11 ).
For reference, the error is: "TypeError: An op outside of the function building code is being passed a "Graph" tensor."
I wasn't able to figure out how to correct this so I decided to create a custom kernel constraint and use that for the weight standardization.
`class CustomWeightStandardization(tf.keras.constraints.Constraint):
As far as I understand, this should be equivalent to how you have implemented weight standardization. I am training an image segmentation model using a 3D U-Net architecture trained from scratch on a medical imaging dataset I have. Unfortunately, turning on this kernel constraint makes the model perform worse than when I train without it. Do you have any ideas on how to fix this?
The text was updated successfully, but these errors were encountered: