Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cargo +nightly fmt #1017

Merged
merged 17 commits into from
Dec 12, 2023
7 changes: 4 additions & 3 deletions burn-autodiff/src/graph/traversal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ impl BreadthFirstSearch {
let mut visited = HashSet::with_capacity(root.order);
let mut parents = Vec::with_capacity(root.order);
let mut steps = graph.steps();
let root_step = steps
.remove(&root.id)
.expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?");
let root_step = steps.remove(&root.id).expect(
"Root node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);

visited.insert(root.id.clone());
parents.append(&mut root.parents.clone());
Expand Down
324 changes: 154 additions & 170 deletions burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, None, options),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
B::conv2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down Expand Up @@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
}
match bias {
Some(bias) => {
match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, None, options),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
B::conv1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down
5 changes: 4 additions & 1 deletion burn-core/src/nn/conv/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
let channels_out_div_by_group = channels_out % groups == 0;

if !channels_in_div_by_group && !channels_out_div_by_group {
panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}");
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
);
}
}
Loading