-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Torch's Conv1d strides and ConvTranspose1d #145
Add support for Torch's Conv1d strides and ConvTranspose1d #145
Conversation
…spose1d stride support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello! These changes look great.
I think the "ConvTranspose" changes are good to go. I guess it would be cool to add some tests to double-check that it works correctly with the compile-time API as well, but I guess that's not 100% necessary.
For the Conv1D strides, the "skip" implementation looks correct. I would love to take a shot at implementing that in the compile-time implementations of the Conv1D layer as well... I have a rough idea how that should work. I'm also thinking about making a "wrapper" with a counter to keep track of whether forward()
or skip()
should be called.
Do you mind if I push any changes back to your branch?
|
||
TEST(TestTorchConvTranspose1D, modelOutputMatchesPythonImplementationForDoubles) | ||
{ | ||
testTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>( | |
testTorchConvTranspose1DModel<double,4,15,5,3,1,1,3>( |
|
||
TEST(TestTorchConvTranspose1D, streaming_modelOutputMatchesPythonImplementationForDoubles) | ||
{ | ||
testStreamingTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testStreamingTorchConvTranspose1DModel<float,4,15,5,3,1,1,3>( | |
testStreamingTorchConvTranspose1DModel<double,4,15,5,3,1,1,3>( |
Also linking issue #144 for visibility. |
Sounds good! The stride counter idea sounds great! Feel free to push changes and I'll also take a look at that! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #145 +/- ##
==========================================
- Coverage 95.70% 94.76% -0.95%
==========================================
Files 58 40 -18
Lines 3892 2578 -1314
==========================================
- Hits 3725 2443 -1282
+ Misses 167 135 -32
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Alright, I think I've made all the changes that I want to make... Still need to do a pass for cleanup and documentation, but if @fcaspe wants to have another look and make sure I didn't mess things up too bad, that would be much appreciated :). |
Just reviewed. The 'StridedConv1d` class is a good idea. The examples look good and cleaner than mine. I also tried on the conv models I am developing and they are working ok with this new version! |
@fcaspe Awesome! I'm going to go ahead and merge this PR. |
32b8664
into
jatinchowdhury18:main
Hi Jatin!
This library is really awesome, I have been using it for low latency inference of big convolutional autoencoders, so I have implemented the 1d Transposed Convolution and convolutional strides.
ConvTranspose1d is implemented with RTNeural's
Conv1D
class, but a different loading function has to be called,RTNeural::torch_helpers::loadConvTranspose1D
. Seetorch_convtranspose1d_test.cpp
for an example.Conv1d strides are implemented using a
.skip()
method that performs a single stride step. This just updates the circular buffer of theConv1D
layer with the new input we jump over. For example, if strides=2 is required, then.skip()
has to be called every time after a.forward()
call is made. Seetorch_conv1d_stride_test.cpp
for an example.I know these new functionalities are not fully incorporated into the library. For instance, strides are still missing in
Conv1DT
and the non-streaming versions ofConv1D
andConv2D
. Let me know what you think about these additions and I will be happy to improve them so that hopefully they can be integrated into the library!Best,
Franco