Skip to content

Commit

Permalink
FIXED PIPELINE
Browse files Browse the repository at this point in the history
  • Loading branch information
Mcilie committed Nov 2, 2023
1 parent ff23b02 commit a7420d1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ install_requires =
jellyfish
datasets
huggingface_hub
black

[options.packages.find]
where=src
31 changes: 21 additions & 10 deletions src/prompt_systematic_review/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from huggingface_hub import HfFileSystem
from huggingface_hub import HfFileSystem, login
import pandas as pd
from io import StringIO
import os
Expand All @@ -18,23 +18,36 @@

class Pipeline:

def __init__(self, token=None,revision="test"):
def __init__(self, token=None,revision="main"):
self.token = token
self.root = "datasets/PromptSystematicReview/Prompt_Systematic_Review_Dataset/"
self.root = f"hf://datasets/PromptSystematicReview/Prompt_Systematic_Review_Dataset@{revision}/"
if token is not None:
self.fs = HfFileSystem(token=token)
login(token=token)
else:
self.fs = HfFileSystem()
self.revision=revision

def is_logged_in(self):
return self.token is not None

def get_revision(self):
return self.revision

def set_revision(self,revision):
try:
assert revision.isalnum()
self.revision=revision
self.root = f"hf://datasets/PromptSystematicReview/Prompt_Systematic_Review_Dataset@{revision}/"
except:
raise ValueError("Revision must be alphanumeric")

def login(self, token):
if self.token is not None:
raise ValueError("Already Logged In")
else:
self.fs = HfFileSystem(token=self.token)
login(token=token)
self.token = token

def get_all_files(self):
Expand All @@ -44,17 +57,15 @@ def get_all_data_files(self):
return self.fs.glob(self.root+"**.csv",revision=self.revision)

def read_from_file(self, fileName):
self.fs = None
self.fs = HfFileSystem(token=self.token)
text = self.fs.read_text(os.path.join(self.root, fileName),revision=self.revision)
return pd.read_csv(StringIO(text))
return pd.read_csv(os.path.join(self.root, fileName))


def write_to_file(self, fileName, dataFrame):
if not self.is_logged_in():
raise ValueError("Not Logged In")
path = os.path.join(self.root, fileName)
self.fs.write_text(path, dataFrame.to_csv(index=False),revision=self.revision)
self.fs = None
self.fs = HfFileSystem(token=self.token)
dataFrame.to_csv(path, index=False)




Expand Down
15 changes: 8 additions & 7 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import os
from prompt_systematic_review.pipeline import *
from huggingface_hub import delete_file
import random
import time
import hashlib
Expand All @@ -13,7 +14,7 @@ def hashString(bytes):
@pytest.fixture
def client():
return Pipeline(token=os.environ['HF_AUTH_TOKEN'],revision="test")
'''

def test_login():
testClient = Pipeline(revision="test")
assert testClient.is_logged_in() == False
Expand All @@ -30,11 +31,9 @@ def test_get_all_data_files(client):
def test_read_from_file(client):
assert len(client.read_from_file("test.csv")) > 0
assert len(client.read_from_file("test.csv").columns) == 2
assert client.read_from_file("test.csv")[" Age"].mean() == 21
assert client.read_from_file("test.csv")["Age"].mean() == 21

def test_write_to_file(client):
assert True
return True
lenOfFiles = len(client.get_all_files())
randString = random.randbytes(100) + str(time.time()).encode()
randHash = hashString(randString)
Expand All @@ -48,8 +47,10 @@ def test_write_to_file(client):
assert df["test"].sum() == 4
assert df["test2"].sum() == 6
#time.sleep(1)
#client.fs.delete(f"{randHash[:10]}_test.csv",revision="test")
#assert len(client.get_all_files()) == lenOfFiles
'''
print(client.root+f"{randHash[:10]}_test.csv")
delete_file(f"{randHash[:10]}_test.csv", "PromptSystematicReview/Prompt_Systematic_Review_Dataset", repo_type="dataset", revision="test")

assert len(client.get_all_files()) == lenOfFiles



0 comments on commit a7420d1

Please sign in to comment.