Skip to content

Commit

Permalink
[tools] add brainpy.tools.compose and brainpy.tools.pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 18, 2024
1 parent fb08523 commit e4f857e
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 1 deletion.
192 changes: 192 additions & 0 deletions brainpy/_src/tools/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import inspect
from functools import partial
from operator import attrgetter
from types import MethodType

__all__ = [
'compose', 'pipe'
]


def identity(x):
""" Identity function. Return x
>>> identity(3)
3
"""
return x


def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None):
""" Like @property, but returns ``classval`` when used as a class attribute
>>> class MyClass(object):
... '''The class docstring'''
... @instanceproperty(classval=__doc__)
... def __doc__(self):
... return 'An object docstring'
... @instanceproperty
... def val(self):
... return 42
...
>>> MyClass.__doc__
'The class docstring'
>>> MyClass.val is None
True
>>> obj = MyClass()
>>> obj.__doc__
'An object docstring'
>>> obj.val
42
"""
if fget is None:
return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc,
classval=classval)
return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc,
classval=classval)


class InstanceProperty(property):
""" Like @property, but returns ``classval`` when used as a class attribute
Should not be used directly. Use ``instanceproperty`` instead.
"""

def __init__(self, fget=None, fset=None, fdel=None, doc=None,
classval=None):
self.classval = classval
property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc)

def __get__(self, obj, type=None):
if obj is None:
return self.classval
return property.__get__(self, obj, type)

def __reduce__(self):
state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval)
return InstanceProperty, state


class Compose(object):
""" A composition of functions
See Also:
compose
"""
__slots__ = 'first', 'funcs'

def __init__(self, funcs):
funcs = tuple(reversed(funcs))
self.first = funcs[0]
self.funcs = funcs[1:]

def __call__(self, *args, **kwargs):
ret = self.first(*args, **kwargs)
for f in self.funcs:
ret = f(ret)
return ret

def __getstate__(self):
return self.first, self.funcs

def __setstate__(self, state):
self.first, self.funcs = state

@instanceproperty(classval=__doc__)
def __doc__(self):
def composed_doc(*fs):
"""Generate a docstring for the composition of fs.
"""
if not fs:
# Argument name for the docstring.
return '*args, **kwargs'

return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:]))

try:
return (
'lambda *args, **kwargs: ' +
composed_doc(*reversed((self.first,) + self.funcs))
)
except AttributeError:
# One of our callables does not have a `__name__`, whatever.
return 'A composition of functions'

@property
def __name__(self):
try:
return '_of_'.join(
(f.__name__ for f in reversed((self.first,) + self.funcs))
)
except AttributeError:
return type(self).__name__

def __repr__(self):
return '{.__class__.__name__}{!r}'.format(
self, tuple(reversed((self.first,) + self.funcs)))

def __eq__(self, other):
if isinstance(other, Compose):
return other.first == self.first and other.funcs == self.funcs
return NotImplemented

def __ne__(self, other):
equality = self.__eq__(other)
return NotImplemented if equality is NotImplemented else not equality

def __hash__(self):
return hash(self.first) ^ hash(self.funcs)

# Mimic the descriptor behavior of python functions.
# i.e. let Compose be called as a method when bound to a class.
# adapted from
# docs.python.org/3/howto/descriptor.html#functions-and-methods
def __get__(self, obj, objtype=None):
return self if obj is None else MethodType(self, obj)

# introspection with Signature is only possible from py3.3+
@instanceproperty
def __signature__(self):
base = inspect.signature(self.first)
last = inspect.signature(self.funcs[-1])
return base.replace(return_annotation=last.return_annotation)

__wrapped__ = instanceproperty(attrgetter('first'))


def compose(*funcs):
""" Compose functions to operate in series.
Returns a function that applies other functions in sequence.
Functions are applied from right to left so that
``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``.
If no arguments are provided, the identity function (f(x) = x) is returned.
>>> inc = lambda i: i + 1
>>> compose(str, inc)(3)
'4'
"""
if not funcs:
return identity
if len(funcs) == 1:
return funcs[0]
else:
return Compose(funcs)


def pipe(*funcs):
""" Pipe a value through a sequence of functions
I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))``
We think of the value as progressing through a pipe of several
transformations, much like pipes in UNIX
>>> double = lambda i: 2 * i
>>> pipe(double, str)(3)
'6'
"""
return compose(*reversed(funcs))
24 changes: 24 additions & 0 deletions brainpy/_src/tools/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import unittest

import brainpy as bp
import brainpy.math as bm


class TestFunction(unittest.TestCase):
def test_compose(self):
f = lambda a: a + 1
g = lambda a: a * 10
fun1 = bp.tools.compose(f, g)
fun2 = bp.tools.pipe(g, f)

arr = bm.random.randn(10)
r1 = fun1(arr)
r2 = fun2(arr)
groundtruth = f(g(arr))
self.assertTrue(bm.allclose(r1, r2))
self.assertTrue(bm.allclose(r1, groundtruth))
bm.clear_buffer_memory()



5 changes: 5 additions & 0 deletions brainpy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,9 @@
)


from brainpy._src.tools.functions import (
compose as compose,
pipe as pipe,
)


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# version
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, 'brainpy', '__init__.py'), 'r') as f:
with open(os.path.join(here, 'brainpy', 'test_functions.py'), 'r') as f:
init_py = f.read()
version = re.search('__version__ = "(.*)"', init_py).groups()[0]
if len(sys.argv) > 2 and sys.argv[2] == '--python-tag=py3':
Expand Down

0 comments on commit e4f857e

Please sign in to comment.