diff --git a/src/sdf/context.cpp b/src/sdf/context.cpp index 1711318eb..6182c9934 100644 --- a/src/sdf/context.cpp +++ b/src/sdf/context.cpp @@ -22,6 +22,29 @@ #include "manifold/optional_assert.h" +struct AffineValue { + // value = var * a + b + int var; + double a; + double b; + + AffineValue(int var, double a, double b) : var(var), a(a), b(b) {} + AffineValue(double constant) + : var(std::numeric_limits::max()), a(0.0), b(constant) {} + bool operator==(const AffineValue &other) const { + return var == other.var && a == other.a && b == other.b; + } +}; + +template <> +struct std::hash { + size_t operator()(const AffineValue &value) const { + size_t h = std::hash()(value.var); + hash_combine(h, value.a, value.b); + return h; + } +}; + namespace manifold::sdf { void Context::dump() const { #ifdef MANIFOLD_DEBUG @@ -33,7 +56,8 @@ void Context::dump() const { if (operand.isResult()) std::cout << "r" << operand.toInstIndex(); else if (operand.isConst()) - std::cout << constants[operand.toConstIndex()]; + std::cout << constants[operand.toConstIndex()] << "(" << operand.id + << ")"; else std::cout << static_cast('X' - operand.id - 1); std::cout << " "; @@ -63,8 +87,6 @@ Operand Context::addInstruction(Instruction inst) { case OpCode::MIN: case OpCode::MAX: case OpCode::EQ: - case OpCode::AND: - case OpCode::OR: case OpCode::FMA: // first two operands commutative, sort them // makes it more likely to find common subexpressions @@ -101,7 +123,7 @@ std::optional Context::trySimplify(Instruction inst) { switch (op) { case OpCode::NOP: case OpCode::RETURN: - case OpCode::CONST: + case OpCode::CONSTANT: case OpCode::STORE: case OpCode::LOAD: break; @@ -128,8 +150,6 @@ std::optional Context::trySimplify(Instruction inst) { case OpCode::MAX: case OpCode::EQ: case OpCode::GT: - case OpCode::AND: - case OpCode::OR: result = EvalContext::handle_binary( op, constants[operands[0].toConstIndex()], constants[operands[1].toConstIndex()]); @@ -188,33 +208,11 @@ std::optional Context::trySimplify(Instruction inst) { return {}; } -Instruction Context::strengthReduction(Instruction inst) { - // strength reduction: reduce instructions to simpler variants - // not very helpful for point evaluation in a vm because instruction decoding - // is the most time consuming part. - // This can be useful if we want to do JIT, interval evaluation or bulk - // evaluation. - if (inst.op == OpCode::MUL && inst.operands[1].isConst() && - constants[inst.operands[1].toConstIndex()] == 2.0) { - // x * 2 => x + x - return {OpCode::ADD, {inst.operands[0], inst.operands[1], Operand::none()}}; - } - if (inst.op == OpCode::DIV && inst.operands[1].isConst()) { - // x / c => x * (1/c) - return {OpCode::MUL, - {inst.operands[0], - addConstant(1.0 / constants[inst.operands[1].toConstIndex()]), - Operand::none()}}; - } - return inst; -} - // bypass the cache because we don't expect to have more common subexpressions // after optimizations Operand Context::addInstructionNoCache(Instruction inst) { auto simplified = trySimplify(inst); if (simplified.has_value()) return simplified.value(); - inst = strengthReduction(inst); size_t i = instructions.size(); instructions.push_back(inst); @@ -234,7 +232,22 @@ Context::UsesVector::const_iterator findUse(const Context::UsesVector &uses, return std::lower_bound(uses.cbegin(), uses.cend(), inst); } -void Context::peephole() { +void Context::addUse(Operand operand, size_t inst) { + if (!operand.isResult() && !operand.isConst()) return; + auto uses = getUses(operand); + auto iter = findUse(*uses, inst); + if (iter == uses->cend() || *iter != inst) uses->insert(iter, inst); +} + +void Context::removeUse(Operand operand, size_t inst) { + if (!operand.isResult() && !operand.isConst()) return; + auto uses = getUses(operand); + auto iter = findUse(*uses, inst); + if (*iter == inst) uses->erase(iter); +} + +void Context::combineFMA() { + const auto none = Operand::none(); auto tryApply = [&](size_t i, Operand lhs, Operand rhs) { if (!lhs.isResult()) return false; auto lhsInst = lhs.toInstIndex(); @@ -244,19 +257,12 @@ void Context::peephole() { Operand b = instructions[lhsInst].operands[1]; instructions[i] = {OpCode::FMA, {a, b, rhs}}; // remove instruction - auto none = Operand::none(); instructions[lhsInst] = {OpCode::NOP, {none, none, none}}; // update uses, note that we need to maintain the order of the indices opUses[lhsInst].clear(); auto updateUses = [&](Operand x) { - if (!x.isResult() && !x.isConst()) return; - auto uses = getUses(x); - auto iter1 = findUse(*uses, lhsInst); - DEBUG_ASSERT(*iter1 == lhsInst, logicErr, "expected use"); - uses->erase(iter1); - auto iter2 = findUse(*uses, i); - // make sure there is no duplicate - if (iter2 == uses->cend() || *iter2 != i) uses->insert(iter2, i); + removeUse(x, lhsInst); + addUse(x, i); }; updateUses(a); if (a != b) updateUses(b); @@ -273,6 +279,281 @@ void Context::peephole() { } } +void Context::optimizeAffine() { + const auto none = Operand::none(); + std::vector affineValues; + affineValues.reserve(instructions.size()); + unordered_map avcache; + + auto getConstant = [&](Operand operand) -> std::optional { + if (operand.isConst()) return constants[operand.toConstIndex()]; + if (operand.isResult() && affineValues[operand.toInstIndex()].a == 0.0) + return affineValues[operand.toInstIndex()].b; + return {}; + }; + + auto replaceInst = [&](int from, int to) { + auto fromInst = Operand{from + 1}; + auto toInst = Operand{to + 1}; + for (auto use : opUses[from]) { + for (auto &operand : instructions[use].operands) + if (operand == fromInst) operand = toInst; + } + opUses[from].clear(); + instructions[from] = {OpCode::NOP, {none, none, none}}; + }; + + // abstract interpretation to figure out affine values for each instruction, + // and replace them as appropriate + // note that we still need constant propagation because this abstract + // interpretation can generate constants + for (size_t i = 0; i < instructions.size(); i++) { + auto &inst = instructions[i]; + AffineValue result = AffineValue(static_cast(i), 1, 0); + switch (inst.op) { + // notably, neg is special among these unary opcode + case OpCode::ABS: + case OpCode::EXP: + case OpCode::LOG: + case OpCode::SQRT: + case OpCode::FLOOR: + case OpCode::CEIL: + case OpCode::ROUND: + case OpCode::SIN: + case OpCode::COS: + case OpCode::TAN: + case OpCode::ASIN: + case OpCode::ACOS: + case OpCode::ATAN: { + auto x = getConstant(inst.operands[0]); + if (x.has_value()) + result = AffineValue( + EvalContext::handle_unary(inst.op, x.value())); + break; + } + case OpCode::NEG: + if (inst.operands[0].isConst()) + result = AffineValue(-constants[inst.operands[0].toConstIndex()]); + else if (inst.operands[0].isResult()) { + auto av = affineValues[inst.operands[0].toInstIndex()]; + result = AffineValue(av.var, -av.a, -av.b); + } + break; + case OpCode::DIV: { + // TODO: handle the case where lhs is divisible by rhs despite rhs is + // not a constant + auto rhs = getConstant(inst.operands[1]); + if (rhs.has_value()) { + if (inst.operands[0].isConst()) { + result = AffineValue(constants[inst.operands[0].toConstIndex()] / + rhs.value()); + } else if (inst.operands[0].isResult()) { + auto av = affineValues[inst.operands[0].toInstIndex()]; + result = + AffineValue(av.var, av.a / rhs.value(), av.b / rhs.value()); + } + } + break; + } + case OpCode::MOD: + case OpCode::MIN: + case OpCode::MAX: + case OpCode::EQ: + case OpCode::GT: { + // TODO: we can do better than just constant propagation... + auto lhs = getConstant(inst.operands[0]); + auto rhs = getConstant(inst.operands[1]); + if (lhs.has_value() && rhs.has_value()) + result = AffineValue(EvalContext::handle_binary( + inst.op, lhs.value(), rhs.value())); + break; + } + case OpCode::ADD: { + auto x = inst.operands[0]; + auto y = inst.operands[1]; + auto lhs = getConstant(x); + auto rhs = getConstant(y); + if (lhs.has_value() && rhs.has_value()) { + result = AffineValue(lhs.value() + rhs.value()); + } else if (lhs.has_value() && y.isResult()) { + result = affineValues[y.toInstIndex()]; + result.b += lhs.value(); + } else if (rhs.has_value() && x.isResult()) { + result = affineValues[x.toInstIndex()]; + result.b += rhs.value(); + } else if (x.isResult() && y.isResult()) { + if (affineValues[x.toInstIndex()].var == + affineValues[y.toInstIndex()].var) { + auto other = affineValues[y.toInstIndex()]; + result = affineValues[x.toInstIndex()]; + result.a += other.a; + result.b += other.b; + } + } + } + case OpCode::SUB: { + auto x = inst.operands[0]; + auto y = inst.operands[1]; + auto lhs = getConstant(x); + auto rhs = getConstant(y); + if (lhs.has_value() && rhs.has_value()) { + result = AffineValue(lhs.value() - rhs.value()); + } else if (lhs.has_value() && y.isResult()) { + result = affineValues[y.toInstIndex()]; + result.a = -result.a; + result.b = lhs.value() - result.b; + } else if (rhs.has_value() && x.isResult()) { + result = affineValues[x.toInstIndex()]; + result.b -= rhs.value(); + } else if (x.isResult() && y.isResult()) { + if (affineValues[x.toInstIndex()].var == + affineValues[y.toInstIndex()].var) { + auto other = affineValues[y.toInstIndex()]; + result = affineValues[x.toInstIndex()]; + result.a -= other.a; + result.b -= other.b; + } + } + break; + } + case OpCode::MUL: { + auto x = inst.operands[0]; + auto y = inst.operands[1]; + auto lhs = getConstant(x); + auto rhs = getConstant(y); + if (lhs.has_value() && rhs.has_value()) { + result = AffineValue(lhs.value() * rhs.value()); + } else if (lhs.has_value() && y.isResult()) { + result = affineValues[y.toInstIndex()]; + result.a *= lhs.value(); + result.b *= lhs.value(); + } else if (rhs.has_value() && x.isResult()) { + result = affineValues[x.toInstIndex()]; + result.a *= rhs.value(); + result.b *= rhs.value(); + } + break; + } + default: + // TODO: handle FMA as well? + break; + } + affineValues.push_back(result); + if (result.var != static_cast(i)) { + // we did evaluate something + auto pair = avcache.insert({result, static_cast(i)}); + if (!pair.second) { + // this result is being optimized away, replace uses with the value + replaceInst(static_cast(i), pair.first->second); + } else { + for (auto operand : inst.operands) removeUse(operand, i); + addUse(Operand{result.var + 1}, i); + // modify instruction + // FIXME: handle constant uses... + if (result.a == 1.0 && result.b == 0.0) { + // this result is being optimized away, replace uses with the value + pair.first->second = result.var; + replaceInst(static_cast(i), result.var); + } else if (result.a == 1.0) { + auto constant = addConstant(result.b); + addUse(constant, i); + instructions[i] = {OpCode::ADD, + {constant, Operand{result.var + 1}, none}}; + } else if (result.a == -1.0) { + auto constant = addConstant(result.b); + addUse(constant, i); + instructions[i] = {OpCode::SUB, + {constant, Operand{result.var + 1}, none}}; + } else if (result.b == 0.0) { + auto constant = addConstant(result.a); + addUse(constant, i); + instructions[i] = {OpCode::MUL, + {constant, Operand{result.var + 1}, none}}; + } else { + auto a = addConstant(result.a); + auto b = addConstant(result.b); + addUse(a, i); + addUse(b, i); + instructions[i] = {OpCode::FMA, {a, Operand{result.var + 1}, b}}; + } + } + } + } +} + +void Context::schedule() { + cache.clear(); + opUses.clear(); + for (auto &uses : constantUses) uses.clear(); + auto oldInstructions = std::move(this->instructions); + // compute depth in DG + std::vector levelMap; + levelMap.reserve(oldInstructions.size()); + for (size_t i = 0; i < oldInstructions.size(); i++) { + const auto &inst = oldInstructions[i]; + size_t maxLevel = 0; + for (auto operand : inst.operands) { + if (!operand.isResult()) continue; + maxLevel = std::max(maxLevel, levelMap[operand.toInstIndex()]); + } + levelMap.push_back(maxLevel + 1); + } + + std::vector computedInst(oldInstructions.size(), Operand::none()); + std::vector stack; + if (oldInstructions.back().operands[0].isResult()) + stack.push_back(oldInstructions.back().operands[0].toInstIndex()); + + auto requiresComputation = [&computedInst](Operand operand) { + return operand.isResult() && computedInst[operand.toInstIndex()].isNone(); + }; + auto toNewOperand = [&computedInst](Operand old) { + if (old.isResult()) return computedInst[old.toInstIndex()]; + return old; + }; + + while (!stack.empty()) { + int numResults = 0; + auto back = stack.back(); + if (!computedInst[back].isNone()) { + stack.pop_back(); + continue; + } + auto &inst = oldInstructions[back]; + std::array costs = {0, 0, 0}; + std::array ids = {0, 1, 2}; + for (auto i : ids) + if (requiresComputation(inst.operands[i])) { + numResults += 1; + costs[i] = levelMap[inst.operands[i].toInstIndex()]; + } + if (numResults > 0) { + std::sort(ids.begin(), ids.end(), + [&costs](size_t x, size_t y) { return costs[x] < costs[y]; }); + for (size_t x : ids) + if (requiresComputation(inst.operands[x])) + stack.push_back(inst.operands[x].toInstIndex()); + } else { + stack.pop_back(); + std::array newOperands; + for (int i : ids) newOperands[i] = toNewOperand(inst.operands[i]); + Operand result = addInstructionNoCache({inst.op, newOperands}); + computedInst[back] = result; + } + } + addInstructionNoCache( + {OpCode::RETURN, + {computedInst[oldInstructions.back().operands[0].toInstIndex()], + Operand::none(), Operand::none()}}); +} + +void Context::optimize() { + optimizeAffine(); + combineFMA(); + schedule(); + dump(); +} + struct RegEntry { size_t nextUse; Operand operand; @@ -353,7 +634,7 @@ std::pair, size_t> Context::genTape() { if (iter == spills.end()) { DEBUG_ASSERT(operand.isConst(), logicErr, "can only materialize constants"); - tape.insert(tape.end(), {static_cast(OpCode::CONST), reg}); + tape.insert(tape.end(), {static_cast(OpCode::CONSTANT), reg}); addImmediate(tape, constants[operand.toConstIndex()]); } else { tape.insert(tape.end(), {static_cast(OpCode::LOAD), reg}); @@ -391,7 +672,8 @@ std::pair, size_t> Context::genTape() { // insert it back with new next use // because it is not at the end of its lifetime, the incremented // iterator is guaranteed to be valid - insertRegCache({*(findUse(*uses, inst) + 1), instOperands[i], regs[i]}); + auto nextUse = *(findUse(*uses, inst) + 1); + insertRegCache({nextUse, instOperands[i], regs[i]}); } } return regs; @@ -399,11 +681,10 @@ std::pair, size_t> Context::genTape() { for (size_t i = 0; i < instructions.size(); i++) { auto &inst = instructions[i]; - if (inst.op == OpCode::NOP) continue; auto instOp = Operand{static_cast(i) + 1}; auto uses = getUses(instOp); - // avoid useless ops - if (inst.op != OpCode::RETURN && uses->empty()) continue; + if (inst.op == OpCode::NOP) continue; + // if (inst.op != OpCode::RETURN && uses->empty()) continue; auto tmp = handleOperands(inst.operands, i); if (inst.op == OpCode::RETURN) { tape.insert(tape.end(), {static_cast(inst.op), tmp[0]}); @@ -411,7 +692,11 @@ std::pair, size_t> Context::genTape() { } // note that we may spill the operand register, but that is fine uint8_t reg = allocateReg(); - insertRegCache({uses->front(), instOp, reg}); + if (uses->empty()) { + availableReg.push_back(reg); + } else { + insertRegCache({uses->front(), instOp, reg}); + } tape.insert(tape.end(), {static_cast(inst.op), reg}); for (size_t j : {0, 1, 2}) { if (inst.operands[j].isNone()) break; diff --git a/src/sdf/context.h b/src/sdf/context.h index 34d6187d9..f9c95870b 100644 --- a/src/sdf/context.h +++ b/src/sdf/context.h @@ -80,11 +80,11 @@ struct std::hash { namespace manifold::sdf { class Context { public: - using UsesVector = small_vector; + using UsesVector = std::vector; Operand addConstant(double d); Operand addInstruction(Instruction); - void peephole(); + void optimize(); void reschedule(); std::pair, size_t> genTape(); @@ -106,8 +106,12 @@ class Context { unordered_map cache; std::optional trySimplify(Instruction); - Instruction strengthReduction(Instruction); Operand addInstructionNoCache(Instruction); + void combineFMA(); + void optimizeAffine(); + void addUse(Operand operand, size_t inst); + void removeUse(Operand operand, size_t inst); + void schedule(); UsesVector* getUses(Operand operand) { if (operand.isResult()) { diff --git a/src/sdf/tape.h b/src/sdf/tape.h index f4c403f6a..77cba4611 100644 --- a/src/sdf/tape.h +++ b/src/sdf/tape.h @@ -27,7 +27,7 @@ namespace manifold::sdf { enum class OpCode : uint8_t { NOP, RETURN, - CONST, + CONSTANT, STORE, LOAD, @@ -54,8 +54,6 @@ enum class OpCode : uint8_t { MAX, EQ, GT, - AND, - OR, // fast binary operations ADD, @@ -115,7 +113,7 @@ struct EvalContext { Domain x = buffer[tape[i + 2]]; buffer[tape[i + 1]] = handle_unary(current, x); i += 3; - } else if (current == OpCode::CONST) { + } else if (current == OpCode::CONSTANT) { double x; std::memcpy(&x, tape.data() + i + 2, sizeof(x)); buffer[tape[i + 1]] = Domain(x); @@ -192,10 +190,6 @@ inline double EvalContext::handle_binary(OpCode op, double lhs, return lhs == rhs ? 1.0 : 0.0; case OpCode::GT: return lhs > rhs ? 1.0 : 0.0; - case OpCode::AND: - return (lhs == 1.0 && rhs == 1.0) ? 1.0 : 0.0; - case OpCode::OR: - return (lhs == 1.0 || rhs == 1.0) ? 1.0 : 0.0; default: return 0; } @@ -204,7 +198,7 @@ inline double EvalContext::handle_binary(OpCode op, double lhs, template <> inline double EvalContext::handle_choice(double cond, double lhs, double rhs) { - if (cond == 1.0) return lhs; + if (cond != 0.0) return lhs; return rhs; } @@ -272,10 +266,6 @@ inline Interval EvalContext>::handle_binary( return lhs == rhs; case OpCode::GT: return lhs > rhs; - case OpCode::AND: - return lhs.logical_and(rhs); - case OpCode::OR: - return lhs.logical_or(rhs); default: return {0.0, 0.0}; } @@ -285,7 +275,7 @@ template <> inline Interval EvalContext>::handle_choice( Interval cond, Interval lhs, Interval rhs) { if (cond.is_const()) { - if (cond.lower == 1.0) return lhs; + if (cond.lower != 0.0) return lhs; return rhs; } return lhs.merge(rhs); @@ -297,7 +287,7 @@ inline std::string dumpOpCode(OpCode op) { return "NOP"; case OpCode::RETURN: return "RETURN"; - case OpCode::CONST: + case OpCode::CONSTANT: return "CONST"; case OpCode::LOAD: return "LOAD"; @@ -343,10 +333,6 @@ inline std::string dumpOpCode(OpCode op) { return "EQ"; case OpCode::GT: return "GT"; - case OpCode::AND: - return "AND"; - case OpCode::OR: - return "OR"; case OpCode::ADD: return "ADD"; case OpCode::SUB: diff --git a/src/sdf/value.cpp b/src/sdf/value.cpp index cc9a7197e..b9f0606b8 100644 --- a/src/sdf/value.cpp +++ b/src/sdf/value.cpp @@ -14,6 +14,8 @@ #include "value.h" +#include + #include "../utils.h" #include "context.h" #include "tape.h" @@ -91,12 +93,12 @@ Value Value::operator>(const Value& other) const { Value Value::operator&&(const Value& other) const { return Value(ValueKind::OPERATION, std::make_shared( - OpCode::AND, *this, other, Invalid())); + OpCode::MUL, *this, other, Invalid())); } Value Value::operator||(const Value& other) const { return Value(ValueKind::OPERATION, std::make_shared( - OpCode::OR, *this, other, Invalid())); + OpCode::ADD, *this, other, Invalid())); } Value Value::abs() const { @@ -205,13 +207,8 @@ std::pair, size_t> Value::genTape() const { Context ctx; unordered_map cache; std::vector stack; - cache.reserve(128); - stack.reserve(128); - - if (kind == ValueKind::OPERATION) stack.push_back(std::get(v).get()); - - auto none = Operand::none(); + const auto none = Operand::none(); bool ready = true; auto getOperand = [&](const Value& x, bool pushStack) { switch (x.kind) { @@ -236,7 +233,13 @@ std::pair, size_t> Value::genTape() const { return none; } }; + + auto start = std::chrono::high_resolution_clock::now(); + if (kind == ValueKind::OPERATION) stack.push_back(std::get(v).get()); + + int count = 0; while (!stack.empty()) { + count++; ready = true; auto current = stack.back(); Operand a = getOperand(current->operands[0], true); @@ -252,8 +255,28 @@ std::pair, size_t> Value::genTape() const { Operand result = getOperand(*this, false); ctx.addInstruction({OpCode::RETURN, {result, none, none}}); - ctx.peephole(); - return ctx.genTape(); + auto end = std::chrono::high_resolution_clock::now(); + auto time = static_cast( + std::chrono::duration_cast(end - start) + .count()); + printf("serialization: %dus with %d nodes\n", time, count); + start = std::chrono::high_resolution_clock::now(); + ctx.optimize(); + end = std::chrono::high_resolution_clock::now(); + time = static_cast( + std::chrono::duration_cast(end - start) + .count()); + printf("optimize: %dus\n", time); + + start = std::chrono::high_resolution_clock::now(); + auto tape = ctx.genTape(); + end = std::chrono::high_resolution_clock::now(); + time = static_cast( + std::chrono::duration_cast(end - start) + .count()); + printf("codegen: %dus with length %ld\n", time, tape.first.size()); + + return tape; } } // namespace manifold::sdf diff --git a/test/sdf_tape_test.cpp b/test/sdf_tape_test.cpp index aef8c46b9..2b13a2171 100644 --- a/test/sdf_tape_test.cpp +++ b/test/sdf_tape_test.cpp @@ -81,7 +81,7 @@ TEST(TAPE, Gyroid) { ctxSimple.buffer[0] = x; ctxSimple.buffer[1] = y; ctxSimple.buffer[2] = z; - ASSERT_NEAR(ctxSimple.eval(), gyroid({x, y, z}), 1e-12); + ASSERT_NEAR(ctxSimple.eval(), gyroid({x, y, z}), 1e-6); } } } @@ -176,3 +176,26 @@ TEST(TAPE, Blobs) { .count()); printf("interval evaluation: %dus\n", time); } + +TEST(TAPE, Blobs2) { + auto lengthFn = [](Value x, Value y, Value z) { + return (x * x + y * y + z * z).sqrt(); + }; + auto smoothstepFn = [](Value edge0, Value edge1, Value a) { + auto x = ((a - edge0) / (edge1 - edge0)) + .min(Value::Constant(1)) + .max(Value::Constant(0)); + return x * x * (Value::Constant(3) - Value::Constant(2) * x); + }; + Value d = Value::Constant(0); + for (int i = 0; i < 1000; i++) { + auto f = double(i + 1); + auto tmp = smoothstepFn( + Value::Constant(-1), Value::Constant(1), + Value::Constant(f).abs() - lengthFn(Value::Constant(f) - Value::X(), + Value::Constant(f) - Value::Y(), + Value::Constant(f) - Value::Z())); + d = d + tmp; + } + auto tape = d.genTape(); +}