diff --git a/app.py b/app.py index 3d008d0..ac77404 100644 --- a/app.py +++ b/app.py @@ -44,7 +44,7 @@ CAPTCHA_SECRET = os.getenv('CAPTCHA_SECRET', '') from init_app import db, init_app -from models import RegistrationKey, SiteBase, SiteIndicator, User +from models import RegistrationKey, SiteBase, SiteIndicator, User, Query, Result from modules.reference import DEFAULTS, ENGINES, LANGUAGES, COUNTRIES, LANGUAGES_YANDEX, LANGUAGES_YAHOO, COUNTRIES_YAHOO, COUNTRY_LANGUAGE_DUCKDUCKGO, DOMAINS_GOOGLE, INDICATOR_METADATA, MATCH_VALUES_TO_IGNORE # Import all your functions here from modules.crawler import crawl_one_or_more_urls, annotate_indicators @@ -91,17 +91,18 @@ def insert_sites_of_concern(local_domains): def insert_indicators(indicators): - app.logger.info("Inserting indicators: %s", indicators) - engine = db.session.get_bind() - with engine.connect() as conn: - conn.execute( - insert(SiteIndicator), - [{"domain": indicator['domain_name'], - "indicator_type": indicator['indicator_type'], - "indicator_content": str(indicator['indicator_content'])} - for indicator in indicators] - ) - conn.commit() + #app.logger.info("Inserting indicators: %s", indicators) + try: + db_indicators = [] + for indicator in indicators: + db_indicators.append(SiteIndicator(indicator_type=indicator['indicator_type'], indicator_content=str(indicator['indicator_content']), domain=indicator['domain_name'], indicator_annotation=indicator['indicator_annotation'])) + db.session.bulk_save_objects(db_indicators) + db.session.commit() + except Exception as e: + app.logger.error("Error inserting indicators: %s", e) + db.session.rollback() + + return None # TODO move to a utils or decorators file def clean_inputs(view_func): @@ -173,11 +174,12 @@ def login(request): if reg_key is not None and user is None: - reg_key_db = db.session.get(RegistrationKey, reg_key) - if reg_key_db is not None: + reg_key_db = db.session.get_one(RegistrationKey, reg_key) + if reg_key_db.registration_keys is not None: hashed_password = bcrypt.generate_password_hash(password).decode('utf-8') user = User(username=username, password=hashed_password) db.session.add(user) + db.session.commit() login_user(user) is_logged_in = True @@ -214,6 +216,7 @@ def register_gui(): hashed_password = bcrypt.generate_password_hash(password).decode('utf-8') user = User(username=username, password=hashed_password) db.session.add(user) + db.session.commit() login_user(user) is_logged_in = True else: @@ -243,7 +246,7 @@ def verify_captcha(request): } response = requests.post('https://www.google.com/recaptcha/api/siteverify', data=params) result = response.json() - if result['success'] and result['score'] >= 0.5: # You can adjust the score threshold + if result['success'] and result['score'] >= 0.3: # You can adjust the score threshold return True else: return False @@ -301,6 +304,9 @@ def find_indicators_and_matches(urls, run_urlscan = False, internal_only = False indicators_df = indicators_df[filter_mask] indicators_df = annotate_indicators(indicators_df) + #add indicators to DB + insert_indicators(indicators_df.to_dict('records')) + if internal_only: comparison_indicators = indicators_df else: @@ -868,36 +874,36 @@ def fetch_content_results(title_query, content_query, combineOperator, language, # Convert results to CSV csv_data = convert_results_to_csv(results) # Save the query to the database + try: + query = Query(title=str(title_query), content=str(content_query), combine_operator=str(combineOperator), language=str(language), country=str(country)) + db.session.add(query) + db.session.commit() - # db = get_db() - # cursor = db.cursor() - # cursor.execute('INSERT INTO content_queries (title_query, content_query, combine_operator, language, country) VALUES (?, ?, ?, ?, ?)', - # (title_query, content_query, combineOperator, language, country)) - # db.commit() - # # Get the last inserted row ID - # cq_id = cursor.lastrowid - - # results_list = [] - # for domain, data in results.items(): - # for link_data in data['links']: - # res = [ - # cq_id, - # domain, - # str(data['count']), - # link_data['title'], - # link_data['link'], - # str(link_data['count']), - # ', '.join(link_data['engines']) - # ] - # results_list.append(res) - - # # Insert data into the database - # # Prepare your SQL insert statement including the additional column - # insert_sql = 'INSERT INTO content_queries_results (cq_id, Domain, Occcurences, Title, Link, Link_Occurences, Engines) VALUES (?,?, ?, ?, ?, ?, ?)' - - # # Execute the insert command - # cursor.executemany(insert_sql, results_list) - # db.commit() + # Get the last inserted row ID + cq_id = query.id + + results_list = [] + for result in results: + res = Result( + query_id=cq_id, + domain=result['domain'], + url=result['url'], + title=result['title'], + snippet=result['snippet'], + link_count=result['link_count'], + engines=str(result['engines']), + domain_count=result['domain_count'], + score=result['score'] + + ) + results_list.append(res) + + db.session.bulk_save_objects(results_list) + db.session.commit() + + except Exception as e: + app.logger.error(f"Error saving query to database: {e}") + print(f"Error saving query to database: {e}") return results, csv_data diff --git a/config.py b/config.py index e6ef925..48cab98 100644 --- a/config.py +++ b/config.py @@ -6,9 +6,19 @@ class DevelopmentConfig(Config): DEVELOPMENT = True DEBUG = True SQLALCHEMY_DATABASE_URI = os.getenv("DEVELOPMENT_DATABASE_URL") + SQLALCHEMY_ENGINE_OPTIONS = { + 'connect_args': { + 'options': '-c statement_timeout=5000' + } + } class ProductionConfig(Config): DEBUG = False SQLALCHEMY_DATABASE_URI = os.getenv("PRODUCTION_DATABASE_URL") + SQLALCHEMY_ENGINE_OPTIONS = { + 'connect_args': { + 'options': '-c statement_timeout=5000' + } + } config = { "development": DevelopmentConfig, "production": ProductionConfig diff --git a/models.py b/models.py index bcbdb4d..836ed1f 100644 --- a/models.py +++ b/models.py @@ -13,8 +13,8 @@ class Query(db.Model): updated = db.Column(db.DateTime(timezone=True), default=datetime.now, onupdate=datetime.now) # Input by Query Fields: - title = db.Column(db.String(100), nullable=False, unique=False) - content = db.Column(db.String(100), nullable=False, unique=False) + title = db.Column(db.String(300), nullable=True, unique=False) + content = db.Column(db.String(300), nullable=True, unique=False) combine_operator = db.Column(db.String(100), nullable=True, unique=False) language = db.Column(db.String(100), nullable=True, unique=False) country = db.Column(db.String(100), nullable=True, unique=False) @@ -29,13 +29,16 @@ class Result(db.Model): updated = db.Column(db.DateTime(timezone=True), default=datetime.now, onupdate=datetime.now) # Input by Query Fields: - domain = db.Column(db.String(100), nullable=False, unique=False) - occurrences = db.Column(db.Integer(), nullable=False, unique=False) - title = db.Column(db.String(100), nullable=True, unique=False) - link = db.Column(db.String(100), nullable=True, unique=False) - link_occurrences = db.Column(db.Integer(), nullable=True, unique=False) - engines = db.Column(db.String(100), nullable=True, unique=False) - cq_id = db.Column(db.String(100), nullable=True, unique=False) + domain = db.Column(db.String(255), nullable=True, unique=False) + url = db.Column(db.String(255), nullable=True, unique=False) + title = db.Column(db.String(255), nullable=True, unique=False) + snippet = db.Column(db.Text, nullable=True, unique=False) + engine = db.Column(db.String(255), nullable=True, unique=False) + link_count = db.Column(db.Integer, nullable=True, unique=False) + domain_count = db.Column(db.Integer, nullable=True, unique=False) + engines = db.Column(db.String(255), nullable=True, unique=False) + score = db.Column(db.Float, nullable=True, unique=False) + query_id = db.Column(db.Integer, db.ForeignKey('content_queries.id'), nullable=False) class RegistrationKey(db.Model): @@ -53,8 +56,9 @@ class SiteIndicator(db.Model): # Input by Query Fields: indicator_type = db.Column(db.String(100), nullable=False, unique=False) - indicator_content = db.Column(db.String(100), nullable=False, unique=False) + indicator_content = db.Column(db.Text, nullable=True, unique=False) domain = db.Column(db.String(100), nullable=False, unique=False) + indicator_annotation = db.Column(db.String(100), nullable=True, unique=False) class SiteBase(db.Model): __tablename__ = 'sites_base'