Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function calls #21

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion dist_ir/executor/absint.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,31 @@ def interpret_pmap(self, op: Op, state: AbstractState):

return state

def interpret_function_call(self, op: Op, state: AbstractState):
# Find the op's inputs in state's environment and save environment
inputs = tuple(state.env[v] for v in op.inputs)
old_env = state.env
state.env = {} # To enforce variable scoping

# Change state's function pointer to subfunction (TODO necessary?)
function = state.function
state.function = op.subfunctions[0]

# Interpret subfunction with appropriate inputs
self.interpret(op.subfunctions[0], inputs, state=state)

# Find the outputs from the state's env
results = tuple(state.env[v] for v in op.subfunctions[0].outputs)

# Put the results back into the state's environment
state.env = old_env
for x, val in zip(op.outputs, results):
state.env[x] = val
# Also reset state's function pointer
state.function = function

return state

def interpret(
self, function: Function, inputs: Sequence[Any], state: AbstractState = None
):
Expand All @@ -93,7 +118,9 @@ def interpret(

# Execute ops in topological order:
for op in function.ops:
if op.op_type == "Pmap":
if op.op_type == "FnCall":
self.interpret_function_call(op, state)
elif op.op_type == "Pmap":
self.interpret_pmap(op, state)
else:
# Function dispatch:
Expand Down
33 changes: 3 additions & 30 deletions dist_ir/executor/sequential_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Sequence
from typing import Any, Sequence, Tuple

from .absint import AbstractInterpreter, convert_impls_to_semantics
from .backend_register import BackendRegister
Expand All @@ -12,42 +12,15 @@ def __init__(self, backend):
semantics = convert_impls_to_semantics(BackendRegister[backend])
self.interpreter = AbstractInterpreter(semantics=semantics)

def _compute_op(self, op: Op, inputs: List[Any]):
"""Executes the given op and returns its outputs."""
op_type = op.op_type
if op_type == "Pmap":
# Zip the inputs so that we map over each corresponding value
inputs = zip(*inputs)
# Iterate over the inputs
results = []
for inps in inputs:
# Execute subfunction with appropriate inputs
outs = self.compute(op.subfunctions[0], inps)
# Match output names to output data using the function output order.
ordered_outs = [outs[e] for e in op.subfunctions[0].outputs]
results.append(ordered_outs)
# Unzip the results
results = tuple(zip(*results))
return results
if op_type not in BackendRegister[self._backend]:
raise NotImplementedError(
f"No {self._backend} implementation found for op {op_type}"
)
impl = BackendRegister[self._backend][op_type]
output_data = impl(op, inputs)
if not isinstance(output_data, tuple):
output_data = (output_data,)
return output_data

def compute(self, function: Function, inputs: Sequence[Any]) -> Dict[Value, Any]:
def compute(self, function: Function, inputs: Sequence[Any]) -> Tuple[Any]:
"""Executes the function given the specified inputs and returns the final result.

Args:
function: The function to execute.
inputs: A sequence of input data represented in the specified backend.

Returns:
A map from output value to output data.
A tuple of outputs.
"""
state = self.interpreter.interpret(function, inputs)
return tuple(state.env[v] for v in function.outputs)
25 changes: 13 additions & 12 deletions dist_ir/executor/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(self, function: Function, inputs: Sequence[Any]):
self.consumers = defaultdict(int)
self.trace = []

def add_trace_event(self, op_name, device, start_time, duration):
def add_trace_event(self, op_type, device, start_time, duration):
self.trace.append(
{
"name": op_name,
"name": op_type,
"ph": "X",
"ts": start_time,
"dur": duration,
Expand Down Expand Up @@ -70,30 +70,31 @@ def _simulate_op(
# Update the trace and timestamps
for device in costs:
state.add_trace_event(
op.name,
op.op_type,
device,
state.timestamps[device],
costs[device],
)
state.timestamps[device] += costs[device]

# Update the live memory.
for out_edge in op.outputs:
state.consumers[out_edge] = len(state.function.consumers[out_edge])
# Output value could live on multiple devices (e.g. scatter) so
# update memory on all devices:
output_devices = out_edge.type.get_all_devices()
for output_device in output_devices:
state.live_memory[output_device] += out_edge.type.size()
for out_val, conc_val in zip(op.outputs, outputs):
if isinstance(conc_val, Type):
state.consumers[conc_val] = len(state.function.consumers[out_val])
# Output value could live on multiple devices (e.g. scatter) so
# update memory on all devices:
output_devices = conc_val.get_all_devices()
for output_device in output_devices:
state.live_memory[output_device] += conc_val.size()
# TODO: Can we optimize this using a priority queue?
for value in state.consumers:
# TODO we are missing a decrement of state.consumers[value] somewhere
if state.consumers[value] == 0 and all(
value != v for v in state.function.inputs
):
value_devices = value.type.get_all_devices()
value_devices = value.get_all_devices()
for device in value_devices:
state.live_memory[device] -= value.type.size()
state.live_memory[device] -= value.size()

# Update the peak memory.
for device in state.live_memory:
Expand Down
7 changes: 4 additions & 3 deletions dist_ir/executor/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,16 @@ def _type_function(function: Function, type_map: Dict[Value, Type]) -> Function:
# Invariant: inputs of op are already typed (as ops are toposorted)
typed_inputs = tuple(value_map[inp] for inp in op.inputs)

# Recursively convert the subfunctions:
subfunctions = tuple(_type_function(fn, type_map) for fn in op.subfunctions)
# Recursively convert the subfunctions?
# TODO how to handle multiple calls to function with varying types/shapes?
# subfunctions = tuple(_type_function(fn, type_map) for fn in op.subfunctions)

new_op = Op(
op_type=op.op_type,
name=op.name,
inputs=typed_inputs,
attributes=op.attributes,
subfunctions=subfunctions,
subfunctions=op.subfunctions,
output_names=tuple(v.name for v in op.outputs),
# Look up output types from type_map
output_types=tuple(type_map[v] for v in op.outputs),
Expand Down
12 changes: 11 additions & 1 deletion dist_ir/ir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@ class Op:
output_types: InitVar[Tuple[Type]] = None

def __post_init__(self, output_names, output_types):
if self.op_type == "Pmap":
if self.op_type == "FnCall":
# Function calls. Subfunction 0 is the called function
assert len(self.subfunctions) == 1
# Number of inputs is arbitrary but positive
assert len(self.inputs) > 0
# Number of inputs matches subfunction
assert len(self.inputs) == len(self.subfunctions[0].inputs)
# Number of outputs is given by subfunction
num_outputs = len(self.subfunctions[0].outputs)

elif self.op_type == "Pmap":
# Handle pmap specially
assert len(self.subfunctions) == 1
# Number of inputs is arbitrary but positive
Expand Down
23 changes: 23 additions & 0 deletions test/test_sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,26 @@ def test_pmap_dp():
(res,) = ex.compute(function, [(x_0, x_1), (_wA, _wA), (_wB, _wB)])
assert np.array_equal(res[0], np.matmul(np.matmul(x_0, _wA), _wB))
assert np.array_equal(res[1], np.matmul(np.matmul(x_1, _wA), _wB))


def test_function_call():
layer = FunctionMaker()
x = layer.add_input_value("x", None)
w = layer.add_input_value("w", None)
_ = layer.add_op("MatMul", inputs=[x, w])
layer = layer.finalize()
fn = FunctionMaker()
x = fn.add_input_value("x", None)
w1 = fn.add_input_value("w1", None)
w2 = fn.add_input_value("w2", None)
a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer])
_ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer])
fn = fn.finalize()
cpprint(fn)

ex = SequentialExecutor("numpy")
_x = np.arange(16 * 4).reshape((16, 4))
_w1 = np.ones((4, 2))
_w2 = np.ones((2, 1))
(res,) = ex.compute(fn, [_x, _w1, _w2])
assert np.array_equal(res, np.matmul(np.matmul(_x, _w1), _w2))
24 changes: 24 additions & 0 deletions test/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,27 @@ def test_chrome_trace():
transformed_function, (v.type for v in transformed_function.inputs)
)
simulation.dump_chrome_trace("test/trace.json")


def test_function_call():
topology = Topology()
d0 = topology.add_device("gpu")

layer = FunctionMaker()
x = layer.add_input_value("x", None)
w = layer.add_input_value("w", None)
_ = layer.add_op("MatMul", inputs=[x, w])
layer = layer.finalize()
fn = FunctionMaker()
x = fn.add_input_value("x", Tensor(Float(), (4, 5), device=d0))
w1 = fn.add_input_value("w1", Tensor(Float(), (5, 6), device=d0))
w2 = fn.add_input_value("w2", Tensor(Float(), (6, 2), device=d0))
a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer])
_ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer])
fn = fn.finalize()
fn = infer_types(fn, fn.inputs)

device_speeds = {"gpu": 1.0e13}
simulator = Simulator(CostModel(topology, device_speeds))
simulation = simulator.interpret(fn, (v.type for v in fn.inputs))
simulation.dump_chrome_trace("test/trace.json")
19 changes: 19 additions & 0 deletions test/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,24 @@ def test_scatter():
assert xs.type.types[1].device == d1


def test_function_call():
layer = FunctionMaker()
x = layer.add_input_value("x", None)
w = layer.add_input_value("w", None)
_ = layer.add_op("MatMul", inputs=[x, w])
layer = layer.finalize()
fn = FunctionMaker()
x = fn.add_input_value("x", Tensor(Float(), (4, 5)))
w1 = fn.add_input_value("w1", Tensor(Float(), (5, 6)))
w2 = fn.add_input_value("w2", Tensor(Float(), (6, 2)))
a1 = fn.add_op("FnCall", inputs=[x, w1], subfunctions=[layer])
_ = fn.add_op("FnCall", inputs=[a1, w2], subfunctions=[layer])
fn = fn.finalize()

fn = infer_types(fn, [x, w1, w2])
y = fn.outputs[0]
assert y.type == Tensor(Float(), (4, 2))


if __name__ == "__main__":
test_pmap()