Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove schain firewall config during cleaner stage (beta) #1141

Merged
merged 12 commits into from
Jan 23, 2025
16 changes: 14 additions & 2 deletions core/schains/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def volume(self) -> CheckRes:
@property
def firewall_rules(self) -> CheckRes:
"""Checks that firewall rules are set correctly"""
data = {
'inited': False,
'rules': False,
'persistent': False,
}
if self.config:
conf = self.cfm.skaled_config
base_port = get_base_port_from_config(conf)
Expand All @@ -311,8 +316,15 @@ def firewall_rules(self) -> CheckRes:
base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges
)
logger.debug(f'Rule controller {self.rc.expected_rules()}')
return CheckRes(self.rc.is_rules_synced())
return CheckRes(False)
data.update({
'inited': self.rc.is_inited(),
'rules': self.rc.is_rules_synced(),
'persistent': self.rc.is_persistent(),
})
logger.debug('Firewall rules check: %s', data)
status = all(data.values())
return CheckRes(status=status, data=data)
return CheckRes(status=False, data=data)

@property
def skaled_container(self) -> CheckRes:
Expand Down
54 changes: 33 additions & 21 deletions core/schains/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import glob
import logging
import os
import shutil
from multiprocessing import Process
from pathlib import Path
from typing import Optional

from sgx import SgxClient
from skale import Skale

from core.node import get_current_nodes, get_skale_node_version
from core.schains.checks import SChainChecks
from core.schains.config.file_manager import ConfigFileManager
from core.schains.config.directory import schain_config_dir
from core.schains.dkg.utils import get_secret_key_share_filepath
from core.schains.firewall.utils import get_default_rule_controller
from core.schains.config.helper import (
get_base_port_from_config,
get_node_ips_from_config,
get_own_ip_from_config,
)
from core.schains.firewall.utils import cleanup_firewall_for_schain, get_default_rule_controller
from core.schains.process import ProcessReport, terminate_process
from core.schains.runner import get_container_name, is_exited
from core.schains.external_config import ExternalConfig
from core.schains.types import ContainerType
from core.schains.firewall.utils import get_sync_agent_ranges

from tools.configs import SGX_CERTIFICATES_FOLDER, SYNC_NODE
from tools.configs import (
NFT_CHAIN_CONFIG_WILDCARD,
SGX_CERTIFICATES_FOLDER,
SYNC_NODE
)
from tools.configs.schains import SCHAINS_DIR_PATH
from tools.configs.containers import SCHAIN_CONTAINER, IMA_CONTAINER, SCHAIN_STOP_TIMEOUT
from tools.docker_utils import DockerUtils
Expand Down Expand Up @@ -136,18 +136,36 @@ def get_schains_with_containers(dutils=None):
]


def get_schains_firewall_configs() -> list:
return list(map(lambda path: Path(path).stem, glob.glob(NFT_CHAIN_CONFIG_WILDCARD)))


def get_schains_on_node(dutils=None):
dutils = dutils or DockerUtils()
schains_with_dirs = os.listdir(SCHAINS_DIR_PATH)
schains_with_container = get_schains_with_containers(dutils)
schains_active_records = get_schains_names()
schains_firewall_configs = list(
map(
lambda name: name.removeprefix('skale-'),
get_schains_firewall_configs()
)
)
logger.info(
'dirs %s, containers: %s, records: %s',
'dirs %s, containers: %s, records: %s, firewall configs: %s',
schains_with_dirs,
schains_with_container,
schains_active_records
schains_active_records,
schains_firewall_configs
)
return sorted(
merged_unique(
schains_with_dirs,
schains_with_container,
schains_active_records,
schains_firewall_configs
)
)
return sorted(merged_unique(schains_with_dirs, schains_with_container, schains_active_records))


def schain_names_to_ids(skale, schain_names):
Expand Down Expand Up @@ -258,16 +276,10 @@ def cleanup_schain(
remove_schain_container(schain_name, dutils=dutils)
if check_status['volume']:
remove_schain_volume(schain_name, dutils=dutils)
if check_status['firewall_rules']:
conf = ConfigFileManager(schain_name).skaled_config
base_port = get_base_port_from_config(conf)
own_ip = get_own_ip_from_config(conf)
node_ips = get_node_ips_from_config(conf)
ranges = []
if estate is not None:
ranges = estate.ranges
rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges)
rc.cleanup()
if any(checks.firewall_rules.data):
logger.info('Cleaning firewall for %s', schain_name)
cleanup_firewall_for_schain(schain_name)

if estate is not None and estate.ima_linked:
if check_status.get('ima_container', False) or is_exited(
schain_name, container_type=ContainerType.ima, dutils=dutils
Expand Down
21 changes: 17 additions & 4 deletions core/schains/firewall/firewall_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,32 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None:
for rule in rules:
self.host_controller.remove_rule(rule)

def flush(self) -> None:
self.remove_rules(self.rules)
self.host_controller.cleanup()


class IptablesSChainFirewallManager(SChainFirewallManager):
def create_host_controller(self) -> IptablesController:
return IptablesController()

def cleanup(self) -> None:
self.remove_rules(self.rules)


class NFTSchainFirewallManager(SChainFirewallManager):
def create_host_controller(self) -> NFTablesController:
nc_controller = NFTablesController(chain=self.name)
nc_controller.create_table()
nc_controller.create_chain(self.first_port, self.last_port)
return nc_controller

def rules_saved(self) -> bool:
saved = self.host_controller.get_saved_rules()
if saved == '':
return False
return saved == self.host_controller.get_plain_chain_rules()

def base_config_applied(self) -> bool:
return self.host_controller.has_chain(self.host_controller.chain) and \
self.host_controller.has_drop_rule(self.first_port, self.last_port)

def cleanup(self) -> None:
self.host_controller.cleanup()
self.host_controller.remove_saved_rules()
60 changes: 52 additions & 8 deletions core/schains/firewall/nftables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
#
# This file is part of SKALE Admin
#
# Copyright (C) 2024 SKALE Labs
Expand Down Expand Up @@ -45,12 +44,13 @@ class NFTablesController(IHostFirewallController):
plock = multiprocessing.Lock()
FAMILY = 'inet'

def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None:
def __init__(self, chain: str, table: str = TABLE) -> None:
self.table = table
self.chain = f'skale-{chain}'
self._nftables = importlib.import_module('nftables')
self.nft = self._nftables.Nftables()
self.nft.set_json_output(True)
self.nft.set_stateless_output(True)

@classmethod
def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list:
Expand Down Expand Up @@ -124,7 +124,7 @@ def create_table(self) -> None:
if not self.has_table(self.table):
return self.run_cmd(f'add table inet {self.table}')

def add_schain_drop_rule(self, first_port: int, last_port: int) -> None:
def has_drop_rule(self, first_port: int, last_port: int) -> bool:
expr = [
{
'match': {
Expand All @@ -137,7 +137,22 @@ def add_schain_drop_rule(self, first_port: int, last_port: int) -> None:
{'drop': None},
]

if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'):
return self.expr_to_rule(expr) in self.get_rules_by_policy(policy='drop')

def add_schain_drop_rule(self, first_port: int, last_port: int) -> None:
if not self.has_drop_rule(first_port, last_port):
expr = [
{
'match': {
'op': '==',
'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}},
'right': {'range': [first_port, last_port]},
}
},
{'counter': None},
{'drop': None},
]

cmd = {
'nftables': [
{
Expand Down Expand Up @@ -177,7 +192,8 @@ def create_chain(self, first_port: int, last_port: int) -> None:
]
)
)
self.add_schain_drop_rule(first_port, last_port)
self.add_schain_drop_rule(first_port, last_port)
self.save_rules()

def delete_chain(self) -> None:
if self.has_chain(self.chain):
Expand Down Expand Up @@ -329,19 +345,47 @@ def get_plain_chain_rules(self) -> str:
self.nft.set_json_output(False)
output = ''
try:
rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}')
rc, output, error = self.run_cmd(
f'list chain {self.FAMILY} {self.table} {self.chain}'
)
if rc != 0:
raise NFTablesCmdFailedError(f"Failed to get table content: {error}")
finally:
self.nft.set_json_output(True)

lines = output.split('\n')
# cleanup table header
if lines[-1] == '':
lines = lines[1:-2]
else:
lines = lines[1:-1]

# remove leading tab
lines = list(map(lambda line: line[1:], lines))
# Adding new line at the end to prevent validation failure
lines.append('')
output = '\n'.join(lines)
return output

@property
def nft_chain_path(self) -> str:
return os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf')

def save_rules(self) -> None:
logger.info('Saving the firewall rules for chain %s', self.chain)
chain_rules = self.get_plain_chain_rules()
nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf')
with open(nft_chain_path, 'w') as nft_chain_file:
with open(self.nft_chain_path, 'w') as nft_chain_file:
nft_chain_file.write(chain_rules)

def get_saved_rules(self) -> str:
if not os.path.isfile(self.nft_chain_path):
return ''
with open(self.nft_chain_path, 'r') as nft_chain_file:
return nft_chain_file.read()

def remove_saved_rules(self) -> None:
if os.path.isfile(self.nft_chain_path):
os.remove(self.nft_chain_path)

def cleanup(self) -> None:
self.delete_chain()
26 changes: 23 additions & 3 deletions core/schains/firewall/rule_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ def sync(self) -> None:
logger.debug('Syncing firewall rules with %s', erules)
self.firewall_manager.update_rules(erules)

def cleanup(self) -> None:
self.firewall_manager.flush()


class IptablesSChainRuleController(SChainRuleController):
@configured_only
Expand All @@ -215,6 +212,18 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager:
self.base_port + self.ports_per_schain - 1 # type: ignore
)

@configured_only
def is_persistent(self) -> bool:
return True

@configured_only
def is_inited(self) -> bool:
return True

@configured_only
def cleanup(self) -> None:
self.firewall_manager.cleanup()


class NFTSchainRuleController(SChainRuleController):
@configured_only
Expand All @@ -224,3 +233,14 @@ def create_firewall_manager(self) -> NFTSchainFirewallManager:
self.base_port, # type: ignore
self.base_port + self.ports_per_schain - 1 # type: ignore
)

@configured_only
def is_persistent(self) -> bool:
return self.firewall_manager.rules_saved()

@configured_only
def is_inited(self) -> bool:
return self.firewall_manager.base_config_applied()

def cleanup(self) -> None:
self.firewall_manager.cleanup()
10 changes: 9 additions & 1 deletion core/schains/firewall/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: # pragma: no cover
pass

@abstractmethod
def flush(self) -> None: # pragma: no cover # noqa
def cleanup(self) -> None: # pragma: no cover # noqa
pass


Expand Down Expand Up @@ -139,3 +139,11 @@ def sync(self) -> None: # pragma: no cover
@abstractmethod
def cleanup(self) -> None: # pragma: no cover
pass

@abstractmethod
def is_persistent(self) -> bool: # pragma: no cover
pass

@abstractmethod
def is_inited(self) -> bool: # pragma: no cover
pass
7 changes: 7 additions & 0 deletions core/schains/firewall/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from skale import Skale

from .types import IpRange
from .nftables import NFTablesController
from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController


Expand Down Expand Up @@ -101,3 +102,9 @@ def save_sync_ranges(sync_agent_ranges: List[IpRange], path: str) -> None:

def ranges_from_plain_tuples(plain_ranges: List[Tuple]) -> List[IpRange]:
return list(sorted(map(lambda r: IpRange(*r), plain_ranges)))


def cleanup_firewall_for_schain(schain_name: str) -> None:
nft = NFTablesController(chain=schain_name)
nft.cleanup()
nft.remove_saved_rules()
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,13 @@ def ncli_status(_schain_name):
yield init_node_cli_status(_schain_name)
finally:
shutil.rmtree(schain_dir_path, ignore_errors=True)


@pytest.fixture()
def nft_chain_folder():
path = '/etc/nft.conf.d/skale/chains'
try:
os.makedirs(path)
yield path
finally:
shutil.rmtree(path)
Loading
Loading