Skip to content

Commit

Permalink
cargo +nightly fmt (#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexErrant authored Dec 12, 2023
1 parent 1a5f252 commit 610d640
Show file tree
Hide file tree
Showing 29 changed files with 424 additions and 323 deletions.
7 changes: 4 additions & 3 deletions burn-autodiff/src/graph/traversal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ impl BreadthFirstSearch {
let mut visited = HashSet::with_capacity(root.order);
let mut parents = Vec::with_capacity(root.order);
let mut steps = graph.steps();
let root_step = steps
.remove(&root.id)
.expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?");
let root_step = steps.remove(&root.id).expect(
"Root node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);

visited.insert(root.id.clone());
parents.append(&mut root.parents.clone());
Expand Down
324 changes: 154 additions & 170 deletions burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, None, options),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
B::conv2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down Expand Up @@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
}
match bias {
Some(bias) => {
match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, None, options),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
B::conv1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
}
},
}
}

Expand Down Expand Up @@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}

match bias {
Some(bias) => {
match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

Expand Down
5 changes: 4 additions & 1 deletion burn-core/src/nn/conv/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
let channels_out_div_by_group = channels_out % groups == 0;

if !channels_in_div_by_group && !channels_out_div_by_group {
panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}");
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
);
}
}
Loading

0 comments on commit 610d640

Please sign in to comment.