-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Message passing operation #59
base: master
Are you sure you want to change the base?
Conversation
raimis
commented
Jun 7, 2022
•
edited
Loading
edited
- Move the code from TorchMD-NET (Nearest neighbor operation #58)
- Integration
- Tests
- Documentation
@peastman could you review? |
details. | ||
messages: `torch.Tensor` | ||
Atom pair messages. The shape of the tensor is `(num_pairs, num_features)`. | ||
For efficient, `num_features` has to be a multiple of 32 and <= 1024. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are those limitations really necessary? It's very common for the number of features not to be a multiple of 32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In contrary, the number of internal features is always factor of 32 (e.g. in TorchMD-NET, I have seen usage of 64, 128, 256). GPU computes in warps of 32 threads, so it is the best to match that patter for computational efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I frequently create models that don't satisfy those requirements, including in TorchMD-Net. For example, I've trained models with 48 or 80 features per layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand your question. Why not? The number of features is a hyperparameter. It's one of many hyperarameters you tune to balance training accuracy, overfitting, speed, and memory use. Why place arbitrary limits on it when there's no need to?
const int32_t i_feat = threadIdx.x; | ||
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can eliminate the limitations on number of features by just rewriting this as a loop.
for (int32_t i_feat = threadIdx.x; i_feat < num_features; i_feat += blockDim.x)
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from solving a non-existing problem, this would make the memory access not coalesced and reduce speed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would have no effect on speed at all. If the number of features happens to satisfy your current requirement, the behavior would be identical to what it currently does. The atomicAdd()
would be executed once by every thread with i_feat
equal to threadIdx.x
. The only change would be if the number doesn't satisfy your current requirements, either because it's not a multiple of 32 or it's more than 1024. In that case it would produce correct behavior, unlike the current code. So there's no downside at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will reduce the number of thread by the number of features. The reduced parallelism would result into reduced speed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not suggesting any change to the number of threads. The only thing I'm suggesting is wrapping the atomicAdd()
in a loop as shown above. If num_features
happens to match your current restrictions, nothing will change. Every thread will still call it exactly once.