Skip to content

Commit

Permalink
formatting?
Browse files Browse the repository at this point in the history
  • Loading branch information
Mcilie committed Nov 2, 2023
1 parent a7420d1 commit 107194a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 36 deletions.
31 changes: 10 additions & 21 deletions src/prompt_systematic_review/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import os




"""
READ THIS
https://docs.github.com/en/actions/security-guides/using-secrets-in-github-actions
Expand All @@ -14,34 +12,31 @@
"""




class Pipeline:

def __init__(self, token=None,revision="main"):
def __init__(self, token=None, revision="main"):
self.token = token
self.root = f"hf://datasets/PromptSystematicReview/Prompt_Systematic_Review_Dataset@{revision}/"
if token is not None:
self.fs = HfFileSystem(token=token)
login(token=token)
else:
self.fs = HfFileSystem()
self.revision=revision
self.revision = revision

def is_logged_in(self):
return self.token is not None

def get_revision(self):
return self.revision
def set_revision(self,revision):

def set_revision(self, revision):
try:
assert revision.isalnum()
self.revision=revision
self.revision = revision
self.root = f"hf://datasets/PromptSystematicReview/Prompt_Systematic_Review_Dataset@{revision}/"
except:
raise ValueError("Revision must be alphanumeric")

def login(self, token):
if self.token is not None:
raise ValueError("Already Logged In")
Expand All @@ -51,22 +46,16 @@ def login(self, token):
self.token = token

def get_all_files(self):
return self.fs.ls(self.root, detail=False,revision=self.revision)
return self.fs.ls(self.root, detail=False, revision=self.revision)

def get_all_data_files(self):
return self.fs.glob(self.root+"**.csv",revision=self.revision)
return self.fs.glob(self.root + "**.csv", revision=self.revision)

def read_from_file(self, fileName):
return pd.read_csv(os.path.join(self.root, fileName))


def write_to_file(self, fileName, dataFrame):
if not self.is_logged_in():
raise ValueError("Not Logged In")
path = os.path.join(self.root, fileName)
dataFrame.to_csv(path, index=False)





3 changes: 0 additions & 3 deletions tests/test_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from prompt_systematic_review.utils import process_paper_title



def test_paper():
paper1 = Paper(
"How to write a paper",
Expand Down Expand Up @@ -32,7 +31,6 @@ def test_paper():
assert paper1 != paper2 and paper2 != alsoPaper1



def test_arxiv_source():
# test that arXiv source returns papers properly
arxiv_source = ArXivSource()
Expand Down Expand Up @@ -60,4 +58,3 @@ def test_arxiv_source():
assert paper.keywords == [
"foundational models in medical imaging: a comprehensive survey and future vision"
]

30 changes: 19 additions & 11 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,54 @@
def hashString(bytes):
return str(hashlib.md5(bytes).hexdigest())


@pytest.fixture
def client():
return Pipeline(token=os.environ['HF_AUTH_TOKEN'],revision="test")
return Pipeline(token=os.environ["HF_AUTH_TOKEN"], revision="test")


def test_login():
testClient = Pipeline(revision="test")
assert testClient.is_logged_in() == False
testClient.login(os.environ['HF_AUTH_TOKEN'])
testClient.login(os.environ["HF_AUTH_TOKEN"])
assert testClient.is_logged_in() == True


def test_get_all_files(client):
assert len(client.get_all_files()) > 0


def test_get_all_data_files(client):
assert len(client.get_all_data_files()) > 0
assert all([x.endswith(".csv") for x in client.get_all_data_files()])


def test_read_from_file(client):
assert len(client.read_from_file("test.csv")) > 0
assert len(client.read_from_file("test.csv").columns) == 2
assert client.read_from_file("test.csv")["Age"].mean() == 21


def test_write_to_file(client):
lenOfFiles = len(client.get_all_files())
randString = random.randbytes(100) + str(time.time()).encode()
randHash = hashString(randString)
csvDict = {"test":[1,3], "test2":[2,4]}
csvDict = {"test": [1, 3], "test2": [2, 4]}
print(client.revision)
client.write_to_file(f"{randHash[:10]}_test.csv", pd.DataFrame(csvDict))
print(client.revision)
time.sleep(1)
# assert client.revision == "main"
df = client.read_from_file(f"{randHash[:10]}_test.csv")
df = client.read_from_file(f"{randHash[:10]}_test.csv")
assert df["test"].sum() == 4
assert df["test2"].sum() == 6
#time.sleep(1)
print(client.root+f"{randHash[:10]}_test.csv")
delete_file(f"{randHash[:10]}_test.csv", "PromptSystematicReview/Prompt_Systematic_Review_Dataset", repo_type="dataset", revision="test")

assert len(client.get_all_files()) == lenOfFiles
# time.sleep(1)
print(client.root + f"{randHash[:10]}_test.csv")
delete_file(
f"{randHash[:10]}_test.csv",
"PromptSystematicReview/Prompt_Systematic_Review_Dataset",
repo_type="dataset",
revision="test",
)



assert len(client.get_all_files()) == lenOfFiles
1 change: 0 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from prompt_systematic_review.utils import process_paper_title



def test_process_paper_title():
assert process_paper_title("Laws of the\n Wildabeest") == "laws of the wildabeest"

0 comments on commit 107194a

Please sign in to comment.