Skip to content

Commit

Permalink
Fix error for changing precision of multiple input sets
Browse files Browse the repository at this point in the history
  • Loading branch information
mldiego committed Mar 20, 2024
1 parent bf29b32 commit 2521f2a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions code/nnv/engine/nn/NN.m
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@
function inputSet = consistentPrecision(obj, inputSet)
% (assume parameters have same precision across layers)
% approach: change input precision based on network parameters
inputPrecision = class(inputSet.V);
inputPrecision = class(inputSet(1).V);
netPrecision = 'double'; % default
for i=1:length(obj.Layers)
if isa(obj.Layers{i}, "FullyConnectedLayer") || isa(obj.Layers{i}, "Conv2DLayer")
Expand All @@ -963,7 +963,9 @@
if ~strcmp(inputPrecision, netPrecision)
% input and parameter precision does not match
warning("Changing input set precision to "+string(netPrecision));
inputSet = inputSet.changeVarsPrecision(netPrecision);
for i = 1:length(inputSet)
inputSet(i) = inputSet(i).changeVarsPrecision(netPrecision);
end
end
end

Expand Down

0 comments on commit 2521f2a

Please sign in to comment.