Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Payable Feature #2

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
50 changes: 50 additions & 0 deletions src/main/scala/edu/berkeley/cs/rise/quartz/PayableExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package edu.berkeley.cs.rise.quartz

object PayableExtractor {

def extractPayableVars(stateMachine: StateMachine): (Set[String], Set[(String, String)], Set[(String, String)]) = {
var fields: Set[String] = Set.empty[String]
var params: Set[(String, String)] = Set.empty[(String, String)]
val structFields: Set[(String, String)] = Set.empty[(String, String)]

var previousSize = 0
do {
previousSize = fields.size + params.size + structFields.size
stateMachine.transitions foreach { transition =>
val transParams = transition.parameters.getOrElse(Seq.empty[Variable]).map(row => row.name)
transition.body.getOrElse(Seq.empty[Statement]) foreach {
case Send(destination, _, _) => {
val destVar = identifyVariable(destination)
if (transParams.contains(destVar)) {
params += ((transition.name, destVar))
} else {
fields += destVar
}
}
case Assignment(left, right) if fields.contains(left.rootName) ||
params.contains((transition.name, left.rootName)) => {
val rightVar = identifyVariable(right)
if (transParams.contains(rightVar)) {
params += ((transition.name, rightVar))
} else {
fields += rightVar
}
}
case _ => "default"
}
}
} while (previousSize != fields.size + params.size + structFields.size)

println("Fields:", fields)
println("Params:", params)
println("Struct:", structFields)
return (fields, params, structFields)
}

def identifyVariable(expression: Expression): String = expression match {
case MappingRef(map, key) => identifyVariable(map)
case VarRef(name) => name
case SequenceSize(sequence) => identifyVariable(sequence)
case _ => ""
}
}
46 changes: 10 additions & 36 deletions src/main/scala/edu/berkeley/cs/rise/quartz/Solidity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@ object Solidity {
case Bool => "bool"
case Timespan => "uint"
case HashValue(_) => "bytes32"
case Mapping(keyType, valueType) => s"mapping(${writeType(keyType, payable)} => ${writeType(valueType, payable)})"
case Mapping(keyType, valueType) => s"mapping(${writeType(keyType, false)} => ${writeType(valueType, payable)})"
case Sequence(elementType) => s"${writeType(elementType, payable)}[]"
case Struct(name) => name
}

private def writeStructDefinition(name: String, fields: Map[String, DataType]): String = {
private def writeStructDefinition(name: String, fields: Map[String, DataType], payableFields: Set[(String, String)]): String = {
val builder = new StringBuilder()
appendLine(builder, s"struct $name {")

indentationLevel += 1
fields.foreach { case (fName, fTy) =>
// TODO Determine when field must be marked payable
appendLine(builder, s"${writeType(fTy, payable = false)} $fName;")
appendLine(builder, s"${writeType(fTy, payable = payableFields.contains((name, fName)))} $fName;")
}
indentationLevel -= 1
appendLine(builder, "}")
Expand Down Expand Up @@ -200,14 +198,14 @@ object Solidity {
builder.toString()
}

private def writeTransition(transition: Transition, useCall: Boolean = false): String = {
private def writeTransition(transition: Transition, useCall: Boolean = false, payableParams: Set[(String, String)]): String = {
val builder = new StringBuilder()

val paramsRepr = transition.parameters.fold("") { params =>
// Remove parameters that are used in the original source but are built in to Solidity
val effectiveParams = params.filter(p => !BUILTIN_PARAMS.contains(p.name))
val payableParams = extractPayableVars(transition.body.getOrElse(Seq.empty[Statement]), effectiveParams.map(_.name).toSet)
writeParameters(effectiveParams.zip(effectiveParams.map(p => payableParams.contains(p.name))))
val payables = payableParams.filter(_._1.equals(transition.name)).map(_._2)
writeParameters(effectiveParams.zip(effectiveParams.map(p => payables.contains(p.name))))
}

val payable = if (transition.parameters.getOrElse(Seq.empty[Variable]).exists(_.name == "tokens")) {
Expand Down Expand Up @@ -485,9 +483,9 @@ object Solidity {
appendLine(builder, s"contract $name {")
indentationLevel += 1

val payableFields = extractPayableVars(stateMachine.flattenStatements, stateMachine.fields.map(_.name).toSet)
var (fields, params, structFields) = PayableExtractor.extractPayableVars(stateMachine)

stateMachine.structs.foreach { case (name, fields) => builder.append(writeStructDefinition(name, fields)) }
stateMachine.structs.foreach { case (name, fields) => builder.append(writeStructDefinition(name, fields, structFields)) }

appendLine(builder, "enum State {")
indentationLevel += 1
Expand All @@ -497,12 +495,12 @@ object Solidity {
indentationLevel -= 1
appendLine(builder, "}")

stateMachine.fields.foreach(f => appendLine(builder, writeField(f, payableFields.contains(f.name)) + ";"))
stateMachine.fields.foreach(f => appendLine(builder, writeField(f, fields.contains(f.name)) + ";"))
appendLine(builder, s"State public $CURRENT_STATE_VAR;")
builder.append(writeAuthorizationFields(stateMachine))
builder.append("\n")

stateMachine.transitions foreach { t => builder.append(writeTransition(t, useCall)) }
stateMachine.transitions foreach { t => builder.append(writeTransition(t, useCall, params)) }
extractAllMembershipTypes(stateMachine).foreach(ty => builder.append(writeSequenceContainsTest(ty) + "\n"))
builder.append("\n")

Expand Down Expand Up @@ -557,28 +555,4 @@ object Solidity {
expressionChecks
}
}

private def extractVarNames(expression: Expression): Set[String] = expression match {
case MappingRef(map, key) => extractVarNames(map) ++ extractVarNames(key)
case VarRef(name) => Set(name)
case LogicalOperation(left, _, right) => extractVarNames(left) ++ extractVarNames(right)
case ArithmeticOperation(left, _, right) => extractVarNames(left) ++ extractVarNames(right)
case SequenceSize(sequence) => extractVarNames(sequence)
case _ => Set.empty[String]
}

private def extractPayableVars(statements: Seq[Statement], scope: Set[String] = Set.empty[String]): Set[String] = {
val names = statements.foldLeft(Set.empty[String]) { (current, statement) =>
statement match {
case Send(destination, _, _) => current.union(extractVarNames(destination))
case _ => current
}
}

if (scope.nonEmpty) {
names.intersect(scope)
} else {
names
}
}
}
2 changes: 0 additions & 2 deletions src/main/scala/edu/berkeley/cs/rise/quartz/StateMachine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ case class StateMachine(structs: Map[String, Map[String, DataType]], fields: Seq
}

def flattenExpressions: Seq[Expression] = transitions.flatMap(_.flattenExpressions())

def flattenStatements: Seq[Statement] = transitions.flatMap(_.body.getOrElse(Seq.empty[Statement]))
}

object StateMachine {
Expand Down
19 changes: 19 additions & 0 deletions src/test/resources/payable/field.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
contract Field {
data {
A: Identity
B: Identity
C: Identity
}

initialize: -> open {
B = A
}

test1: open -> open {
send 0 to A
}

test2: open -> open {
C = B
}
}
19 changes: 19 additions & 0 deletions src/test/resources/payable/field2.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
contract Field {
data {
A: Identity
B: Identity
C: Identity
}

initialize: -> open {
A = B
}

test1: open -> open {
send 0 to A
}

test2: open -> open {
B = C
}
}
14 changes: 14 additions & 0 deletions src/test/resources/payable/mapping.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
contract Mapping {
data {
Map1: Mapping[Uint, Identity]
Map2: Mapping[Identity, Identity]
}

initialize: ->(id: Uint) open {
send 0 to Map1[id]
}

test1: open ->(id: Identity) open {
send 0 to Map2[id]
}
}
11 changes: 11 additions & 0 deletions src/test/resources/payable/mapping2.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
contract Mapping {
data {
Map1: Mapping[Uint, Mapping[Identity, Identity]]
Map2: Mapping[Identity, Mapping[Uint, Identity]]
}

initialize: ->(id: Uint, id2: Identity) open {
send 0 to Map1[id][id2]
send 0 to Map2[id2][id]
}
}
19 changes: 19 additions & 0 deletions src/test/resources/payable/param.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
contract Param {
data {
A: Identity
B: Identity
}

initialize: ->(id: Identity) open {
A = id
A = B
}

test1: open -> open {
send 0 to A
}

test2: open ->(id: Identity) open {
B = id
}
}
15 changes: 15 additions & 0 deletions src/test/resources/payable/param2.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
contract Param {
data {
A: Identity
}

initialize: ->(id: Identity, id2: Identity, id3: Identity) open {
A = id
id = id2
id3 = id2
}

test1: open -> open {
send 0 to A
}
}
33 changes: 33 additions & 0 deletions src/test/resources/payable/struct.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
contract Struct {
struct Wrapper1 {
payee: Identity
}

struct Wrapper2 {
payee2: Identity
wrap: Wrapper1
}

struct Wrapper3 {
payee3: Identity
wrap: Wrapper2
}

data {
wrap1: Wrapper1
wrap2: Wrapper2
wrap3: Wrapper3
}

initialize: -> open {
send 0 to wrap1.payee
}

test1: open -> open {
send 0 to wrap2.wrap.payee
}

test2: open -> open {
send 0 to wrap3.wrap.wrap.payee
}
}
31 changes: 31 additions & 0 deletions src/test/resources/payable/struct2.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
contract Struct {
struct Wrapper1 {
payee: Identity
}

struct Wrapper2 {
map: Mapping[Identity, Wrapper1]
map2: Mapping[Identity, Identity]
}

struct Wrapper3 {
seq: Sequence[Wrapper1]
}

struct Wrapper4 {
payeeMap: Wrapper2
payeeSeq: Wrapper3
}

data {
wrap2: Wrapper2
wrap3: Wrapper3
wrap4: Wrapper4
}

initialize: ->(id: Identity) open {
send 0 to wrap2.map[id].payee
send 0 to wrap2.map2[id]
send 0 to wrap4.payeeMap.map[id].payee
}
}
24 changes: 24 additions & 0 deletions src/test/resources/payable/struct3.qtz
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
contract Struct {
struct Wrapper {
payee: Identity
payee2: Identity
}

data {
wrap: Wrapper
A: Identity
}

initialize: -> open {
A = wrap.payee
}

test1: open -> open {
send 0 to A
}

test2: open ->(id: Identity) open {
wrap.payee2 = id
send 0 to wrap.payee2
}
}