diff --git a/.gitignore b/.gitignore
index b7a4738..4cf2657 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,11 @@
+poetry.lock
+
experiments/
rtf*checkpoints/
tests/realtabformer/data/
+*.DS_Store
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/Makefile b/Makefile
index a103366..aaa30e3 100644
--- a/Makefile
+++ b/Makefile
@@ -14,7 +14,10 @@ lint:
pre-commit run -a --hook-stage manual $(hook)
test:
- pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile
+ poetry run pytest tests --cov-config pyproject.toml --numprocesses 4 --dist loadfile
+
+bump-version:
+ poetry version patch
pip-compile:
pip-compile -q -o -
@@ -23,13 +26,16 @@ secret-scan:
trufflehog --max_depth 1 --exclude_paths trufflehog-ignore.txt .
package: clean install
- python setup.py sdist bdist_wheel
+ # python setup.py sdist bdist_wheel
+ poetry build
test-pypi-upload: package
- twine upload --repository testpypi dist/*
+ # poetry run twine upload --repository testpypi dist/*
+ poetry publish --build --repository testpypi
pypi-upload: package
- twine upload dist/*
+ # poetry run twine upload dist/*
+ poetry publish --build
install-test-requirements:
pip install -r test_requirements.txt
diff --git a/README.md b/README.md
index 6e64231..0a31b9e 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-
+ [![Downloads](https://static.pepy.tech/badge/realtabformer)](https://pepy.tech/project/realtabformer) [![Downloads](https://static.pepy.tech/badge/realtabformer/month)](https://pepy.tech/project/realtabformer) [![Downloads](https://static.pepy.tech/badge/realtabformer/week)](https://pepy.tech/project/realtabformer)
# REaLTabFormer
diff --git a/pyproject.toml b/pyproject.toml
index dfb7786..13d4fdf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,47 +1,54 @@
[build-system]
-requires = ["setuptools>=61.0"]
-build-backend = "setuptools.build_meta"
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
-[project]
+[tool.poetry]
name = "REaLTabFormer"
-dynamic = ["version"]
-authors = [
- { name="Aivin V. Solatorio", email="asolatorio@worldbank.org" },
-]
description = "A novel method for generating tabular and relational data using language models."
+authors = ["Aivin V. Solatorio "]
readme = "README.md"
-license = { file="LICENSE" }
-requires-python = ">=3.7"
+license = "MIT"
+version = "0.2.0"
+homepage = "https://github.com/avsolatorio/REaLTabFormer"
+documentation = "https://worldbank.github.io/REaLTabFormer/"
+
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
- "Operating System :: OS Independent",
+ "Operating System :: OS Independent"
]
+
keywords = [
"REaLTabFormer", "deep learning", "tabular data",
"transformers", "data generation", "seq2seq model",
"synthetic data", "pytorch", "language models",
"synthetic data generation"
]
-dependencies = [
- "accelerate >= 0.20.3",
- "datasets >= 2.6.1",
- "numpy >= 1.21.6", # "numpy >= 1.23.4",
- "pandas >= 1.3.5", # "pandas >= 1.5.1",
- "scikit_learn >= 1.0.2", # "scikit_learn >= 1.1.3",
- "torch >= 1.13.0",
- "tqdm >= 4.64.1",
- "transformers >= 4.24.0",
- "shapely >= 1.8.5.post1",
-]
-[project.urls]
-"Homepage" = "https://github.com/avsolatorio/REaLTabFormer"
-"Documentation" = "https://avsolatorio.github.io/REaLTabFormer/"
+[tool.poetry.dependencies]
+python = ">=3.8"
+datasets = ">=2.6.1"
+numpy = ">=1.21.6" # ">=1.23.4"
+pandas = ">=1.3.5" # ">=1.5.1"
+scikit-learn = ">=1.0.2" # ">=1.1.3"
+tqdm = ">=4.64.1"
+transformers = {extras = ["torch", "sentencepiece"], version = ">=4.41.0"}
+shapely = ">=1.8.5.post1"
+
+[tool.poetry.urls]
+Homepage = "https://github.com/avsolatorio/REaLTabFormer"
+Documentation = "https://avsolatorio.github.io/REaLTabFormer/"
+
+[tool.poetry.scripts]
+realtabformer = "realtabformer:main"
-[tool.setuptools.packages.find]
-where = ["src"]
+[tool.poetry_bumpversion.file."src/realtabformer/VERSION"]
-[tool.setuptools.dynamic]
-# version = {attr = "realtabformer.__version__"}
-version = {file = "src/realtabformer/VERSION"}
+[tool.poetry.group.dev.dependencies]
+ipykernel = "^6.29.4"
+pytest = "^8.2.2"
+isort = "^5.13.2"
+black = "^24.4.2"
+bandit = "^1.7.9"
+trufflehog = "^2.2.1"
+pytest-mock = "^3.14.0"
diff --git a/setup.py b/setup.py
deleted file mode 100644
index a7d273d..0000000
--- a/setup.py
+++ /dev/null
@@ -1,4 +0,0 @@
-import setuptools
-
-if __name__ == "__main__":
- setuptools.setup(setup_requires=["setuptools_scm"], include_package_data=True)
diff --git a/src/realtabformer/VERSION b/src/realtabformer/VERSION
index 1180819..0ea3a94 100644
--- a/src/realtabformer/VERSION
+++ b/src/realtabformer/VERSION
@@ -1 +1 @@
-0.1.7
+0.2.0
diff --git a/src/realtabformer/realtabformer.py b/src/realtabformer/realtabformer.py
index 5837908..1ec1688 100644
--- a/src/realtabformer/realtabformer.py
+++ b/src/realtabformer/realtabformer.py
@@ -822,10 +822,10 @@ def _train_with_sensitivity(
loaded_model_path = None
if not load_from_best_mean_sensitivity:
- if (bdm_path / "pytorch_model.bin").exists():
+ if (bdm_path / "pytorch_model.bin").exists() or (bdm_path / "model.safetensors").exists():
loaded_model_path = bdm_path
else:
- if (mean_closest_bdm_path / "pytorch_model.bin").exists():
+ if (mean_closest_bdm_path / "pytorch_model.bin").exists() or (mean_closest_bdm_path / "model.safetensors").exists():
loaded_model_path = mean_closest_bdm_path
if loaded_model_path is None: