Skip to content
This repository has been archived by the owner on Mar 17, 2022. It is now read-only.

Commit

Permalink
use flask-sqlalchemy session scope for saq.db in wsgi apps
Browse files Browse the repository at this point in the history
  • Loading branch information
unixfreak0037 committed Feb 14, 2019
1 parent a2a7d6f commit 20f5e0f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 22 deletions.
6 changes: 4 additions & 2 deletions ace.wsgi
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ import saq
saq.initialize(saq_home=saq_home, config_paths=None, logging_config_path=logging_config_path, relative_dir=saq_home)

# initialize flask
from app import create_app
application = create_app() # fix this hard coded string
import app
application = app.create_app() # fix this hard coded string
# tell ACE to use the session scope provided by the sqlalchemy-flask extension
saq.db = app.db.session

# add the "do" template command
application.jinja_env.add_extension('jinja2.ext.do')
8 changes: 5 additions & 3 deletions api.wsgi
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ logging_config_path = os.path.join(saq_home, 'etc', 'api_logging.ini')
# initialize saq
# note that config paths are determined by the env vars we dug out above
import saq
saq.initialize(saq_home=saq_home, config_paths=None, logging_config_path=logging_config_path, relative_dir=saq_home, use_flask=True)
saq.initialize(saq_home=saq_home, config_paths=None, logging_config_path=logging_config_path, relative_dir=saq_home)

# initialize flask
from api import create_app
application = create_app()
import api
application = api.create_app()
# tell ACE to use the session scope provided by the sqlalchemy-flask extension
saq.db = api.db.session
5 changes: 2 additions & 3 deletions lib/saq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def initialize(saq_home=None,
logging_config_path=None,
args=None,
relative_dir=None,
unittest=False,
use_flask=False):
unittest=False):

from saq.database import initialize_database, initialize_node

Expand Down Expand Up @@ -460,7 +459,7 @@ def initialize(saq_home=None,
YSS_SOCKET_DIR = os.path.join(YSS_BASE_DIR, CONFIG['yara']['yss_socket_dir'])

# initialize the database connection
initialize_database(use_flask=use_flask)
initialize_database()

# initialize fallback semaphores
initialize_fallback_semaphores()
Expand Down
14 changes: 3 additions & 11 deletions lib/saq/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,10 +2001,8 @@ def clear_delayed_analysis_requests(root, db, c):
"""Clears all delayed analysis requests for the given RootAnalysis object."""
execute_with_retry(db, c, "DELETE FROM delayed_analysis WHERE uuid = %s", (root.uuid,), commit=True)

def initialize_database(use_flask=False):
"""Initializes database connections by creating the SQLAlchemy engine and session objects.
:param bool use_flask: If this flag is set to True then we use configure database session to be in sync with
Flask's request objects. Otherwise the default "thread local" session_scope is used. """
def initialize_database():
"""Initializes database connections by creating the SQLAlchemy engine and session objects."""

global DatabaseSession
from config import config
Expand All @@ -2014,13 +2012,7 @@ def initialize_database(use_flask=False):
**config[saq.CONFIG['global']['instance_type']].SQLALCHEMY_DATABASE_OPTIONS)

DatabaseSession = sessionmaker(bind=engine)

if use_flask:
import flask
saq.db = scoped_session(DatabaseSession, scopefunc=flask._app_ctx_stack.__ident_func__)
logging.debug("using flask for session scoping")
else:
saq.db = scoped_session(DatabaseSession)
saq.db = scoped_session(DatabaseSession)

@use_db
def initialize_node(db, c):
Expand Down
7 changes: 4 additions & 3 deletions lib/saq/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,10 +740,10 @@ def execute_api_server(self, listen_address=None, listen_port=None, ssl_cert=Non
# this is a bit weird because I want the urls to be the same as they
# are configured for apache, where they are all starting with /api

from api import create_app
import api
from saq.database import initialize_database

app = create_app(testing=True)
app = api.create_app(testing=True)
from werkzeug.serving import run_simple
from werkzeug.wsgi import DispatcherMiddleware
from flask import Flask
Expand All @@ -761,7 +761,8 @@ def execute_api_server(self, listen_address=None, listen_port=None, ssl_cert=Non
saq.CONFIG.get('api', 'ssl_cert') if ssl_cert is None else ssl_cert,
saq.CONFIG.get('api', 'ssl_key') if ssl_key is None else ssl_key )

initialize_database(use_flask=True)
initialize_database()
saq.db = api.db.session

logging.info(f"starting api server on {listen_address} port {listen_port}")
run_simple(listen_address, listen_port, application, ssl_context=ssl_context, use_reloader=False)
Expand Down

0 comments on commit 20f5e0f

Please sign in to comment.