Skip to content

Commit

Permalink
do some optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pca006132 committed Dec 27, 2024
1 parent 36321cd commit 7cbc4de
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 47 deletions.
65 changes: 60 additions & 5 deletions src/sdf/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ Operand Context::addInstruction(Instruction inst) {
return result;
}

// bypass the cache because we don't expect to have more common subexpressions
// after optimizations
Operand Context::addInstructionNoCache(Instruction inst) {
std::optional<Operand> Context::trySimplify(Instruction inst) {
// constant choice
auto op = inst.op;
auto &operands = inst.operands;
Expand Down Expand Up @@ -161,11 +159,68 @@ Operand Context::addInstructionNoCache(Instruction inst) {
return addConstant(result);
}

// simple simplifications
if (op == OpCode::ADD) {
// add is commutative, so if there is a constant, it must be on the left
// 0 + x => x
if (operands[0].isConst() && constants[operands[0].toConstIndex()] == 0.0)
return operands[1];
}
if (op == OpCode::SUB) {
// x - 0 => x
if (operands[1].isConst() && constants[operands[1].toConstIndex()] == 0.0)
return operands[0];

Check warning on line 172 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L172

Added line #L172 was not covered by tests
}
if (op == OpCode::MUL) {
// mul is commutative, so if there is a constant, it must be on the left
// 0 * x => 0
if (operands[0].isConst() && constants[operands[0].toConstIndex()] == 0.0)
return operands[0];

Check warning on line 178 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L178

Added line #L178 was not covered by tests
// 1 * x => x
if (operands[0].isConst() && constants[operands[0].toConstIndex()] == 1.0)
return operands[1];

Check warning on line 181 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L181

Added line #L181 was not covered by tests
}
if (op == OpCode::DIV) {
if (operands[1].isConst() && constants[operands[1].toConstIndex()] == 1.0)
return operands[0];

Check warning on line 185 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L185

Added line #L185 was not covered by tests
}

return {};
}

Instruction Context::strengthReduction(Instruction inst) {
// strength reduction: reduce instructions to simpler variants
// not very helpful for point evaluation in a vm because instruction decoding
// is the most time consuming part.
// This can be useful if we want to do JIT, interval evaluation or bulk
// evaluation.
if (inst.op == OpCode::MUL && inst.operands[1].isConst() &&
constants[inst.operands[1].toConstIndex()] == 2.0) {

Check warning on line 198 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L198

Added line #L198 was not covered by tests
// x * 2 => x + x
return {OpCode::ADD, {inst.operands[0], inst.operands[1], Operand::none()}};

Check warning on line 200 in src/sdf/context.cpp

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.cpp#L200

Added line #L200 was not covered by tests
}
if (inst.op == OpCode::DIV && inst.operands[1].isConst()) {
// x / c => x * (1/c)
return {OpCode::MUL,
{inst.operands[0],
addConstant(1.0 / constants[inst.operands[1].toConstIndex()]),
Operand::none()}};
}
return inst;
}

// bypass the cache because we don't expect to have more common subexpressions
// after optimizations
Operand Context::addInstructionNoCache(Instruction inst) {
auto simplified = trySimplify(inst);
if (simplified.has_value()) return simplified.value();
inst = strengthReduction(inst);

size_t i = instructions.size();
instructions.push_back({op, operands});
instructions.push_back(inst);
opUses.emplace_back();
// update uses
for (auto operand : operands) {
for (auto operand : inst.operands) {
auto target = getUses(operand);
if (target == nullptr) continue;
// avoid duplicates
Expand Down
3 changes: 3 additions & 0 deletions src/sdf/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#pragma once

#include <optional>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -104,6 +105,8 @@ class Context {
std::vector<UsesVector> opUses;
unordered_map<Instruction, Operand> cache;

std::optional<Operand> trySimplify(Instruction);
Instruction strengthReduction(Instruction);
Operand addInstructionNoCache(Instruction);

UsesVector* getUses(Operand operand) {

Check warning on line 112 in src/sdf/context.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/context.h#L112

Added line #L112 was not covered by tests
Expand Down
95 changes: 53 additions & 42 deletions src/sdf/interval.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ struct Interval {
Domain lower;
Domain upper;

static constexpr Domain zero = static_cast<Domain>(0);
static constexpr Domain one = static_cast<Domain>(1);

Interval()
: lower(-std::numeric_limits<Domain>::infinity()),
upper(std::numeric_limits<Domain>::infinity()) {}
Expand All @@ -42,14 +45,22 @@ struct Interval {

Interval operator-(const Interval &other) const { return *this + (-other); }

Check warning on line 46 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L46

Added line #L46 was not covered by tests

Interval operator*(Domain d) const {

Check warning on line 48 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L48

Added line #L48 was not covered by tests
if (d > zero) return {lower * d, upper * d};
return {upper * d, lower * d};
}

Interval operator*(const Interval &other) const {
if (is_const()) return other * lower;
if (other.is_const()) return *this * other.lower;

Domain a1b1 = lower * other.lower;
Domain a2b2 = upper * other.upper;
// we can write more "fast paths", but at some point it will become slower
// than just going the general path...
if (lower >= 0.0 && other.lower >= 0.0)
if (lower >= zero && other.lower >= zero)
return {a1b1, a2b2};
else if (upper <= 0.0 && other.upper <= 0.0)
else if (upper <= zero && other.upper <= zero)
return {a2b2, a1b1};

Domain a1b2 = lower * other.upper;
Expand All @@ -58,53 +69,47 @@ struct Interval {
std::max(std::max(a1b1, a1b2), std::max(a2b1, a2b2))};
}

Interval operator*(double d) const {
if (d > 0) return {lower * d, upper * d};
return {upper * d, lower * d};
}

Interval operator/(const Interval &other) const {
if (other.is_const()) return *this / other.lower;

Check warning on line 73 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L72-L73

Added lines #L72 - L73 were not covered by tests
constexpr Domain zero = static_cast<Domain>(0);
constexpr Domain infty = std::numeric_limits<Domain>::infinity();
Interval reci;
if (other.lower >= zero || other.upper <= zero) {
reci.lower = other.upper == zero ? -infty : (1 / other.upper);
reci.upper = other.lower == zero ? infty : (1 / other.lower);
reci.lower = other.upper == zero ? -infty : (one / other.upper);
reci.upper = other.lower == zero ? infty : (one / other.lower);

Check warning on line 78 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L76-L78

Added lines #L76 - L78 were not covered by tests
} else {
reci.lower = -infty;
reci.upper = infty;
}
return *this * reci;

Check warning on line 83 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L83

Added line #L83 was not covered by tests
}

Interval operator/(double d) const {
if (d > 0) return {lower / d, upper / d};
Interval operator/(Domain d) const {
if (d > zero) return {lower / d, upper / d};
return {upper / d, lower / d};

Check warning on line 88 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L86-L88

Added lines #L86 - L88 were not covered by tests
}

constexpr bool is_const() const { return lower == upper; }

Interval operator==(const Interval &other) const {
if (is_const() && other.is_const() && lower == other.lower)

Check warning on line 94 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L93-L94

Added lines #L93 - L94 were not covered by tests
return constant(1); // must be equal
return constant(one); // must be equal
if (lower > other.upper || upper < other.lower)

Check warning on line 96 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L96

Added line #L96 was not covered by tests
return constant(0); // disjoint, cannot possibly be equal
return {0, 1};
return constant(zero); // disjoint, cannot possibly be equal
return {zero, one};

Check warning on line 98 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L98

Added line #L98 was not covered by tests
}

constexpr bool operator==(double d) const { return is_const() && lower == d; }
constexpr bool operator==(Domain d) const { return is_const() && lower == d; }

Interval operator>(const Interval &other) const {
if (lower > other.upper) return constant(1);
if (upper < other.lower) return constant(0);
return {0, 1};
if (lower > other.upper) return constant(one);
if (upper < other.lower) return constant(zero);
return {zero, one};

Check warning on line 106 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L103-L106

Added lines #L103 - L106 were not covered by tests
}

Interval operator<(const Interval &other) const {
if (upper < other.lower) return constant(1);
if (lower > other.upper) return constant(0);
return {0, 1};
if (upper < other.lower) return constant(one);
if (lower > other.upper) return constant(zero);
return {zero, one};
}

Interval min(const Interval &other) const {

Check warning on line 115 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L115

Added line #L115 was not covered by tests
Expand Down Expand Up @@ -132,70 +137,76 @@ struct Interval {
}

Interval abs() const {
if (lower >= 0) return *this;
if (upper <= 0) return {-upper, -lower};
return {0.0, std::max(-lower, upper)};
if (lower >= zero) return *this;
if (upper <= zero) return {-upper, -lower};
return {zero, std::max(-lower, upper)};

Check warning on line 142 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L139-L142

Added lines #L139 - L142 were not covered by tests
}

Interval mod(double m) const {
Interval mod(Domain m) const {

Check warning on line 145 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L145

Added line #L145 was not covered by tests
// FIXME: cannot deal with negative m right now...
Domain diff = std::fmod(lower, m);
if (diff < 0) diff += m;
if (diff < zero) diff += m;
Domain cycle_min = lower - diff;

Check warning on line 149 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L147-L149

Added lines #L147 - L149 were not covered by tests
// may be disjoint intervals, but we don't deal with that...
if (upper - cycle_min >= m) return {0.0, m};
if (upper - cycle_min >= m) return {zero, m};
return {diff, upper - cycle_min};

Check warning on line 152 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L151-L152

Added lines #L151 - L152 were not covered by tests
}

Interval logical_and(const Interval &other) const {
return {lower == 0.0 || other.lower == 0.0 ? 0.0 : 1.0,
upper == 1.0 && other.upper == 1.0 ? 1.0 : 0.0};
return {lower == zero || other.lower == zero ? zero : one,
upper == one && other.upper == one ? one : zero};

Check warning on line 157 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L155-L157

Added lines #L155 - L157 were not covered by tests
}

Interval logical_or(const Interval &other) const {
return {lower == 0.0 && other.lower == 0.0 ? 0.0 : 1.0,
upper == 1.0 || other.upper == 1.0 ? 1.0 : 0.0};
return {lower == zero && other.lower == zero ? zero : one,
upper == one || other.upper == one ? one : zero};

Check warning on line 162 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L160-L162

Added lines #L160 - L162 were not covered by tests
}

Interval sin() const {
if (is_const()) return constant(std::sin(lower));
// largely similar to cos
int64_t min_pis = static_cast<int64_t>(std::floor((lower - kHalfPi) / kPi));
int64_t max_pis = static_cast<int64_t>(std::floor((upper - kHalfPi) / kPi));
int64_t min_pis = static_cast<int64_t>(std::floor(
(lower - static_cast<Domain>(kHalfPi)) / static_cast<Domain>(kPi)));
int64_t max_pis = static_cast<int64_t>(std::floor(
(upper - static_cast<Domain>(kHalfPi)) / static_cast<Domain>(kPi)));

bool not_cross_pos_1 =
(min_pis % 2 == 0) ? max_pis - min_pis <= 1 : max_pis == min_pis;
bool not_cross_neg_1 =
(min_pis % 2 == 0) ? max_pis == min_pis : max_pis - min_pis <= 1;

Domain new_min =
not_cross_neg_1 ? std::min(std::sin(lower), std::sin(upper)) : -1.0;
not_cross_neg_1 ? std::min(std::sin(lower), std::sin(upper)) : -one;
Domain new_max =
not_cross_pos_1 ? std::max(std::sin(lower), std::sin(upper)) : 1.0;
not_cross_pos_1 ? std::max(std::sin(lower), std::sin(upper)) : one;
return {new_min, new_max};
}

Interval cos() const {
if (is_const()) return constant(std::cos(lower));
int64_t min_pis = static_cast<int64_t>(std::floor(lower / kPi));
int64_t max_pis = static_cast<int64_t>(std::floor(upper / kPi));
int64_t min_pis =
static_cast<int64_t>(std::floor(lower / static_cast<Domain>(kPi)));
int64_t max_pis =
static_cast<int64_t>(std::floor(upper / static_cast<Domain>(kPi)));

bool not_cross_pos_1 =
(min_pis % 2 == 0) ? max_pis - min_pis <= 1 : max_pis == min_pis;
bool not_cross_neg_1 =
(min_pis % 2 == 0) ? max_pis == min_pis : max_pis - min_pis <= 1;

Domain new_min =
not_cross_neg_1 ? std::min(std::cos(lower), std::cos(upper)) : -1.0;
not_cross_neg_1 ? std::min(std::cos(lower), std::cos(upper)) : -one;
Domain new_max =
not_cross_pos_1 ? std::max(std::cos(lower), std::cos(upper)) : 1.0;
not_cross_pos_1 ? std::max(std::cos(lower), std::cos(upper)) : one;
return {new_min, new_max};
}

Interval tan() const {
if (is_const()) return constant(std::tan(lower));
int64_t min_pis = static_cast<int64_t>(std::floor((lower + kHalfPi) / kPi));
int64_t max_pis = static_cast<int64_t>(std::floor((upper + kHalfPi) / kPi));
int64_t min_pis = static_cast<int64_t>(std::floor(
(lower + static_cast<Domain>(kHalfPi)) / static_cast<Domain>(kPi)));
int64_t max_pis = static_cast<int64_t>(std::floor(
(upper + static_cast<Domain>(kHalfPi)) / static_cast<Domain>(kPi)));
if (min_pis != max_pis)

Check warning on line 210 in src/sdf/interval.h

View check run for this annotation

Codecov / codecov/patch

src/sdf/interval.h#L204-L210

Added lines #L204 - L210 were not covered by tests
return {-std::numeric_limits<Domain>::infinity(),
std::numeric_limits<Domain>::infinity()};
Expand Down

0 comments on commit 7cbc4de

Please sign in to comment.