-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tools] add
brainpy.tools.compose
and brainpy.tools.pipe
- Loading branch information
1 parent
fb08523
commit e4f857e
Showing
4 changed files
with
222 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import inspect | ||
from functools import partial | ||
from operator import attrgetter | ||
from types import MethodType | ||
|
||
__all__ = [ | ||
'compose', 'pipe' | ||
] | ||
|
||
|
||
def identity(x): | ||
""" Identity function. Return x | ||
>>> identity(3) | ||
3 | ||
""" | ||
return x | ||
|
||
|
||
def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): | ||
""" Like @property, but returns ``classval`` when used as a class attribute | ||
>>> class MyClass(object): | ||
... '''The class docstring''' | ||
... @instanceproperty(classval=__doc__) | ||
... def __doc__(self): | ||
... return 'An object docstring' | ||
... @instanceproperty | ||
... def val(self): | ||
... return 42 | ||
... | ||
>>> MyClass.__doc__ | ||
'The class docstring' | ||
>>> MyClass.val is None | ||
True | ||
>>> obj = MyClass() | ||
>>> obj.__doc__ | ||
'An object docstring' | ||
>>> obj.val | ||
42 | ||
""" | ||
if fget is None: | ||
return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, | ||
classval=classval) | ||
return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, | ||
classval=classval) | ||
|
||
|
||
class InstanceProperty(property): | ||
""" Like @property, but returns ``classval`` when used as a class attribute | ||
Should not be used directly. Use ``instanceproperty`` instead. | ||
""" | ||
|
||
def __init__(self, fget=None, fset=None, fdel=None, doc=None, | ||
classval=None): | ||
self.classval = classval | ||
property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) | ||
|
||
def __get__(self, obj, type=None): | ||
if obj is None: | ||
return self.classval | ||
return property.__get__(self, obj, type) | ||
|
||
def __reduce__(self): | ||
state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) | ||
return InstanceProperty, state | ||
|
||
|
||
class Compose(object): | ||
""" A composition of functions | ||
See Also: | ||
compose | ||
""" | ||
__slots__ = 'first', 'funcs' | ||
|
||
def __init__(self, funcs): | ||
funcs = tuple(reversed(funcs)) | ||
self.first = funcs[0] | ||
self.funcs = funcs[1:] | ||
|
||
def __call__(self, *args, **kwargs): | ||
ret = self.first(*args, **kwargs) | ||
for f in self.funcs: | ||
ret = f(ret) | ||
return ret | ||
|
||
def __getstate__(self): | ||
return self.first, self.funcs | ||
|
||
def __setstate__(self, state): | ||
self.first, self.funcs = state | ||
|
||
@instanceproperty(classval=__doc__) | ||
def __doc__(self): | ||
def composed_doc(*fs): | ||
"""Generate a docstring for the composition of fs. | ||
""" | ||
if not fs: | ||
# Argument name for the docstring. | ||
return '*args, **kwargs' | ||
|
||
return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) | ||
|
||
try: | ||
return ( | ||
'lambda *args, **kwargs: ' + | ||
composed_doc(*reversed((self.first,) + self.funcs)) | ||
) | ||
except AttributeError: | ||
# One of our callables does not have a `__name__`, whatever. | ||
return 'A composition of functions' | ||
|
||
@property | ||
def __name__(self): | ||
try: | ||
return '_of_'.join( | ||
(f.__name__ for f in reversed((self.first,) + self.funcs)) | ||
) | ||
except AttributeError: | ||
return type(self).__name__ | ||
|
||
def __repr__(self): | ||
return '{.__class__.__name__}{!r}'.format( | ||
self, tuple(reversed((self.first,) + self.funcs))) | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, Compose): | ||
return other.first == self.first and other.funcs == self.funcs | ||
return NotImplemented | ||
|
||
def __ne__(self, other): | ||
equality = self.__eq__(other) | ||
return NotImplemented if equality is NotImplemented else not equality | ||
|
||
def __hash__(self): | ||
return hash(self.first) ^ hash(self.funcs) | ||
|
||
# Mimic the descriptor behavior of python functions. | ||
# i.e. let Compose be called as a method when bound to a class. | ||
# adapted from | ||
# docs.python.org/3/howto/descriptor.html#functions-and-methods | ||
def __get__(self, obj, objtype=None): | ||
return self if obj is None else MethodType(self, obj) | ||
|
||
# introspection with Signature is only possible from py3.3+ | ||
@instanceproperty | ||
def __signature__(self): | ||
base = inspect.signature(self.first) | ||
last = inspect.signature(self.funcs[-1]) | ||
return base.replace(return_annotation=last.return_annotation) | ||
|
||
__wrapped__ = instanceproperty(attrgetter('first')) | ||
|
||
|
||
def compose(*funcs): | ||
""" Compose functions to operate in series. | ||
Returns a function that applies other functions in sequence. | ||
Functions are applied from right to left so that | ||
``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. | ||
If no arguments are provided, the identity function (f(x) = x) is returned. | ||
>>> inc = lambda i: i + 1 | ||
>>> compose(str, inc)(3) | ||
'4' | ||
""" | ||
if not funcs: | ||
return identity | ||
if len(funcs) == 1: | ||
return funcs[0] | ||
else: | ||
return Compose(funcs) | ||
|
||
|
||
def pipe(*funcs): | ||
""" Pipe a value through a sequence of functions | ||
I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` | ||
We think of the value as progressing through a pipe of several | ||
transformations, much like pipes in UNIX | ||
>>> double = lambda i: 2 * i | ||
>>> pipe(double, str)(3) | ||
'6' | ||
""" | ||
return compose(*reversed(funcs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
import unittest | ||
|
||
import brainpy as bp | ||
import brainpy.math as bm | ||
|
||
|
||
class TestFunction(unittest.TestCase): | ||
def test_compose(self): | ||
f = lambda a: a + 1 | ||
g = lambda a: a * 10 | ||
fun1 = bp.tools.compose(f, g) | ||
fun2 = bp.tools.pipe(g, f) | ||
|
||
arr = bm.random.randn(10) | ||
r1 = fun1(arr) | ||
r2 = fun2(arr) | ||
groundtruth = f(g(arr)) | ||
self.assertTrue(bm.allclose(r1, r2)) | ||
self.assertTrue(bm.allclose(r1, groundtruth)) | ||
bm.clear_buffer_memory() | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,4 +45,9 @@ | |
) | ||
|
||
|
||
from brainpy._src.tools.functions import ( | ||
compose as compose, | ||
pipe as pipe, | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters