Skip to content

Commit

Permalink
Reinstate black-compatible isort (#450)
Browse files Browse the repository at this point in the history
* Reinstate isort

* Run isort

* Fix ops.AssociativeOp issue

* Ignore cuda warning during tests

* Attempt to fix pattern registration

* Another attempt

* Fix merge conflict

Co-authored-by: eb8680 <[email protected]>
  • Loading branch information
fritzo and eb8680 authored Feb 3, 2021
1 parent 8f55725 commit 155ad4c
Show file tree
Hide file tree
Showing 38 changed files with 74 additions and 57 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ docs: FORCE
lint: FORCE
flake8
black --check .
isort --check .

license: FORCE
python scripts/update_headers.py

format: FORCE
black .
isort .

test: lint FORCE
ifeq (${FUNSOR_BACKEND}, torch)
Expand Down
2 changes: 1 addition & 1 deletion examples/discrete_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import lazy
Expand Down
4 changes: 2 additions & 2 deletions examples/eeg_slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from urllib.request import urlopen

import numpy as np
import pyro
import torch
import torch.nn as nn
import pyro

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.pyro.convert import (
funsor_to_cat_and_mvn,
funsor_to_mvn,
Expand Down
2 changes: 1 addition & 1 deletion examples/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import lazy
Expand Down
4 changes: 2 additions & 2 deletions examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import pyro
import pyro.poutine as poutine
import torch
from model import Guide, Model
from seal_data import prepare_fake, prepare_seal

import funsor
import funsor.ops as ops
from funsor.interpreter import interpretation
from funsor.optimizer import apply_optimizer
from funsor.sum_product import MarkovProduct, naive_sequential_sum_product, sum_product
from funsor.terms import lazy, to_funsor
from model import Guide, Model
from seal_data import prepare_fake, prepare_seal


def aic_num_parameters(model, guide=None):
Expand Down
2 changes: 1 addition & 1 deletion examples/mixed_hmm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from torch.distributions import constraints

import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.domains import Bint, Reals
from funsor.tensor import Tensor
from funsor.terms import Stack, Variable, to_funsor
Expand Down
4 changes: 2 additions & 2 deletions examples/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch.optim import Adam

import funsor
import funsor.torch.distributions as f_dist
import funsor.ops as ops
import funsor.torch.distributions as f_dist
from funsor.domains import Reals
from funsor.pyro.convert import dist_to_funsor, funsor_to_mvn
from funsor.tensor import Tensor, Variable
Expand Down Expand Up @@ -241,8 +241,8 @@ def main(args):
import matplotlib

matplotlib.use("Agg")
from matplotlib import pyplot
import numpy as np
from matplotlib import pyplot

seeds = set(seed for seed, _, _ in results)
X = args.num_frames
Expand Down
2 changes: 1 addition & 1 deletion examples/slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist


def main(args):
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torchvision import datasets, transforms

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.domains import Bint, Reals

REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down
3 changes: 1 addition & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals
from funsor.integrate import Integrate
from funsor.interpreter import reinterpret, interpretation
from funsor.interpreter import interpretation, reinterpret
from funsor.sum_product import MarkovProduct
from funsor.tensor import Tensor, function
from funsor.terms import (
Expand Down Expand Up @@ -42,7 +42,6 @@
testing,
)


__all__ = [
"Array",
"Bint",
Expand Down
2 changes: 1 addition & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from funsor.interpreter import gensym
from funsor.tensor import Einsum, Tensor, get_default_prototype
from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, Bint
from funsor.terms import Binary, Bint, Funsor, Lambda, Reduce, Unary, Variable

from . import ops

Expand Down
1 change: 0 additions & 1 deletion funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
)
from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property


BACKEND_TO_DISTRIBUTIONS_BACKEND = {
"torch": "funsor.torch.distributions",
"jax": "funsor.jax.distributions",
Expand Down
1 change: 0 additions & 1 deletion funsor/einsum/numpy_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import operator

from functools import reduce

import funsor.ops as ops
Expand Down
2 changes: 1 addition & 1 deletion funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, align_gaussian, _mv, _trace_mm, _vv
from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian
from funsor.tensor import Tensor
from funsor.terms import (
Funsor,
Expand Down
15 changes: 9 additions & 6 deletions funsor/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
from jax.core import Tracer
from jax.interpreters.xla import DeviceArray

import funsor.jax.distributions # noqa: F401
import funsor.jax.ops # noqa: F401
import funsor.ops as ops
from funsor.adjoint import adjoint_ops
from funsor.interpreter import children, recursion_reinterpret
from funsor.terms import Funsor, to_funsor
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, tensor_to_funsor
from funsor.terms import Funsor, to_funsor
from funsor.util import quote

from . import distributions as _
from . import ops as _

del _ # flake8


@adjoint_ops.register(
Tensor,
ops.AssociativeOp,
ops.AssociativeOp,
AssociativeOp,
AssociativeOp,
Funsor,
(DeviceArray, Tracer),
tuple,
Expand Down
4 changes: 2 additions & 2 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import numpyro.distributions as dist

import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.distribution import (
Bernoulli,
FUNSOR_DIST_NAMES,
Bernoulli,
LogNormal,
backenddist_to_funsor,
eager_beta,
Expand Down Expand Up @@ -37,7 +38,6 @@
transformeddist_to_funsor,
)
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor
from funsor.terms import Binary, Funsor, Reduce, Variable, eager, to_data, to_funsor
from funsor.util import methodof
Expand Down
2 changes: 1 addition & 1 deletion funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import expit, gammaln, logsumexp

import funsor.ops as ops
from .. import ops

################################################################################
# Register Ops
Expand Down
1 change: 1 addition & 0 deletions funsor/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from . import array, builtin, op
from .array import *
from .builtin import *
from .op import *
3 changes: 3 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def sigmoid_log_abs_det_jacobian(x, y):
PRODUCT_INVERSES[add] = sub

__all__ = [
"AssociativeOp",
"GetitemOp",
"NullOp",
"abs",
"add",
"and_",
Expand Down
2 changes: 1 addition & 1 deletion funsor/pyro/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from collections import OrderedDict

import torch
from torch.distributions import constraints
from pyro.distributions import TorchDistribution
from torch.distributions import constraints

from funsor.cnf import Contraction
from funsor.delta import Delta
Expand Down
13 changes: 7 additions & 6 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from multipledispatch.variadic import Variadic

import funsor
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain
from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp
from funsor.terms import (

from . import ops
from .delta import Delta
from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain
from .ops import GetitemOp, MatmulOp, Op, ReshapeOp
from .terms import (
Binary,
Funsor,
FunsorMeta,
Expand All @@ -34,7 +35,7 @@
to_data,
to_funsor,
)
from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote
from .util import get_backend, get_tracing_state, getargspec, is_nn_module, quote


def get_default_prototype():
Expand Down
4 changes: 2 additions & 2 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.domains import Domain, Bint, Real
from funsor.domains import Bint, Domain, Real
from funsor.gaussian import Gaussian
from funsor.terms import Funsor, Number
from funsor.tensor import Tensor
from funsor.terms import Funsor, Number
from funsor.util import get_backend


Expand Down
13 changes: 8 additions & 5 deletions funsor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import torch
from multipledispatch import dispatch

import funsor.torch.distributions # noqa: F401
import funsor.torch.ops # noqa: F401
import funsor.ops as ops
from funsor.adjoint import adjoint_ops
from funsor.interpreter import children, recursion_reinterpret
from funsor.terms import Funsor, to_funsor
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, tensor_to_funsor
from funsor.terms import Funsor, to_funsor
from funsor.util import quote

from . import distributions as _
from . import ops as _

del _ # flake8


@adjoint_ops.register(
Tensor, ops.AssociativeOp, ops.AssociativeOp, Funsor, torch.Tensor, tuple, object
Tensor, AssociativeOp, AssociativeOp, Funsor, torch.Tensor, tuple, object
)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
return {}
Expand Down
2 changes: 2 additions & 0 deletions funsor/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ def set_backend(backend):
_FUNSOR_BACKEND = "torch"

import torch # noqa: F401

import funsor.torch # noqa: F401
elif backend == "jax":
_FUNSOR_BACKEND = "jax"
_JAX_LOADED = True

import jax # noqa: F401

import funsor.jax # noqa: F401
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_headers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import os
import glob
import os

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/"]
Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ per-file-ignores =
funsor/jax/distributions.py:F821
funsor/torch/distributions.py:F821

[isort]
profile = black
known_first_party = funsor, test
known_third_party = opt_einsum, pyro, pyroapi, torch, torchvision

[tool:pytest]
filterwarnings = error
ignore:numpy.ufunc size changed:RuntimeWarning
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"test": [
"black",
"flake8",
"isort>=5.0",
"pandas",
"pyro-api>=0.1.2",
"pytest==4.3.1",
Expand All @@ -61,6 +62,7 @@
"dev": [
"black",
"flake8",
"isort>=5.0",
"pandas",
"pytest==4.3.1",
"pytest-xdist==1.27.0",
Expand Down
2 changes: 1 addition & 1 deletion test/examples/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.cnf import Contraction
from funsor.domains import Bint, Real, Reals
from funsor.gaussian import Gaussian
Expand Down
2 changes: 1 addition & 1 deletion test/examples/test_sensor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
import torch

import funsor.torch.distributions as dist
import funsor.ops as ops
import funsor.torch.distributions as dist
from funsor.cnf import Contraction
from funsor.domains import Bint, Reals
from funsor.gaussian import Gaussian
Expand Down
Loading

0 comments on commit 155ad4c

Please sign in to comment.