From 8c8c14e49360ff74c127ad8e64fa7fc3ae81ca73 Mon Sep 17 00:00:00 2001 From: Sasha Rush Date: Sat, 13 Feb 2021 21:45:15 -0500 Subject: [PATCH] Remove tests requiring genbmm (#97) * update tests * style * update setup * update --- setup.py | 8 +++++++- torch_struct/semirings/checkpoint.py | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index f9e514cd..9d367469 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ from setuptools import setup +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + setup( name="torch_struct", version="0.5", @@ -9,9 +12,12 @@ "torch_struct", "torch_struct.semirings", ], + long_description=long_description, package_data={"torch_struct": []}, - url="https://github.com/harvardnlp/pytorch_struct", + long_description_content_type="text/markdown", + url="https://github.com/harvardnlp/pytorch-struct", install_requires=["torch"], setup_requires=["pytest-runner"], tests_require=["pytest"], + python_requires='>=3.6', ) diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index b2dacba5..c4e10c4f 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -1,8 +1,10 @@ import torch +has_genbmm = False try: import genbmm from genbmm import BandedMatrix + has_genbmm = True except ImportError: pass @@ -52,7 +54,7 @@ def backward(ctx, grad_output): class _CheckpointSemiring(cls): @staticmethod def matmul(a, b): - if isinstance(a, genbmm.BandedMatrix): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): lu = a.lu + b.lu ld = a.ld + b.ld c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)