Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IR conversion when an Event selector is accessed #2589

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slither/core/solidity_types/function_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def storage_size(self) -> Tuple[int, bool]:
def is_dynamic(self) -> bool:
return False

def __str__(self):
def __str__(self) -> str:
# Use x.type
# x.name may be empty
params = ",".join([str(x.type) for x in self._params])
Expand Down
7 changes: 4 additions & 3 deletions slither/core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
from slither.core.declarations import Function

# pylint: disable=too-many-instance-attributes
class Variable(SourceMapping):
def __init__(self) -> None:
super().__init__()
self._name: Optional[str] = None
self._initial_expression: Optional["Expression"] = None
self._type: Optional[Type] = None
self._type: Optional[Union[List, Type, "Function", str]] = None
self._initialized: Optional[bool] = None
self._visibility: Optional[str] = None
self._is_constant = False
Expand Down Expand Up @@ -77,7 +78,7 @@ def name(self, name: str) -> None:
self._name = name

@property
def type(self) -> Optional[Type]:
def type(self) -> Optional[Union[List, Type, "Function", str]]:
return self._type

@type.setter
Expand Down Expand Up @@ -120,7 +121,7 @@ def visibility(self) -> Optional[str]:
def visibility(self, v: str) -> None:
self._visibility = v

def set_type(self, t: Optional[Union[List, Type, str]]) -> None:
def set_type(self, t: Optional[Union[List, Type, "Function", str]]) -> None:
if isinstance(t, str):
self._type = ElementaryType(t)
return
Expand Down
51 changes: 26 additions & 25 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias
from slither.core.variables.function_type_variable import FunctionTypeVariable
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable
from slither.slithir.exceptions import SlithIRError
Expand Down Expand Up @@ -81,7 +80,7 @@
from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure
from slither.slithir.variables import Constant, ReferenceVariable, TemporaryVariable
from slither.slithir.variables import TupleVariable
from slither.utils.function import get_function_id
from slither.utils.function import get_function_id, get_event_id
from slither.utils.type import export_nested_types_from_variable
from slither.utils.using_for import USING_FOR
from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR
Expand Down Expand Up @@ -279,20 +278,6 @@ def is_temporary(ins: Operation) -> bool:
)


def _make_function_type(func: Function) -> FunctionType:
parameters = []
returns = []
for parameter in func.parameters:
v = FunctionTypeVariable()
v.name = parameter.name
parameters.append(v)
for return_var in func.returns:
v = FunctionTypeVariable()
v.name = return_var.name
returns.append(v)
return FunctionType(parameters, returns)


# endregion
###################################################################################
###################################################################################
Expand Down Expand Up @@ -793,12 +778,29 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
assignment.set_node(ir.node)
assignment.lvalue.set_type(ElementaryType("bytes4"))
return assignment
if ir.variable_right == "selector" and isinstance(
ir.variable_left.type, (Function)
if ir.variable_right == "selector" and isinstance(ir.variable_left, (Event)):
# the event selector returns a bytes32, which is different from the error/function selector
# which returns a bytes4
assignment = Assignment(
ir.lvalue,
Constant(
str(get_event_id(ir.variable_left.full_name)), ElementaryType("bytes32")
),
ElementaryType("bytes32"),
)
assignment.set_expression(ir.expression)
assignment.set_node(ir.node)
assignment.lvalue.set_type(ElementaryType("bytes32"))
return assignment
if ir.variable_right == "selector" and (
isinstance(ir.variable_left.type, (Function))
):
assignment = Assignment(
ir.lvalue,
Constant(str(get_function_id(ir.variable_left.type.full_name))),
Constant(
str(get_function_id(ir.variable_left.type.full_name)),
ElementaryType("bytes4"),
),
ElementaryType("bytes4"),
)
assignment.set_expression(ir.expression)
Expand Down Expand Up @@ -826,10 +828,9 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
targeted_function = next(
(x for x in ir_func.contract.functions if x.name == str(ir.variable_right))
)
t = _make_function_type(targeted_function)
ir.lvalue.set_type(t)
ir.lvalue.set_type(targeted_function)
elif isinstance(left, (Variable, SolidityVariable)):
t = ir.variable_left.type
t = left.type
elif isinstance(left, (Contract, Enum, Structure)):
t = UserDefinedType(left)
# can be None due to temporary operation
Expand All @@ -846,10 +847,10 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
ir.lvalue.set_type(elems[elem].type)
else:
assert isinstance(type_t, Contract)
# Allow type propagtion as a Function
# Allow type propagation as a Function
# Only for reference variables
# This allows to track the selector keyword
# We dont need to check for function collision, as solc prevents the use of selector
# We don't need to check for function collision, as solc prevents the use of selector
# if there are multiple functions with the same name
f = next(
(f for f in type_t.functions if f.name == ir.variable_right),
Expand All @@ -858,7 +859,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo
if f:
ir.lvalue.set_type(f)
else:
# Allow propgation for variable access through contract's name
# Allow propagation for variable access through contract's name
# like Base_contract.my_variable
v = next(
(
Expand Down
13 changes: 13 additions & 0 deletions slither/utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ def get_function_id(sig: str) -> int:
digest = keccak.new(digest_bits=256)
digest.update(sig.encode("utf-8"))
return int("0x" + digest.hexdigest()[:8], 16)


def get_event_id(sig: str) -> int:
"""'
Return the event id of the given signature
Args:
sig (str)
Return:
(int)
"""
digest = keccak.new(digest_bits=256)
digest.update(sig.encode("utf-8"))
return int("0x" + digest.hexdigest(), 16)
47 changes: 47 additions & 0 deletions tests/unit/slithir/test_data/selector.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
interface I{
function testFunction(uint a) external ;
}

contract A{
function testFunction() public{}
}

contract Test{
event TestEvent();
struct St{
uint a;
}
error TestError();

function testFunction(uint a) public {}


function testFunctionStructure(St memory s) public {}

function returnEvent() public returns (bytes32){
return TestEvent.selector;
}

function returnError() public returns (bytes4){
return TestError.selector;
}


function returnFunctionFromContract() public returns (bytes4){
return I.testFunction.selector;
}


function returnFunction() public returns (bytes4){
return this.testFunction.selector;
}

function returnFunctionWithStructure() public returns (bytes4){
return this.testFunctionStructure.selector;
}

function returnFunctionThroughLocaLVar() public returns(bytes4){
A a;
return a.testFunction.selector;
}
}
32 changes: 32 additions & 0 deletions tests/unit/slithir/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from slither import Slither
from slither.slithir.operations import Assignment
from slither.slithir.variables import Constant

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


func_to_results = {
"returnEvent()": "16700440330922901039223184000601971290390760458944929668086539975128325467771",
"returnError()": "224292994",
"returnFunctionFromContract()": "890000139",
"returnFunction()": "890000139",
"returnFunctionWithStructure()": "1430834845",
"returnFunctionThroughLocaLVar()": "3781905051",
}


def test_enum_max_min(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.19")
slither = Slither(Path(TEST_DATA_DIR, "selector.sol").as_posix(), solc=solc_path)

contract = slither.get_contract_from_name("Test")[0]

for func_name, value in func_to_results.items():
f = contract.get_function_from_signature(func_name)
assignment = f.slithir_operations[0]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == value
)