Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identifying routines internally by id instead of name #100

Merged
merged 12 commits into from
Oct 28, 2024
Merged
8 changes: 4 additions & 4 deletions src/badger/actions/routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ def show_routine(args):
return

# List routines
if args.routine_name is None:
routines = list_routine()[0]
if args.routine_id is None:
routines = list_routine()[1]
if routines:
yprint(routines)
else:
print('No routine has been saved yet')
return

try:
routine, _ = load_routine(args.routine_name)
routine, _ = load_routine(args.routine_id)
if routine is None:
print(f'Routine {args.routine_name} not found')
print(f'Routine {args.routine_id} not found')
return
except Exception as e:
print(e)
Expand Down
2 changes: 1 addition & 1 deletion src/badger/core_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def run_routine_subprocess(
print(f"Error in subprocess: {type(e).__name__}, {str(e)}")

# set required arguments
routine, _ = load_routine(args["routine_name"])
routine, _ = load_routine(args["routine_id"])

# TODO look into this bug with serializing of turbo. Fix might be needed in Xopt
try:
Expand Down
116 changes: 43 additions & 73 deletions src/badger/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
logger = logging.getLogger(__name__)
import yaml
import sqlite3
import uuid
from .routine import Routine
from .settings import read_value
from .utils import get_yaml_string
Expand All @@ -20,7 +21,6 @@
logger.info(
f'Badger database root {BADGER_DB_ROOT} created')


def maybe_create_routines_db(func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't a change from this PR but could be come up with a better name for this @zhe-slac @shamin-slac

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ensure_routines_db_exists or require_routines_db?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what is the purpose of this decorator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To create the routines database if it doesn't already exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I'm liking the ensure_routines_db_exists name and I would add a small docstring to the decorator that describes what it does


def func_safe(*args, **kwargs):
Expand All @@ -29,7 +29,7 @@ def func_safe(*args, **kwargs):
con = sqlite3.connect(db_routine)
cur = con.cursor()

cur.execute('create table if not exists routine (name not null primary key, config, savedAt timestamp)')
cur.execute('create table if not exists routine (id text primary key, name text, config, savedAt timestamp)')

con.commit()
con.close()
Expand All @@ -47,7 +47,7 @@ def func_safe(*args, **kwargs):
con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute('create table if not exists run (id integer primary key, savedAt timestamp, finishedAt timestamp, routine, filename)')
cur.execute('create table if not exists run (id integer primary key, savedAt timestamp, finishedAt timestamp, routine_id, filename)')

con.commit()
con.close()
Expand All @@ -61,7 +61,7 @@ def filter_routines(records, tags):
records_filtered = []
for record in records:
try:
_tags = yaml.safe_load(record[1])['config']['tags']
_tags = yaml.safe_load(record[3])['config']['tags']
if tags.items() <= _tags.items():
records_filtered.append(record)
except:
Expand All @@ -75,7 +75,7 @@ def extract_metadata(records):
descr_list = []
for record in records:
try:
metadata = yaml.safe_load(record[1])
metadata = yaml.safe_load(record[2])
env = metadata['environment']['name']
env_list.append(env)
descr = metadata['description']
Expand All @@ -94,65 +94,47 @@ def save_routine(routine: Routine):
con = sqlite3.connect(db_routine)
cur = con.cursor()

cur.execute('select * from routine where name=:name',
{'name': routine.name})
record = cur.fetchone()

runs = get_runs_by_routine(routine.name)

if record and len(runs) == 0: # update the record
cur.execute('update routine set config = ?, savedAt = ? where name = ?',
(routine.yaml(), datetime.now(), routine.name))
else: # insert a record
cur.execute('insert into routine values (?, ?, ?)',
(routine.name, routine.yaml(), datetime.now()))
id = str(uuid.uuid4())
routine.id = id
cur.execute('insert into routine values (?, ?, ?, ?)',
(routine.id, routine.name, routine.yaml(), datetime.now()))

con.commit()
con.close()

return routine.id


# This function is not safe and might break database! Use with caution!
@maybe_create_routines_db
@maybe_create_runs_db
def update_routine(routine: Routine, old_name=''):
def update_routine(routine: Routine):

db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')

con = sqlite3.connect(db_routine)
cur = con.cursor()

name = old_name if old_name else routine.name
cur.execute('select * from routine where name=:name',
{'name': name})
cur.execute('select * from routine where id=:id',
{'id': routine.id})
record = cur.fetchone()

if record: # update the record
cur.execute('update routine set name = ?, config = ?, savedAt = ? where name = ?',
(routine.name, routine.yaml(), datetime.now(), name))

if old_name:
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con_run = sqlite3.connect(db_run, timeout=30.0)
cur_run = con_run.cursor()

cur_run.execute('update run set routine = ? where routine = ?',(routine.name, old_name))

con_run.commit()
con_run.close()
cur.execute('update routine set name = ?, config = ?, savedAt = ? where id = ?',
(routine.name, routine.yaml(), datetime.now(), routine.id))

con.commit()
con.close()


@maybe_create_routines_db
@maybe_create_runs_db
def remove_routine(name, remove_runs=False):
def remove_routine(id: str, remove_runs=False):
db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')

con = sqlite3.connect(db_routine)
cur = con.cursor()

cur.execute(f'delete from routine where name = "{name}"')
cur.execute(f'delete from routine where id = "{id}"')

con.commit()
con.close()
Expand All @@ -164,26 +146,26 @@ def remove_routine(name, remove_runs=False):
con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute(f'delete from run where routine = "{name}"')
cur.execute(f'delete from run where routine_id = "{id}"')

con.commit()
con.close()


@maybe_create_routines_db
def load_routine(name: str):
def load_routine(id: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you said that the routine ID is optional? How is that handled in this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function just takes an id, not a whole routine. In principle, the only place where there is a routine without an id is when a routine is composed from the routine editor (this is the main motivation to have id as an optional and not required field in the Routine class). Then in the save_routine() function, the id is generated and stored in that routine which is then saved in the database (and then it is subsequently retrieved for use in Badger). That being said, there are a couple places where I can imagine it's not impossible that there is no id, so I've modified the function to handle this possibility.

db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')
con = sqlite3.connect(db_routine)
cur = con.cursor()

cur.execute('select * from routine where name=:name', {'name': name})
cur.execute('select * from routine where id=:id', {'id': id})

records = cur.fetchall()
con.close()

if len(records) == 1:
# return yaml.safe_load(records[0][1]), records[0][2]
routine_dict = yaml.safe_load(records[0][1])
routine_dict = yaml.safe_load(records[0][2])
# routine_dict['evaluator'] = None
with warnings.catch_warnings(record=True) as caught_warnings:
routine = Routine(**routine_dict)
Expand All @@ -195,13 +177,13 @@ def load_routine(name: str):
else:
print(f"Caught user warning: {warning.message}")

return routine, records[0][2]
return routine, records[0][3]
elif len(records) == 0:
# logger.warning(f'Routine {name} not found in the database!')
return None, None
else:
raise BadgerDBError(
f'Multiple routines with name {name} found in the database!')
f'Multiple routines with id {id} found in the database!')


@maybe_create_routines_db
Expand All @@ -210,16 +192,17 @@ def list_routine(keyword='', tags={}):
con = sqlite3.connect(db_routine)
cur = con.cursor()

cur.execute(f'select name, config, savedAt from routine where name like "%{keyword}%" order by savedAt desc')
cur.execute(f'select id, name, config, savedAt from routine where name like "%{keyword}%" order by savedAt desc')
records = cur.fetchall()
if tags:
records = filter_routines(records, tags)
names = [record[0] for record in records]
timestamps = [record[2] for record in records]
ids = [record[0] for record in records]
names = [record[1] for record in records]
timestamps = [record[3] for record in records]
environments, descriptions = extract_metadata(records)
con.close()

return names, timestamps, environments, descriptions
return ids, names, timestamps, environments, descriptions


@maybe_create_runs_db
Expand All @@ -230,7 +213,7 @@ def save_run(run):
cur = con.cursor()

# Insert or update a record
routine_name = run['routine'].name
routine_id = run['routine'].id
run_filename = run['filename']
timestamps = run['data']['timestamp']
time_start = datetime.fromtimestamp(timestamps[0])
Expand All @@ -246,7 +229,7 @@ def save_run(run):
rid = existing_row[0]
else:
cur.execute('insert into run values (?, ?, ?, ?, ?)',
(None, time_start, time_finish, routine_name, run_filename))
(None, time_start, time_finish, routine_id, run_filename))
rid = cur.lastrowid

con.commit()
Expand All @@ -256,13 +239,13 @@ def save_run(run):


@maybe_create_runs_db
def get_runs_by_routine(routine: str):
def get_runs_by_routine(routine_id: str):
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute(f'select filename from run where routine = "{routine}" order by savedAt desc')
cur.execute(f'select filename from run where routine_id = "{routine_id}" order by savedAt desc')
records = cur.fetchall()
con.close()

Expand All @@ -289,13 +272,13 @@ def get_runs():


@maybe_create_runs_db
def remove_run_by_filename(name):
def remove_run_by_filename(filename):
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute(f'delete from run where filename = "{name}"')
cur.execute(f'delete from run where filename = "{filename}"')

con.commit()
con.close()
Expand All @@ -313,26 +296,13 @@ def remove_run_by_id(rid):
con.commit()
con.close()

@maybe_create_runs_db
def get_routine_name_by_filename(filename):
db_run = os.path.join(BADGER_DB_ROOT, 'runs.db')

con = sqlite3.connect(db_run)
cur = con.cursor()

cur.execute(f'select routine from run where filename = "{filename}"')
routine_name = cur.fetchone()[0]
con.close()

return routine_name


def import_routines(filename):
con = sqlite3.connect(filename)
cur = con.cursor()

# Deal with empty db file
cur.execute('create table if not exists routine (name not null primary key, config, savedAt timestamp)')
cur.execute('create table if not exists routine (id text primary key, name text, config, savedAt timestamp)')

db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')
con_db = sqlite3.connect(db_routine)
Expand All @@ -344,7 +314,7 @@ def import_routines(filename):
failed_list = []
for record in records:
try:
cur_db.execute('insert into routine values (?, ?, ?)', record)
cur_db.execute('insert into routine values (?, ?, ?, ?)', record)
except:
failed_list.append(record[0])

Expand All @@ -357,22 +327,22 @@ def import_routines(filename):
raise BadgerDBError(get_yaml_string(failed_list))


def export_routines(filename, routine_name_list):
def export_routines(filename, routine_id_list):
con = sqlite3.connect(filename)
cur = con.cursor()

cur.execute('create table if not exists routine (name not null primary key, config, savedAt timestamp)')
cur.execute('create table if not exists routine (id text primary key, name text, config, savedAt timestamp)')

db_routine = os.path.join(BADGER_DB_ROOT, 'routines.db')
con_db = sqlite3.connect(db_routine)
cur_db = con_db.cursor()

for name in routine_name_list:
cur_db.execute('select * from routine where name=:name', {'name': name})
for id in routine_id_list:
cur_db.execute('select * from routine where id=:id', {'id': id})
records = cur_db.fetchall()
record = records[0] # should only have one hit

cur.execute('insert into routine values (?, ?, ?)', record)
cur.execute('insert into routine values (?, ?, ?, ?)', record)

con_db.close()

Expand Down
4 changes: 0 additions & 4 deletions src/badger/gui/default/components/routine_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ def set_routine(self, routine: Routine, silent=False):
def edit_routine(self):
self.stacks.setCurrentIndex(1)

def del_routine(self):
if self.routine_page.delete() == 0:
self.sig_deleted.emit()

def cancel_create_routine(self):
self.sig_canceled.emit()

Expand Down
10 changes: 4 additions & 6 deletions src/badger/gui/default/components/routine_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@


class BadgerRoutineItem(QWidget):
# add id so sig_del gets id
sig_del = pyqtSignal(str)

def __init__(self, name, timestamp, environment, env_dict, description='', parent=None):
def __init__(self, id, name, timestamp, environment, env_dict, description='', parent=None):
super().__init__(parent)

self.activated = False
self.hover = False
self.id = id
self.name = name
self.timestamp = timestamp
self.description = description
Expand Down Expand Up @@ -187,7 +189,7 @@ def delete_routine(self):
if reply != QMessageBox.Yes:
return

self.sig_del.emit(self.name)
self.sig_del.emit(self.id)

def update_tooltip(self):
_timestamp = datetime.fromisoformat(self.timestamp)
Expand All @@ -197,7 +199,3 @@ def update_tooltip(self):
def update_description(self, descr):
self.description = descr
self.update_tooltip()

def update_name(self, name):
self.name = name
self.update_tooltip()
Loading
Loading