diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..4370e6b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,16 @@ +[run] +branch = True +source = ./ + +[report] +exclude_lines = + if self.debug: + pragma: no cover + raise NotImplementedError + if __name__ == .__main__.: + def get_args + def main +ignore_errors = True +omit = + tests/* + examples/* diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..9e296ac --- /dev/null +++ b/.travis.yml @@ -0,0 +1,42 @@ +sudo: required +dist: "bionic" + +language: python + +python: + - "3.6" + +install: + # Travis recently added systemd-resolvd to their VMs. Since full Geneva often runs its own DNS + # server to test DNS strategies, we need to disable system-resolvd. + # First disable the service + - sudo systemctl disable systemd-resolved.service + # Stop the service + - sudo systemctl stop systemd-resolved + # With systemd not running, our own hostname won't resolve - this causes issues with sudo. + # Add back our hostname to /etc/hosts/ so sudo does not complain + - echo $(hostname -I | cut -d\ -f1) $(hostname) | sudo tee -a /etc/hosts + # Replace the 127.0.0.53 nameserver with Google's + - sudo sed 's/nameserver.*/nameserver 8.8.8.8/' /etc/resolv.conf > /tmp/resolv.conf.new + - sudo mv /tmp/resolv.conf.new /etc/resolv.conf + # Now that systemd-resolv.conf is safely disabled, we can now setup for Geneva + - sudo apt-get clean # travis having mirror sync issues + # Install dependencies + - sudo apt-get update + - sudo apt-get -y install libnetfilter-queue-dev python3 python3-pip python3-setuptools graphviz + # Since sudo is required but travis does not set up the root environment, we must override the + # secure_path in sudoers in order for travis's setup to take effect for sudo commands + - printf "Defaults\tenv_reset\nDefaults\tmail_badpass\nDefaults\tsecure_path="/home/travis/virtualenv/python3.6.7/bin/:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin"\nroot\tALL=(ALL:ALL) ALL\n#includedir /etc/sudoers.d\n" > /tmp/sudoers.tmp + # Verify the sudoers file + - sudo visudo -c -f /tmp/sudoers.tmp + # Copy in the sudoers file + - sudo cp /tmp/sudoers.tmp /etc/sudoers + # Now that sudo is good to go, finish installing dependencies + - sudo python3 -m pip install -r requirements.txt + - sudo python3 -m pip install slackclient pytest-cov + +script: + - sudo python3 -m pytest --cov=./ -sv tests/ --tb=short + +after_script: + - bash <(curl -s https://codecov.io/bash) -t 83a45966-78ce-44c2-80b3-964ecab4a53d || echo "Codecov did not collect coverage reports" diff --git a/README.md b/README.md index e6a7302..f04ae1b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Geneva +# Geneva [![Build Status](https://travis-ci.com/Kkevsterrr/geneva.svg?branch=master)](https://travis-ci.com/Kkevsterrr/geneva) [![codecov](https://codecov.io/gh/Kkevsterrr/geneva/branch/master/graph/badge.svg)](https://codecov.io/gh/Kkevsterrr/geneva) Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors, such as those in China, India, and Kazakhstan. Unlike many other anti-censorship solutions which require assistance from outside the censoring regime (Tor, VPNs, etc.), Geneva runs strictly on the client. diff --git a/actions/duplicate.py b/actions/duplicate.py index 9f9e75c..17036a6 100644 --- a/actions/duplicate.py +++ b/actions/duplicate.py @@ -13,9 +13,3 @@ def run(self, packet, logger): """ logger.debug(" - Duplicating given packet %s" % str(packet)) return packet, packet.copy() - - def mutate(self, environment_id=None): - """ - Swaps its left and right child - """ - self.left, self.right = self.right, self.left diff --git a/actions/fragment.py b/actions/fragment.py index e16a94e..d7f80c3 100644 --- a/actions/fragment.py +++ b/actions/fragment.py @@ -196,22 +196,3 @@ def parse(self, string, logger): self.correct_order = False return True - - def mutate(self, environment_id=None): - """ - Mutates the fragment action - it either chooses a new segment offset, - switches the packet order, and/or changes whether it segments or fragments. - """ - self.correct_order = self.get_rand_order() - self.segment = random.choice([True, True, True, False]) - if self.segment: - if random.random() < 0.5: - self.fragsize = int(random.uniform(1, 60)) - else: - self.fragsize = -1 - else: - if random.random() < 0.2: - self.fragsize = int(random.uniform(1, 50)) - else: - self.fragsize = -1 - return self diff --git a/actions/layer.py b/actions/layer.py index 8b976c5..1f16991 100644 --- a/actions/layer.py +++ b/actions/layer.py @@ -4,7 +4,8 @@ import os import urllib.parse -from scapy.all import IP, RandIP, UDP, Raw, TCP, fuzz +from scapy.all import IP, RandIP, UDP, DNS, DNSQR, Raw, TCP, fuzz + class Layer(): """ @@ -179,6 +180,12 @@ def set_load(self, packet, field, value): value = urllib.parse.unquote(value) value = value.encode('utf-8') + # Add support for injecting arbitrary protocol payloads if requested + dns_payload = b"\x009ib\x81\x80\x00\x01\x00\x01\x00\x00\x00\x01\x08examples\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x01+\x00\x04\xc7\xbf2I\x00\x00)\x02\x00\x00\x00\x00\x00\x00\x00" + http_payload = b"GET / HTTP/1.1\r\nHost: www.example.com\r\n\r\n" + + value = value.replace(b"__DNS_REQUEST__", dns_payload) + value = value.replace(b"__HTTP_REQUEST__", http_payload) self.layer.payload = Raw(value) @@ -592,3 +599,237 @@ def __init__(self, layer): self.generators = { 'load' : self.gen_load, } + + +class DNSLayer(Layer): + """ + Defines an interface to access DNS header fields. + """ + name = "DNS" + protocol = DNS + _fields = [ + "id", + "qr", + "opcode", + "aa", + "tc", + "rd", + "ra", + "z", + "ad", + "cd", + "qd", + "rcode", + "qdcount", + "ancount", + "nscount", + "arcount" + ] + fields = _fields + def __init__(self, layer): + """ + Initializes the DNS layer. + """ + Layer.__init__(self, layer) + + self.getters = { + "qr" : self.get_bitfield, + "aa" : self.get_bitfield, + "tc" : self.get_bitfield, + "rd" : self.get_bitfield, + "ra" : self.get_bitfield, + "z" : self.get_bitfield, + "ad" : self.get_bitfield, + "cd" : self.get_bitfield + } + + self.setters = { + "qr" : self.set_bitfield, + "aa" : self.set_bitfield, + "tc" : self.set_bitfield, + "rd" : self.set_bitfield, + "ra" : self.set_bitfield, + "z" : self.set_bitfield, + "ad" : self.set_bitfield, + "cd" : self.set_bitfield + } + + self.generators = { + "id" : self.gen_id, + "qr" : self.gen_bitfield, + "opcode" : self.gen_opcode, + "aa" : self.gen_bitfield, + "tc" : self.gen_bitfield, + "rd" : self.gen_bitfield, + "ra" : self.gen_bitfield, + "z" : self.gen_bitfield, + "ad" : self.gen_bitfield, + "cd" : self.gen_bitfield, + "rcode" : self.gen_rcode, + "qdcount" : self.gen_count, + "ancount" : self.gen_count, + "nscount" : self.gen_count, + "arcount" : self.gen_count + } + + def get_bitfield(self, field): + """""" + return int(getattr(self.layer, field)) + + def set_bitfield(self, packet, field, value): + """""" + return setattr(self.layer, field, int(value)) + + def gen_bitfield(self, field): + """""" + return random.choice([0,1]) + + def gen_id(self, field): + return random.randint(0, 65535) + + def gen_opcode(self, field): + return random.randint(0, 15) + + def gen_rcode(self, field): + return random.randint(0, 15) + + def gen_count(self, field): + return random.randint(0, 65535) + + @staticmethod + def dns_decompress(packet, logger): + """ + Performs DNS decompression on the given scapy packet, if applicable. + Note that DNS compression/decompression must be done on the boundaries + of a label, so DNS compression does not support arbitrary offsets. + """ + # If this is a TCP packet + if packet.haslayer("TCP"): + raise NotImplementedError + + # Perform no action if this is not a DNS or DNSRQ packet + if not packet.haslayer("DNS") or not packet.haslayer("DNSQR"): + return packet + + # Extract the query from the DNSQR layer + query = packet["DNSQR"].qname.decode() + if query[len(query) - 1] != '.': + query += '.' + + # Split the query by label + labels = query.split(".") + + # Collect the first and second half of the query + fhalf = labels[0] + shalf = ".".join(labels[1:]) + + # Build the first DNS query directly. The format of this a byte string like this: + # b'\x07minghui\xc0\x1a\x00\x01\x00\x01' + # \x07 = the length of the label in this DNSQR + # minghui = the portion of the domain we will request in the first DNSQR + # \xc0\x1a = offset into the DNS packet where the rest of the query will be. The actual offset + # here is the \x1a - DNS mandates that if compression is used, the first two bits be 11 + # to differentiate them from the rest. \x1A = 26, which is the length of the DNS header + # plus the length of this DNSQR. + # \x00\x01 = type A record + # \x00\x01 = IN + length = bytes([len(fhalf)]) + label = fhalf.encode() + + # Since the domain will include an extra ".", add 1 + # 2 * 6 is the DNS header + # 1 is the byte that determines the length of the label + # len(label) is the length of the label + # 2 is the offset pointer + # 4 - other record information (class, IN) + packet_offset = 2 * 6 + 1 + len(label) + 2 + 2 + 2 + + # The word must start with binary 11, so OR the offset with 0xC000. + offset = (0xc000 | packet_offset).to_bytes(2, byteorder='big') + request = b'\x00\x01\x00\x01' + + dns_qr1 = length + label + offset + request + + # Build the second DNS query directly. The format of the byte string is the same as above + # b'\x02ca\x00\x00\x01\x00\x01' + # \x02 = length of the remaining domain + # ca = portion of the domain in this DNSQR + # \x00 = null byte to signify the end of the query + # \x00\x01 = type A record + # \x00\x01 = IN + # Since the second half could potentially contain many labels, this is done in a list comprehension + dns_qr2 = b"".join([bytes([len(tld)]) + tld.encode() for tld in shalf.split(".")]) + b"\x00\x01\x00\x01" + + # Next, we must rebuild the DNS packet itself. If we try to have scapy parse either dns_qr1 or dns_qr2, they + # will look malformed, since neither contains a complete request. Therefore, we must build the entire + # DNS packet at once. First, we must remove the original DNSQR, since this contains the original request + del packet["DNS"].qd + + # Once the DNSQR is removed, scapy automatically sets the qdcount to 0. Adjust it to 2 + packet["DNS"].qdcount = 2 + + # Extract the DNS header standalone now for building + dns_header = bytes(packet["DNS"]) + + dns_packet = DNS(dns_header + dns_qr1 + dns_qr2) + + del packet["DNS"] + packet = packet / dns_packet + + # Since the size and data of the packet have changed, force scapy to recalculate the important fields + # in below layers, if applicable + if packet.haslayer("IP"): + del packet["IP"].chksum + del packet["IP"].len + if packet.haslayer("UDP"): + del packet["UDP"].chksum + del packet["UDP"].len + + return packet + + +class DNSQRLayer(Layer): + """ + Defines an interface to access DNSQR header fields. + """ + name = "DNSQR" + protocol = DNSQR + _fields = [ + "qname", + "qtype", + "qclass" + ] + fields = _fields + + def __init__(self, layer): + """ + Initializes the DNS layer. + """ + Layer.__init__(self, layer) + self.getters = { + "qname" : self.get_qname + } + self.generators = { + "qname" : self.gen_qname + } + + def get_qname(self, field): + """ + Returns decoded qname from packet. + """ + return self.layer.qname.decode('utf-8') + + def gen_qname(self, field): + """ + Generates domain name. + """ + return "example.com." + + @classmethod + def name_matches(cls, name): + """ + Scapy returns the name of DNSQR as _both_ DNSQR and "DNS Question Record", + which breaks parsing. Override the name_matches method to handle that case + here. + """ + return name.upper() in ["DNSQR", "DNS QUESTION RECORD"] diff --git a/actions/packet.py b/actions/packet.py index ed1c575..f40b6d1 100644 --- a/actions/packet.py +++ b/actions/packet.py @@ -7,7 +7,9 @@ _SUPPORTED_LAYERS = [ actions.layer.IPLayer, actions.layer.TCPLayer, - actions.layer.UDPLayer + actions.layer.UDPLayer, + actions.layer.DNSLayer, + actions.layer.DNSQRLayer ] SUPPORTED_LAYERS = _SUPPORTED_LAYERS @@ -64,9 +66,25 @@ def _str_packet(packet): @staticmethod def _str_load(packet, protocol): """ - Prints packet payload - """ - return str(packet[protocol].payload) + Prints DNS header for now + """ + if packet.haslayer("DNS") and packet.haslayer("DNSQR"): + res = "%s:%s:%s " % ( + packet["DNSQR"].qname.decode('utf8'), + str(packet["DNSQR"].qtype), + str(packet["DNSQR"].qclass)) + DNS_res = "" + for i in range(packet["DNS"].ancount): + dnsrr = packet["DNS"].an[i] + DNS_res += " " + ':'.join([str(dnsrr.rrname.decode('utf8')), + str(dnsrr.type), + str(dnsrr.rclass), + str(dnsrr.ttl), + str(dnsrr.rdlen), + str(dnsrr.rdata)]) + return "%s %s" % (res, DNS_res) + else: + return str(packet[protocol].payload) def __bytes__(self): """ @@ -238,3 +256,77 @@ def get_supported_protocol(protocol): return layer return None + + @staticmethod + def reset_restrictions(): + """ + Removes layer and field restrictions. + """ + global SUPPORTED_LAYERS, _SUPPORTED_LAYERS + + SUPPORTED_LAYERS = _SUPPORTED_LAYERS + for layer in SUPPORTED_LAYERS: + layer.reset_restrictions() + + @staticmethod + def restrict_fields(logger, filter_protocols, filter_fields, disable_fields): + """ + Validates input arguments. Used by evolve.py to restrict the scope + of this evolution. + """ + global SUPPORTED_LAYERS + + if not disable_fields: + disable_fields = [] + + # First, apply a field whitelist if it was requested + valid = [] + if filter_fields: + for layer in SUPPORTED_LAYERS: + new_fields = [] + for field in filter_fields: + if field in layer.fields: + new_fields.append(field) + valid.append(field) + layer.fields = new_fields + + if valid and logger: + logger.info("Strategies will only be allowed to use fields: %s" % ", ".join(list(set(valid)))) + elif logger: + logger.error("None of the given fields exist in the packet headers of given protocols.") + + # Apply a field blacklist if it was requested + for field in disable_fields: + for layer in SUPPORTED_LAYERS: + layer.fields = [f for f in layer.fields if f not in disable_fields] + + if disable_fields and logger: + logger.info("Strategies will not be allowed to use fields %s" % ", ".join(disable_fields)) + + allowed_layers = [] + # Finally, filter protocols + for protocol in filter_protocols: + allowed_layer = Packet.get_supported_protocol(protocol) + if not allowed_layer: + if logger: + logger.error("%s not a supported protocol." % protocol) + continue + + # Only keep the layer allowed if it contains allowed fields + if allowed_layer.fields: + allowed_layers.append(allowed_layer) + + assert allowed_layers, "Cannot evolve with no available packet layers!" + + SUPPORTED_LAYERS = allowed_layers + + if logger and allowed_layers: + logger.info("Strategies will only be allowed to use protocols: %s" % ", ".join([l.name for l in allowed_layers])) + + def dns_decompress(self, logger): + """ + Performs DNS decompression, if applicable. Returns a new packet. + """ + self.packet = actions.layer.DNSLayer.dns_decompress(self.packet, logger) + self.layers = self.setup_layers() + return self diff --git a/actions/sleep.py b/actions/sleep.py index baf33f2..1d3c479 100644 --- a/actions/sleep.py +++ b/actions/sleep.py @@ -2,7 +2,7 @@ class SleepAction(Action): def __init__(self, time=1, environment_id=None): - Action.__init__(self, "sleep", "out") + Action.__init__(self, "sleep", "both") self.terminal = False self.branching = False self.time = time diff --git a/actions/sniffer.py b/actions/sniffer.py deleted file mode 100644 index d4fc543..0000000 --- a/actions/sniffer.py +++ /dev/null @@ -1,85 +0,0 @@ -import threading -import os - -import actions.packet -from scapy.all import sniff -from scapy.utils import PcapWriter - - -class Sniffer(): - """ - The sniffer class lets the user begin and end sniffing whenever in a given location with a port to filter on. - Call start_sniffing to begin sniffing and stop_sniffing to stop sniffing. - """ - - def __init__(self, location, port, logger): - """ - Intializes a sniffer object. - Needs a location and a port to filter on. - """ - self.stop_sniffing_flag = False - self.location = location - self.port = port - self.pcap_thread = None - self.packet_dumper = None - self.logger = logger - full_path = os.path.dirname(location) - assert port, "Need to specify a port in order to launch a sniffer" - if not os.path.exists(full_path): - os.makedirs(full_path) - - def __packet_callback(self, scapy_packet): - """ - This callback is called whenever a packet is applied. - Returns true if it should finish, otherwise, returns false. - """ - packet = actions.packet.Packet(scapy_packet) - for proto in ["TCP", "UDP"]: - if(packet.haslayer(proto) and ((packet[proto].sport == self.port) or (packet[proto].dport == self.port))): - break - else: - return self.stop_sniffing_flag - - self.logger.debug(str(packet)) - self.packet_dumper.write(scapy_packet) - return self.stop_sniffing_flag - - def __spawn_sniffer(self): - """ - Saves pcaps to a file. Should be run as a thread. - Ends when the stop_sniffing_flag is set. Should not be called by user - """ - self.packet_dumper = PcapWriter(self.location, append=True, sync=True) - while(self.stop_sniffing_flag == False): - sniff(stop_filter=self.__packet_callback, timeout=1) - - def start_sniffing(self): - """ - Starts sniffing. Should be called by user. - """ - self.stop_sniffing_flag = False - self.pcap_thread = threading.Thread(target=self.__spawn_sniffer) - self.pcap_thread.start() - self.logger.debug("Sniffer starting to port %d" % self.port) - - def __enter__(self): - """ - Defines a context manager for this sniffer; simply starts sniffing. - """ - self.start_sniffing() - return self - - def __exit__(self, exc_type, exc_value, tb): - """ - Defines exit context manager behavior for this sniffer; simply stops sniffing. - """ - self.stop_sniffing() - - def stop_sniffing(self): - """ - Stops the sniffer by setting the flag and calling join - """ - if(self.pcap_thread): - self.stop_sniffing_flag = True - self.pcap_thread.join() - self.logger.debug("Sniffer stopping") diff --git a/actions/tamper.py b/actions/tamper.py index 9deaf47..480cb33 100644 --- a/actions/tamper.py +++ b/actions/tamper.py @@ -2,17 +2,25 @@ TamperAction One of the four packet-level primitives supported by Geneva. Responsible for any packet-level -modifications (particularly header modifications). It supports replace and corrupt mode - -in replace mode, it changes a packet field to a fixed value; in corrupt mode, it changes a packet -field to a randomly generated value each time it is run. +modifications (particularly header modifications). It supports the following primitives: + - no operation: it returns the packet given + - replace: it changes a packet field to a fixed value + - corrupt: it changes a packet field to a randomly generated value each time it is run + - add: adds a given value to the value in a field + - compress: performs DNS decompression on the packet (if applicable) """ from actions.action import Action import actions.utils +from actions.layer import DNSLayer import random +# All supported tamper primitives +SUPPORTED_PRIMITIVES = ["corrupt", "replace", "add", "compress"] + + class TamperAction(Action): """ Defines the TamperAction for Geneva. @@ -23,10 +31,7 @@ def __init__(self, environment_id=None, field=None, tamper_type=None, tamper_val self.tamper_value = tamper_value self.tamper_proto = actions.utils.string_to_protocol(tamper_proto) self.tamper_proto_str = tamper_proto - self.tamper_type = tamper_type - if not self.tamper_type: - self.tamper_type = random.choice(["corrupt", "replace"]) def tamper(self, packet, logger): """ @@ -41,8 +46,19 @@ def tamper(self, packet, logger): new_value = self.tamper_value # If corrupting the packet field, generate a value for it - if self.tamper_type == "corrupt": - new_value = packet.gen(self.tamper_proto_str, self.field) + try: + if self.tamper_type == "corrupt": + new_value = packet.gen(self.tamper_proto_str, self.field) + elif self.tamper_type == "add": + new_value = int(self.tamper_value) + int(old_value) + elif self.tamper_type == "compress": + return packet.dns_decompress(logger) + except NotImplementedError: + # If a primitive does not support the type of packet given + return packet + except Exception: + # If an unexpected error has occurred + return packet logger.debug(" - Tampering %s field `%s` (%s) by %s (to %s)" % (self.tamper_proto_str, self.field, str(old_value), self.tamper_type, str(new_value))) @@ -67,8 +83,10 @@ def __str__(self): s = Action.__str__(self) if self.tamper_type == "corrupt": s += "{%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type) - elif self.tamper_type in ["replace"]: + elif self.tamper_type in ["replace", "add"]: s += "{%s:%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type, self.tamper_value) + elif self.tamper_type == "compress": + s += "{%s:%s:compress}" % ("DNS", "qd", ) return s diff --git a/actions/trace.py b/actions/trace.py index 3d7af87..c1df0d7 100644 --- a/actions/trace.py +++ b/actions/trace.py @@ -65,14 +65,15 @@ def parse(self, string, logger): """ Parses a string representation for this object. """ + if not string: + return False try: - if string: - self.start_ttl, self.end_ttl = string.split(":") - self.start_ttl = int(self.start_ttl) - self.end_ttl = int(self.end_ttl) - if self.start_ttl > self.end_ttl: - logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl)) - return False + self.start_ttl, self.end_ttl = string.split(":") + self.start_ttl = int(self.start_ttl) + self.end_ttl = int(self.end_ttl) + if self.start_ttl > self.end_ttl: + logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl)) + return False except ValueError: logger.exception("Cannot parse ttls from given data %s" % string) return False diff --git a/actions/tree.py b/actions/tree.py index 4303fe6..c757bde 100644 --- a/actions/tree.py +++ b/actions/tree.py @@ -30,22 +30,6 @@ def __init__(self, direction, trigger=None): self.environment_id = None self.ran = False - def initialize(self, num_actions, environment_id, allow_terminal=True, disabled=None): - """ - Sets up this action tree with a given number of random actions. - Note that the returned action trees may have less actions than num_actions - if terminal actions are used. - """ - self.environment_id = environment_id - self.trigger = actions.trigger.Trigger(None, None, None, environment_id=environment_id) - if not allow_terminal or random.random() > 0.1: - allow_terminal = False - - for _ in range(num_actions): - new_action = self.get_rand_action(self.direction, disabled=disabled) - self.add_action(new_action) - return self - def __iter__(self): """ Sets up a preoder iterator for the tree. diff --git a/actions/trigger.py b/actions/trigger.py index 5453270..5815b6d 100644 --- a/actions/trigger.py +++ b/actions/trigger.py @@ -27,20 +27,6 @@ def __init__(self, trigger_type, trigger_field, trigger_proto, trigger_value=0, self.bomb_trigger = bool(gas and gas < 0) self.ran = False - @staticmethod - def get_gas(): - """ - Returns a random value for gas for this trigger. - """ - if GAS_ENABLED and random.random() < 0.2: - # Use gas in 20% of scenarios - # Pick a number for gas between 0 - 5 - gas_remaining = int(random.random() * 5) - else: - # Do not use gas - gas_remaining = None - return gas_remaining - def is_applicable(self, packet, logger): """ Checks if this trigger is applicable to a given packet. diff --git a/actions/utils.py b/actions/utils.py index 5520846..a795442 100644 --- a/actions/utils.py +++ b/actions/utils.py @@ -119,7 +119,7 @@ def get_logger(basepath, log_dir, logger_name, log_name, environment_id, log_lev ch = logging.StreamHandler() ch.setFormatter(formatter) ch.setLevel(log_level) - CONSOLE_LOG_LEVEL = log_level + CONSOLE_LOG_LEVEL = ch.level logger.addHandler(ch) return logger @@ -135,34 +135,6 @@ def close_logger(logger): handler.close() -class Logger(): - """ - Logging class context manager, as a thin wrapper around the logging class to help - handle closing open file descriptors. - """ - def __init__(self, log_dir, logger_name, log_name, environment_id, log_level=logging.DEBUG): - self.log_dir = log_dir - self.logger_name = logger_name - self.log_name = log_name - self.environment_id = environment_id - self.log_level = log_level - self.logger = None - - def __enter__(self): - """ - Sets up a logger. - """ - self.logger = get_logger(PROJECT_ROOT, self.log_dir, self.logger_name, self.log_name, self.environment_id, log_level=self.log_level) - return self.logger - - def __exit__(self, exc_type, exc_value, tb): - """ - Closes file handles. - """ - close_logger(self.logger) - - - def get_console_log_level(): """ returns log level of console handler @@ -205,18 +177,6 @@ def setup_dirs(output_dir): return ga_log_dir -def get_from_fuzzed_or_real_packet(environment_id, real_packet_probability, enable_options=True, enable_load=True): - """ - Retrieves a protocol, field, and value from a fuzzed or real packet, depending on - the given probability and if given packets is not None. - """ - packets = actions.utils.read_packets(environment_id) - if packets and random.random() < real_packet_probability: - packet = random.choice(packets) - return packet.get_random() - return actions.packet.Packet().gen_random() - - def get_interface(): """ Chooses an interface on the machine to use for socket testing. diff --git a/engine.py b/engine.py index d3b5879..0f6cf08 100644 --- a/engine.py +++ b/engine.py @@ -67,8 +67,6 @@ def __init__(self, server_port, string_strategy, environment_id=None, output_dir self.server_port = server_port self.seen_packets = [] # Set up the directory and ID for logging - if not output_directory: - output_directory = "trials" actions.utils.setup_dirs(output_directory) if not environment_id: environment_id = actions.utils.get_id() @@ -452,9 +450,8 @@ def in_callback(self, nfpacket): # Run the given strategy packets = self.strategy.act_on_packet(packet, self.logger, direction="in") - # GFW will send RA packets to disrupt a TCP stream + # Censors will often send RA packets to disrupt a TCP stream - record this if packet.haslayer("TCP") and packet.get("TCP", "flags") == "RA": - self.logger.debug("Detected GFW censorship - strategy failed.") self.censorship_detected = True # Branching is disabled for the in direction, so we can only ever get diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..88584c1 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,72 @@ +import os +import sys + +# Add the path to the engine so we can import it +BASEPATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(BASEPATH) + +import engine + +def test_engine(): + """ + Basic engine test + """ + # Port to run the engine on + port = 80 + # Strategy to use + strategy = "[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/" + + # Create the engine in debug mode + with engine.Engine(port, strategy, log_level="debug") as eng: + os.system("curl http://example.com?q=ultrasurf") + + +def test_engine_sleep(): + """ + Basic engine test with sleep action + """ + # Port to run the engine on + port = 80 + # Strategy to use + strategy = "[TCP:flags:S]-sleep{1}-|" + + # Create the engine in debug mode + with engine.Engine(port, strategy, log_level="info") as eng: + os.system("curl http://example.com?q=ultrasurf") + + # Strategy to use in opposite direction + strategy = "\/ [TCP:flags:SA]-sleep{1}-|" + + # Create the engine in debug mode + with engine.Engine(port, strategy, log_level="debug") as eng: + os.system("curl http://example.com?q=ultrasurf") + + + +def test_engine_trace(): + """ + Basic engine test with trace + """ + # Port to run the engine on + port = 80 + # Strategy to use + strategy = "[TCP:flags:PA]-trace{2:10}-|" + + # Create the engine in debug mode + with engine.Engine(port, strategy, log_level="debug") as eng: + os.system("curl -m 5 http://example.com?q=ultrasurf") + + +def test_engine_drop(): + """ + Basic engine test with drop + """ + # Port to run the engine on + port = 80 + # Strategy to use + strategy = "\/ [TCP:flags:SA]-drop-|" + + # Create the engine in debug mode + with engine.Engine(port, strategy, log_level="debug") as eng: + os.system("curl -m 3 http://example.com?q=ultrasurf") + diff --git a/tests/test_fragment.py b/tests/test_fragment.py new file mode 100644 index 0000000..3e7bcee --- /dev/null +++ b/tests/test_fragment.py @@ -0,0 +1,220 @@ +import logging +import pytest +import sys +# Include the root of the project +sys.path.append("..") + +import actions.fragment +import actions.packet +import actions.strategy +import actions.utils + +from scapy.all import IP, TCP, UDP + +logger = logging.getLogger("test") + + +def test_segment(): + """ + Tests the duplicate action primitive. + """ + fragment = actions.fragment.FragmentAction(correct_order=True) + assert str(fragment) == "fragment{tcp:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + + assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different" + assert packet1["Raw"].load == b'da', "Left packet incorrectly fragmented" + assert packet2["Raw"].load == b"ta", "Right packet incorrectly fragmented" + + +def test_segment_reverse(): + """ + Tests the duplicate action primitive in reverse! + """ + fragment = actions.fragment.FragmentAction(correct_order=False) + assert str(fragment) == "fragment{tcp:-1:False}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + + assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different" + assert packet1["Raw"].load == b'ta', "Left packet incorrectly fragmented" + assert packet2["Raw"].load == b"da", "Right packet incorrectly fragmented" + + +def test_odd_fragment(): + """ + Tests long IP fragmentation + """ + + fragment = actions.fragment.FragmentAction(correct_order=True, segment=False) + assert str(fragment) == "fragment{ip:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("dataisodd")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + + assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different" + assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d', "Left packet incorrectly fragmented" + assert packet2["Raw"].load == b'\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Right packet incorrectly fragmented" + assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Packets fragmentation was incorrect" + + +def test_custom_fragment(): + """ + Tests IP fragments with custom sized lengths + """ + + fragment = actions.fragment.FragmentAction(correct_order=True, fragsize=3, segment=False) + assert str(fragment) == "fragment{ip:3:True}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different" + assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented" + assert packet2["Raw"].load == b'issomedata', "Right packet incorrectly fragmented" + assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect" + + +def test_reverse_fragment(): + """ + Tests fragmentation with reversed packets + """ + + fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=3, segment=False) + assert str(fragment) == "fragment{ip:3:False}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different" + assert packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented" + assert packet1["Raw"].load == b'issomedata', "Right packet incorrectly fragmented" + assert packet2["Raw"].load + packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect" + + +def test_udp_fragment(): + """ + Tests fragmentation with reversed packets + """ + + fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False) + assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different" + + +def test_parse(): + """ + Tests parsing. + """ + fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False) + assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment) + + fragment.parse("fragment{tcp:5:False}", logger) + assert fragment.correct_order == False + assert fragment.fragsize == 5 + assert fragment.segment == True + + with pytest.raises(Exception): + fragment.parse("fragment{tcp:5}", logger) + + with pytest.raises(Exception): + fragment.parse("fragment{tcp:a:True}", logger) + + assert fragment.correct_order == False + assert fragment.fragsize == 5 + assert fragment.segment == True + + fragment = actions.fragment.FragmentAction() + assert fragment.correct_order in [True, False] + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + + strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger) + strat.act_on_packet(packet, logger) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/UDP(sport=2222, dport=3333, chksum=0x4444)) + strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger) + strat.act_on_packet(packet, logger) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, chksum=0x4444)) + strat = actions.utils.parse("[TCP:urgptr:0]-tamper{TCP:options-altchksumopt:corrupt}(fragment{tcp:-1:True}(tamper{IP:proto:corrupt},tamper{TCP:seq:replace:654077552}),)-| \/", logger) + strat.act_on_packet(packet, logger) + + strat = actions.utils.parse("[TCP:options-mss:]-tamper{TCP:load:replace:}(fragment{tcp:-1:True},)-| \/", logger) + strat.act_on_packet(packet, logger) + + strat = actions.utils.parse("[TCP:options-mss:]-tamper{IP:frag:replace:1353}(tamper{TCP:load:replace:}(fragment{tcp:-1:True},),)-| \/", logger) + strat.act_on_packet(packet, logger) + + strat = actions.utils.parse("[IP:ihl:5]-duplicate-| [TCP:options-mss:]-tamper{IP:frag:replace:1353}(fragment{tcp:-1:True}(tamper{TCP:load:replace:}(fragment{tcp:-1:False},),tamper{DNSQR:qtype:replace:45416}),)-| \/", logger) + strat.act_on_packet(packet, logger) + + strat = actions.utils.parse("[DNSQR:qclass:25989]-duplicate(duplicate(tamper{DNSQR:qtype:replace:30882},),tamper{UDP:sport:replace:42042})-| [TCP:options-nop:]-tamper{TCP:options-nop:corrupt}(tamper{TCP:load:replace:mjkuskjzgy}(tamper{IP:frag:replace:410}(fragment{tcp:-1:True},),),)-| \/", logger) + strat.act_on_packet(packet, logger) + + +def test_fallback(): + """ + Tests fallback behavior. + """ + fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False) + assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment) + + fragment.parse("fragment{ip:0:False}", logger) + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata")) + packet1, packet2 = fragment.run(packet, logger) + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1) == str(packet2) + + fragment.parse("fragment{tcp:-1:False}", logger) + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata")) + packet1, packet2 = fragment.run(packet, logger) + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1) == str(packet2) + + fragment.parse("fragment{tcp:-1:False}", logger) + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, chksum=0x4444)) + packet1, packet2 = fragment.run(packet, logger) + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1) == str(packet2) + + fragment.parse("fragment{ip:-1:False}", logger) + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)) + packet1, packet2 = fragment.run(packet, logger) + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + assert str(packet1) == str(packet2) + + +def test_ip_only_fragment(): + """ + Tests fragmentation without higher protocols. + """ + + fragment = actions.fragment.FragmentAction(correct_order=True) + fragment.parse("fragment{ip:-1:True}", logger) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/("datadata11datadata")) + packet1, packet2 = fragment.run(packet, logger) + + assert id(packet1) != id(packet2), "Duplicate aliased packet objects" + + assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different" + assert packet1["Raw"].load == b'datadata', "Left packet incorrectly fragmented" + assert packet2["Raw"].load == b"11datadata", "Right packet incorrectly fragmented" + + diff --git a/tests/test_packet.py b/tests/test_packet.py new file mode 100644 index 0000000..63ba0eb --- /dev/null +++ b/tests/test_packet.py @@ -0,0 +1,544 @@ +import logging +import pytest + +import actions.packet +import actions.layer + +from scapy.all import IP, TCP, UDP, DNS, DNSQR, Raw, DNSRR + +logger = logging.getLogger("test") + + +def test_parse_layers(): + """ + Tests layer parsing. + """ + pkt = IP()/TCP()/Raw("") + packet = actions.packet.Packet(pkt) + layers = list(packet.read_layers()) + assert layers[0].name == "IP" + assert layers[1].name == "TCP" + + layers_dict = packet.setup_layers() + assert layers_dict["IP"] + assert layers_dict["TCP"] + + +def test_get_random(): + """ + Tests get random + """ + + tcplayer = actions.layer.TCPLayer(TCP()) + field, value = tcplayer.get_random() + assert field in actions.layer.TCPLayer.fields + + +def test_gen_random(): + """ + Tests gen random + """ + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer in [DNS, TCP, UDP, IP, DNSQR] + + +def test_dnsqr(): + """ + Tests DNSQR. + """ + pkt = UDP()/DNS(ancount=1)/DNSQR() + pkt.show() + packet = actions.packet.Packet(pkt) + packet.show() + assert len(packet.layers) == 3 + assert "UDP" in packet.layers + assert "DNS" in packet.layers + assert "DNSQR" in packet.layers + pkt = IP()/UDP()/DNS()/DNSQR() + packet = actions.packet.Packet(pkt) + assert str(packet) + + +def test_load(): + """ + Tests loads. + """ + tcp = actions.layer.TCPLayer(TCP()) + assert tcp.gen("load") + pkt = IP()/"datadata" + p = actions.packet.Packet(pkt) + assert p.get("IP", "load") == "datadata" + p2 = actions.packet.Packet(IP(bytes(p))) + assert p2.get("IP", "load") == "datadata" + p2.set("IP", "load", "data2") + # Check p is unchanged + assert p.get("IP", "load") == "datadata" + assert p2.get("IP", "load") == "data2" + p2.show2() + # Check that we can dump + assert p2.show2(dump=True) + # Check that we can dump + assert p2.show(dump=True) + assert p2.get("IP", "chksum") == None + + pkt = IP()/TCP()/"datadata" + p = actions.packet.Packet(pkt) + assert p.get("TCP", "load") == "datadata" + p2 = actions.packet.Packet(IP(bytes(p))) + assert p2.get("TCP", "load") == "datadata" + p2.set("TCP", "load", "data2") + # Check p is unchanged + assert p.get("TCP", "load") == "datadata" + assert p2.get("TCP", "load") == "data2" + p2.show2() + assert p2.get("IP", "chksum") == None + + +def test_parse_load(): + """ + Tests load parsing. + """ + pkt = actions.packet.Packet(IP()/TCP()/"TYPE A\r\n") + print("Parsed: %s" % pkt.get("TCP", "load")) + + strat = actions.utils.parse("[TCP:load:TYPE%20A%0D%0A]-drop-| \/", logger) + results = strat.act_on_packet(pkt, logger) + assert not results + + value = pkt.gen("TCP", "load") + " " + pkt.gen("TCP", "load") + pkt.set("TCP", "load", value) + assert " " not in pkt.get("TCP", "load"), "%s contained a space!" % pkt.get("TCP", "load") + + +def test_dns(): + """ + Tests DNS layer. + """ + dns = actions.layer.DNSLayer(DNS()) + print(dns.gen("id")) + assert dns.gen("id") + + p = actions.packet.Packet(DNS(id=0xabcd)) + p2 = actions.packet.Packet(DNS(bytes(p))) + assert p.get("DNS", "id") == 0xabcd + assert p2.get("DNS", "id") == 0xabcd + + p2.set("DNS", "id", 0x4321) + assert p.get("DNS", "id") == 0xabcd # Check p is unchanged + assert p2.get("DNS", "id") == 0x4321 + + dns = actions.packet.Packet(DNS(aa=1)) + assert dns.get("DNS", "aa") == 1 + aa = dns.gen("DNS", "aa") + assert aa == 0 or aa == 1 + assert dns.get("DNS", "aa") == 1 # Original value unchanged + + dns = actions.packet.Packet(DNS(opcode=15)) + assert dns.get("DNS", "opcode") == 15 + opcode = dns.gen("DNS", "opcode") + assert opcode >= 0 and opcode <= 15 + assert dns.get("DNS", "opcode") == 15 # Original value unchanged + + dns.set("DNS", "opcode", 3) + assert dns.get("DNS", "opcode") == 3 + + dns = actions.packet.Packet(DNS(qr=0)) + assert dns.get("DNS", "qr") == 0 + qr = dns.gen("DNS", "qr") + assert qr == 0 or qr == 1 + assert dns.get("DNS", "qr") == 0 # Original value unchanged + + dns.set("DNS", "qr", 1) + assert dns.get("DNS", "qr") == 1 + + dns = actions.packet.Packet(DNS(arcount=0xAABB)) + assert dns.get("DNS", "arcount") == 0xAABB + arcount = dns.gen("DNS", "arcount") + assert arcount >= 0 and arcount <= 0xffff + assert dns.get("DNS", "arcount") == 0xAABB # Original value unchanged + + dns.set("DNS", "arcount", 65432) + assert dns.get("DNS", "arcount") == 65432 + + dns = actions.layer.DNSLayer(DNS()/DNSQR(qname="example.com")) + assert isinstance(dns.get_next_layer(), DNSQR) + print(dns.gen("id")) + assert dns.gen("id") + + p = actions.packet.Packet(DNS(id=0xabcd)) + p2 = actions.packet.Packet(DNS(bytes(p))) + assert p.get("DNS", "id") == 0xabcd + assert p2.get("DNS", "id") == 0xabcd + + +def test_read_layers(): + """ + Tests the ability to read each layer + """ + packet = IP() / UDP() / TCP() / DNS() / DNSQR(qname="example.com") / DNSQR(qname="example2.com") / DNSQR(qname="example3.com") + packet_geneva = actions.packet.Packet(packet) + packet_geneva.setup_layers() + + i = 0 + for layer in packet_geneva.read_layers(): + if i == 0: + assert isinstance(layer, actions.layer.IPLayer) + elif i == 1: + assert isinstance(layer, actions.layer.UDPLayer) + elif i == 2: + assert isinstance(layer, actions.layer.TCPLayer) + elif i == 3: + assert isinstance(layer, actions.layer.DNSLayer) + elif i == 4: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example.com" + elif i == 5: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example2.com" + elif i == 6: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example3.com" + i += 1 + +def test_multi_opts(): + """ + Tests various option getting/setting. + """ + pkt = IP()/TCP(options=[('MSS', 1460), ('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)]) + packet = actions.packet.Packet(pkt) + assert packet.get("TCP", "options-sackok") == '' + assert packet.get("TCP", "options-mss") == 1460 + assert packet.get("TCP", "options-timestamp") == 4154603075 + assert packet.get("TCP", "options-wscale") == 7 + packet.set("TCP", "options-timestamp", 400000000) + assert packet.get("TCP", "options-sackok") == '' + assert packet.get("TCP", "options-mss") == 1460 + assert packet.get("TCP", "options-timestamp") == 400000000 + assert packet.get("TCP", "options-wscale") == 7 + pkt = IP()/TCP(options=[('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)]) + packet = actions.packet.Packet(pkt) + # If the option isn't present, it will be returned as an empty string + assert packet.get("TCP", "options-mss") == '' + packet.set("TCP", "options-mss", "") + assert packet.get("TCP", "options-mss") == 0 + + +def test_options_eol(): + """ + Tests options-eol. + """ + pkt = TCP(options=[("EOL", None)]) + p = actions.packet.Packet(pkt) + assert p.get("TCP", "options-eol") == "" + p2 = actions.packet.Packet(TCP(bytes(p))) + assert p2.get("TCP", "options-eol") == "" + p = actions.packet.Packet(IP()/TCP(options=[])) + assert p.get("TCP", "options-eol") == "" + p.set("TCP", "options-eol", "") + p.show() + assert len(p["TCP"].options) == 1 + assert any(k == "EOL" for k, v in p["TCP"].options) + value = p.gen("TCP", "options-eol") + assert value == "", "eol cannot store data" + p.set("TCP", "options-eol", value) + p2 = TCP(bytes(p)) + assert any(k == "EOL" for k, v in p2["TCP"].options) + + +def test_options_mss(): + """ + Tests options-eol. + """ + pkt = TCP(options=[("MSS", 1440)]) + p = actions.packet.Packet(pkt) + assert p.get("TCP", "options-mss") == 1440 + p2 = actions.packet.Packet(TCP(bytes(p))) + assert p2.get("TCP", "options-mss") == 1440 + p = actions.packet.Packet(TCP(options=[])) + assert p.get("TCP", "options-mss") == "" + p.set("TCP", "options-mss", 2880) + p.show() + assert len(p["TCP"].options) == 1 + assert any(k == "MSS" for k, v in p["TCP"].options) + value = p.gen("TCP", "options-mss") + p.set("TCP", "options-mss", value) + p2 = TCP(bytes(p)) + assert any(k == "MSS" for k, v in p2["TCP"].options) + + +def check_get(protocol, field, value): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = protocol() + setattr(pkt, field, value) + packet = actions.packet.Packet(pkt) + assert packet.get(protocol.__name__, field) == value + + +def get_test_configs(): + """ + Generates test configurations for the getters. + """ + return [ + (IP, 'version', 4), + (IP, 'version', 6), + (IP, 'version', 0), + (IP, 'ihl', 0), + (IP, 'tos', 0), + (IP, 'len', 50), + (IP, 'len', 6), + (IP, 'flags', 'MF'), + (IP, 'flags', 'DF'), + (IP, 'flags', 'MF+DF'), + (IP, 'ttl', 25), + (IP, 'proto', 4), + (IP, 'chksum', 0x4444), + (IP, 'src', '127.0.0.1'), + (IP, 'dst', '127.0.0.1'), + (TCP, 'sport', 12345), + (TCP, 'dport', 55555), + (TCP, 'seq', 123123123), + (TCP, 'ack', 181818181), + (TCP, 'dataofs', 5), + (TCP, 'dataofs', 0), + (TCP, 'dataofs', 15), + (TCP, 'reserved', 0), + (TCP, 'window', 100), + (TCP, 'chksum', 0x4444), + (TCP, 'urgptr', 1), + + (DNS, 'id', 0xabcd), + (DNS, 'qr', 1), + (DNS, 'opcode', 9), + (DNS, 'aa', 0), + (DNS, 'tc', 1), + (DNS, 'rd', 0), + (DNS, 'ra', 1), + (DNS, 'z', 0), + (DNS, 'ad', 1), + (DNS, 'cd', 0), + (DNS, 'qdcount', 0x1234), + (DNS, 'ancount', 12345), + (DNS, 'nscount', 49870), + (DNS, 'arcount', 0xABCD), + + (DNSQR, 'qname', 'example.com.'), + (DNSQR, 'qtype', 1), + (DNSQR, 'qclass', 0), + ] + + +def get_custom_configs(): + """ + Generates test configurations that can use the custom getters. + """ + return [ + (IP, 'flags', ''), + (TCP, 'options-eol', ''), + (TCP, 'options-nop', ''), + (TCP, 'options-mss', 0), + (TCP, 'options-mss', 1440), + (TCP, 'options-mss', 5000), + (TCP, 'options-wscale', 20), + (TCP, 'options-sackok', ''), + (TCP, 'options-sack', ''), + (TCP, 'options-timestamp', 12345678), + (TCP, 'options-altchksum', 0x44), + (TCP, 'options-altchksumopt', ''), + (TCP, 'options-uto', 1), + #(TCP, 'options-md5header', 'deadc0ffee') + ] + + +@pytest.mark.parametrize("config", get_test_configs(), + ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs()]) +def test_get(config): + """ + Tests value retrieval. + """ + proto, field, val = config + check_get(proto, field, val) + + +def check_set_get(protocol, field, value): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = actions.packet.Packet(protocol()) + pkt.set(protocol.__name__, field, value) + assert pkt.get(protocol.__name__, field) == value + # Rebuild the packet to confirm the type survived + pkt2 = actions.packet.Packet(protocol(bytes(pkt))) + assert pkt2.get(protocol.__name__, field) == value, "Value %s for header %s didn't survive packet parsing." % (value, field) + + +@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(), + ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs() + get_custom_configs()]) +def test_set_get(config): + """ + Tests value retrieval. + """ + proto, field, value = config + check_set_get(proto, field, value) + + +def check_gen_set_get(protocol, field): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = actions.packet.Packet(protocol()) + new_value = pkt.gen(protocol.__name__, field) + pkt.set(protocol.__name__, field, new_value) + assert pkt.get(protocol.__name__, field) == new_value + # Rebuild the packet to confirm the type survived + pkt2 = actions.packet.Packet(protocol(bytes(pkt))) + assert pkt2.get(protocol.__name__, field) == new_value + + +@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(), + ids=['%s-%s' % (proto.__name__, field) for proto, field, _ in get_test_configs() + get_custom_configs()]) +def test_gen_set_get(config): + """ + Tests value retrieval. + """ + # Test each generator 50 times to hit a range of values + for i in range(0, 50): + proto, field, _ = config + check_gen_set_get(proto, field) + + +def test_custom_get(): + """ + Tests value retrieval for custom getters. + """ + pkt = IP()/TCP()/Raw(load="AAAA") + tcp = actions.packet.Packet(pkt) + assert tcp.get("TCP", "load") == "AAAA" + + +def test_restrict_fields(): + """ + Tests packet field restriction. + """ + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + tcpfields = actions.layer.TCPLayer.fields + udpfields = actions.layer.UDPLayer.fields + ipfields = actions.layer.IPLayer.fields + + actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], [], []) + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + + pkt = IP()/TCP() + packet = actions.packet.Packet(pkt) + assert "TCP" in packet.layers + assert not "IP" in packet.layers + assert len(packet.layers) == 1 + + for i in range(0, 2000): + layer, proto, field = actions.packet.Packet().gen_random() + assert layer in [TCP, UDP] + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + with pytest.raises(AssertionError): + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in actions.layer.TCPLayer.fields: + packet.get("TCP", field) + + actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], ["flags"], []) + packet = actions.packet.Packet(pkt) + assert len(actions.packet.SUPPORTED_LAYERS) == 1 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.TCPLayer.fields == ["flags"] + assert not actions.layer.UDPLayer.fields + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + with pytest.raises(AssertionError): + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in tcpfields: + if field == "flags": + packet.get("TCP", field) + else: + with pytest.raises(AssertionError): + packet.get("TCP", field) + + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer == TCP + assert field == "flags" + + actions.packet.Packet.reset_restrictions() + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + actions.packet.Packet.restrict_fields(logger, ["TCP", "IP"], [], ["sport", "dport", "seq", "src"]) + packet = actions.packet.Packet(pkt) + packet = packet.copy() + assert packet.has_supported_layers() + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert set(actions.layer.TCPLayer.fields) == set([f for f in tcpfields if f not in ["sport", "dport", "seq"]]) + assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["src"]]) + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + if field == "src": + with pytest.raises(AssertionError): + packet.get("IP", field) + else: + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in tcpfields: + if field in ["sport", "dport", "seq"]: + with pytest.raises(AssertionError): + packet.get("TCP", field) + else: + packet.get("TCP", field) + + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer in [TCP, IP] + assert field not in ["sport", "dport", "seq", "src"] + + actions.packet.Packet.reset_restrictions() + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + + actions.packet.Packet.restrict_fields(logger, ["IP", "UDP", "DNS"], [], ["version"]) + packet = actions.packet.Packet(pkt) + proto, field, value = packet.get_random() + assert proto.__name__ in ["IP", "UDP"] + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert not actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["version"]]) + assert set(actions.layer.UDPLayer.fields) == set(udpfields) + + actions.packet.Packet.reset_restrictions() + for layer in actions.packet.SUPPORTED_LAYERS: + assert layer.fields, '%s has no fields - reset failed!' % str(layer) diff --git a/tests/test_strategy.py b/tests/test_strategy.py new file mode 100644 index 0000000..9e129e9 --- /dev/null +++ b/tests/test_strategy.py @@ -0,0 +1,94 @@ +import logging +import pytest + +import actions.tree +import actions.drop +import actions.tamper +import actions.trace +import actions.duplicate +import actions.sleep +import actions.utils +import actions.strategy + +from scapy.all import IP, TCP + +logger = logging.getLogger("test") + + +def test_run(): + """ + Tests strategy execution. + """ + strat1 = actions.utils.parse("[TCP:flags:R]-duplicate-| \/", logger) + strat2 = actions.utils.parse("[TCP:flags:S]-drop-| \/", logger) + strat3 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:dataofs:replace:0},)-| \/", logger) + strat4 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:15239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14539}(tamper{TCP:seq:corrupt},),),))-| \/", logger) + + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + packets = strat1.act_on_packet(p1, logger, direction="out") + assert packets, "Strategy dropped SYN packets" + assert len(packets) == 1 + assert packets[0]["TCP"].flags == "S" + + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + packets = strat2.act_on_packet(p1, logger, direction="out") + assert not packets, "Strategy failed to drop SYN packets" + + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5)) + packets = strat3.act_on_packet(p1, logger, direction="out") + assert packets, "Strategy dropped packets" + assert len(packets) == 2, "Incorrect number of packets emerged from forest" + assert packets[0]["TCP"].dataofs == 0, "Packet tamper failed" + assert packets[1]["TCP"].dataofs == 5, "Duplicate packet was tampered" + + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5, chksum=100)) + packets = strat4.act_on_packet(p1, logger, direction="out") + assert packets, "Strategy dropped packets" + assert len(packets) == 3, "Incorrect number of packets emerged from forest" + assert packets[0]["TCP"].flags == "R", "Packet tamper failed" + assert packets[0]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed" + assert packets[1]["TCP"].flags == "S", "Packet tamper failed" + assert packets[1]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed" + assert packets[1]["TCP"].seq != p1["TCP"].seq, "Packet tamper failed" + assert packets[2]["TCP"].flags == "A", "Duplicate failed" + + strat4 = actions.utils.parse("[TCP:load:]-tamper{TCP:load:replace:mhe76jm0bd}(fragment{ip:-1:True}(tamper{IP:load:corrupt},drop),)-| \/ ", logger) + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + packets = strat4.act_on_packet(p1, logger) + + # Will fail with scapy 2.4.2 if packet is reparsed + strat5 = actions.utils.parse("\"[TCP:options-eol:]-tamper{TCP:load:replace:o}(tamper{TCP:dataofs:replace:11},)-| \/\"", logger) + p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + packets = strat5.act_on_packet(p1, logger) + + +def test_pretty_print(): + """ + Tests if the string representation of this strategy is correct + """ + logger = logging.getLogger("test") + strat = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/ ", logger) + correct = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:corrupt}\n│ └── ===> \n└── ===> \n \n \/ \n " + assert strat.pretty_print() == correct + + +def test_sleep_parse_handling(): + """ + Tests that the sleep action handles bad parsing. + """ + + print("Testing incorrect parsing:") + assert not actions.sleep.SleepAction().parse("THISHSOULDFAIL", logger) + + assert actions.sleep.SleepAction().parse("10.5", logger) + + +def test_trace_parse_handling(): + """ + Tests that the sleep action handles bad parsing. + """ + + print("Testing incorrect parsing:") + assert not actions.trace.TraceAction().parse("5:4", logger) + assert not actions.trace.TraceAction().parse("THISHOULDFAIL", logger) + assert not actions.trace.TraceAction().parse("", logger) diff --git a/tests/test_tamper.py b/tests/test_tamper.py new file mode 100644 index 0000000..e5fbeee --- /dev/null +++ b/tests/test_tamper.py @@ -0,0 +1,389 @@ +import copy +import logging +import sys +import pytest +import random +# Include the root of the project +sys.path.append("..") + +import actions.strategy +import actions.packet +import actions.utils +import actions.tamper +import actions.layer + +from scapy.all import IP, TCP, UDP, DNS, DNSQR, sr1 + + +logger = logging.getLogger("test") + + +def test_tamper(): + """ + Tests tampering with replace + """ + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R") + lpacket, rpacket = tamper.run(packet, logger) + assert not rpacket, "Tamper must not return right child" + assert lpacket, "Tamper must give a left child" + assert id(lpacket) == id(packet), "Tamper must edit in place" + + # Confirm tamper replaced the field it was supposed to + assert packet[TCP].flags == "R", "Tamper did not replace flags." + new_value = packet[TCP].flags + + # Must run this check repeatedly - if a scapy fuzz-ed value is not properly + # ._fix()-ed, it will return different values each time it's requested + for _ in range(0, 5): + assert packet[TCP].flags == new_value, "Replaced value is not stable" + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["flags"]) + + # Confirm tamper didn't corrupt anything in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_tamper_ip(): + """ + Tests tampering with IP + """ + packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper = actions.tamper.TamperAction(None, field="src", tamper_type="replace", tamper_value="192.168.1.1", tamper_proto="IP") + lpacket, rpacket = tamper.run(packet, logger) + assert not rpacket, "Tamper must not return right child" + assert lpacket, "Tamper must give a left child" + assert id(lpacket) == id(packet), "Tamper must edit in place" + + # Confirm tamper replaced the field it was supposed to + assert packet[IP].src == "192.168.1.1", "Tamper did not replace flags." + + # Confirm tamper didn't corrupt anything in the TCP header + assert confirm_unchanged(packet, original, TCP, []) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, ["src"]) + + +def test_tamper_udp(): + """ + Tests tampering with UDP + """ + packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/UDP(sport=2222, dport=53)) + original = copy.deepcopy(packet) + tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="replace", tamper_value=4444, tamper_proto="UDP") + lpacket, rpacket = tamper.run(packet, logger) + assert not rpacket, "Tamper must not return right child" + assert lpacket, "Tamper must give a left child" + assert id(lpacket) == id(packet), "Tamper must edit in place" + + # Confirm tamper replaced the field it was supposed to + assert packet[UDP].chksum == 4444, "Tamper did not replace flags." + + # Confirm tamper didn't corrupt anything in the TCP header + assert confirm_unchanged(packet, original, UDP, ["chksum"]) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_tamper_ip_ident(): + """ + Tests tampering with IP and that the checksum is correctly changed + """ + + packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper = actions.tamper.TamperAction(None, field='id', tamper_type='replace', tamper_value=3333, tamper_proto="IP") + lpacket, rpacket = tamper.run(packet, logger) + assert not rpacket, "Tamper must not return right child" + assert lpacket, "Tamper must give a left child" + assert id(lpacket) == id(packet), "Tamper must edit in place" + + # Confirm tamper replaced the field it was supposed to + assert packet[IP].id == 3333, "Tamper did not replace flags." + + # Confirm tamper didn't corrupt anything in the TCP header + assert confirm_unchanged(packet, original, TCP, []) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, ["id"]) + + +def confirm_unchanged(packet, original, protocol, changed): + """ + Checks that no other field besides the given array of changed fields + are different between these two packets. + """ + for header in packet.layers: + if packet.layers[header].protocol != protocol: + continue + for field in packet.layers[header].fields: + # Skip checking the field we just changed + if field in changed or field == "load": + continue + assert packet.get(protocol.__name__, field) == original.get(protocol.__name__, field), "Tamper changed %s field %s." % (str(protocol), field) + return True + + +def test_parse_parameters(): + """ + Tests that tamper properly rejects malformed tamper actions + """ + with pytest.raises(Exception): + actions.tamper.TamperAction().parse("this:has:too:many:parameters", logger) + with pytest.raises(Exception): + actions.tamper.TamperAction().parse("not:enough", logger) + + +def test_corrupt(): + """ + Tests the tamper 'corrupt' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="corrupt", tamper_value="R") + assert tamper.field == "flags", "Tamper action changed fields." + assert tamper.tamper_type == "corrupt", "Tamper action changed types." + assert str(tamper) == "tamper{TCP:flags:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper.tamper(packet, logger) + + new_value = packet[TCP].flags + + # Must run this check repeatedly - if a scapy fuzz-ed value is not properly + # ._fix()-ed, it will return different values each time it's requested + for _ in range(0, 5): + assert packet[TCP].flags == new_value, "Corrupted value is not stable" + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["flags"]) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_add(): + """ + Tests the tamper 'add' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="seq", tamper_type="add", tamper_value=10) + assert tamper.field == "seq", "Tamper action changed fields." + assert tamper.tamper_type == "add", "Tamper action changed types." + assert str(tamper) == "tamper{TCP:seq:add:10}", "Tamper returned incorrect string representation: %s" % str(tamper) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper.tamper(packet, logger) + + new_value = packet[TCP].seq + assert new_value == 110, "Tamper did not add" + + # Must run this check repeatedly - if a scapy fuzz-ed value is not properly + # ._fix()-ed, it will return different values each time it's requested + for _ in range(0, 5): + assert packet[TCP].seq == new_value, "Corrupted value is not stable" + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["seq"]) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_decompress(): + """ + Tests the tamper 'decompress' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="qd", tamper_type="compress", tamper_value=10, tamper_proto="DNS") + assert tamper.field == "qd", "Tamper action changed fields." + assert tamper.tamper_type == "compress", "Tamper action changed types." + assert str(tamper) == "tamper{DNS:qd:compress}", "Tamper returned incorrect string representation: %s" % str(tamper) + + packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="minghui.ca."))) + original = packet.copy() + tamper.tamper(packet, logger) + assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x07minghui\xc0\x1a\x00\x01\x00\x01\x02ca\x00\x00\x01\x00\x01' + resp = sr1(packet.packet) + assert resp["DNS"] + assert resp["DNS"].rcode != 1 + assert resp["DNSQR"] + assert resp["DNSRR"].rdata + assert confirm_unchanged(packet, original, IP, ["len"]) + print(resp.summary()) + + packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com"))) + original = packet.copy() + tamper.tamper(packet, logger) + assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x04maps\xc0\x17\x00\x01\x00\x01\x06google\x03com\x00\x00\x01\x00\x01' + resp = sr1(packet.packet) + assert resp["DNS"] + assert resp["DNS"].rcode != 1 + assert resp["DNSQR"] + assert resp["DNSRR"].rdata + assert confirm_unchanged(packet, original, IP, ["len"]) + print(resp.summary()) + + # Confirm this is a NOP on normal packets + packet = actions.packet.Packet(IP()/UDP()) + original = packet.copy() + tamper.tamper(packet, logger) + assert packet.packet.summary() == original.packet.summary() + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, UDP, []) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + packet = actions.packet.Packet(IP(dst="8.8.8.8")/TCP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com"))) + original = packet.copy() + tamper.tamper(packet, logger) + assert bytes(packet) == bytes(original) + + + +def test_corrupt_chksum(): + """ + Tests the tamper 'replace' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="corrupt", tamper_value="R") + assert tamper.field == "chksum", "Tamper action changed checksum." + assert tamper.tamper_type == "corrupt", "Tamper action changed types." + assert str(tamper) == "tamper{TCP:chksum:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper.tamper(packet, logger) + + # Confirm tamper actually corrupted the checksum + assert packet[TCP].chksum != 0 + new_value = packet[TCP].chksum + + # Must run this check repeatedly - if a scapy fuzz-ed value is not properly + # ._fix()-ed, it will return different values each time it's requested + for _ in range(0, 5): + assert packet[TCP].chksum == new_value, "Corrupted value is not stable" + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["chksum"]) + + # Confirm tamper didn't corrupt anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_corrupt_dataofs(): + """ + Tests the tamper 'replace' primitive. + """ + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S", dataofs="6L")) + original = copy.deepcopy(packet) + tamper = actions.tamper.TamperAction(None, field="dataofs", tamper_type="corrupt") + + tamper.tamper(packet, logger) + + # Confirm tamper actually corrupted the checksum + assert packet[TCP].dataofs != "0" + new_value = packet[TCP].dataofs + + # Must run this check repeatedly - if a scapy fuzz-ed value is not properly + # ._fix()-ed, it will return different values each time it's requested + for _ in range(0, 5): + assert packet[TCP].dataofs == new_value, "Corrupted value is not stable" + + # Confirm tamper didn't corrupt anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["dataofs"]) + + # Confirm tamper didn't corrupt anything in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_replace(): + """ + Tests the tamper 'replace' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R") + + assert tamper.field == "flags", "Tamper action changed fields." + assert tamper.tamper_type == "replace", "Tamper action changed types." + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + original = copy.deepcopy(packet) + tamper.tamper(packet, logger) + + # Confirm tamper replaced the field it was supposed to + assert packet[TCP].flags == "R", "Tamper did not replace flags." + # Confirm tamper didn't replace anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["flags"]) + + # Confirm tamper didn't replace anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + # chksums must be handled specially by tamper, so run a second check on this value + tamper.field = "chksum" + tamper.tamper_value = 0x4444 + original = copy.deepcopy(packet) + tamper.tamper(packet, logger) + assert packet[TCP].chksum == 0x4444, "Tamper failed to change chksum." + # Confirm tamper didn't replace anything else in the TCP header + assert confirm_unchanged(packet, original, TCP, ["chksum"]) + # Confirm tamper didn't replace anything else in the IP header + assert confirm_unchanged(packet, original, IP, []) + + +def test_parse_flags(): + """ + Tests the tamper 'replace' primitive. + """ + tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="FRAPUN") + assert tamper.field == "flags", "Tamper action changed checksum." + assert tamper.tamper_type == "replace", "Tamper action changed types." + assert str(tamper) == "tamper{TCP:flags:replace:FRAPUN}", "Tamper returned incorrect string representation: %s" % str(tamper) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + tamper.tamper(packet, logger) + assert packet[TCP].flags == "FRAPUN", "Tamper failed to change flags." + + +@pytest.mark.parametrize("test_type", ["parsed", "direct"]) +@pytest.mark.parametrize("value", ["EOL", "NOP", "Timestamp", "MSS", "WScale", "SAckOK", "SAck", "Timestamp", "AltChkSum", "AltChkSumOpt", "UTO"]) +def test_options(value, test_type): + """ + Tests tampering options + """ + if test_type == "direct": + tamper = actions.tamper.TamperAction(None, field="options-%s" % value.lower(), tamper_type="corrupt", tamper_value=bytes([12])) + else: + tamper = actions.tamper.TamperAction(None) + assert tamper.parse("TCP:options-%s:replace:" % value.lower(), logger) + assert tamper.parse("TCP:options-%s:corrupt" % value.lower(), logger) + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")) + tamper.run(packet, logger) + opts_dict_lookup = value.lower().replace(" ", "_") + + for optname, optval in packet["TCP"].options: + if optname == value: + break + elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]: + break + else: + pytest.fail("Failed to find %s in options" % value) + assert len(packet["TCP"].options) == 1 + raw_p = bytes(packet) + assert raw_p, "options broke scapy bytes" + p2 = actions.packet.Packet(IP(bytes(raw_p))) + assert p2.haslayer("IP") + assert p2.haslayer("TCP") + # EOLs might be added for padding, so just check >= 1 + assert len(p2["TCP"].options) >= 1 + for optname, optval in p2["TCP"].options: + if optname == value: + break + elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]: + break + else: + pytest.fail("Failed to find %s in options" % value) diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 0000000..4c8e74b --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,549 @@ +import logging +import os + +from scapy.all import IP, TCP +import actions.tree +import actions.drop +import actions.tamper +import actions.duplicate +import actions.utils + + +def test_init(): + """ + Tests initialization + """ + print(actions.action.Action.get_actions("out")) + + +def test_count_leaves(): + """ + Tests leaf count is correct. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + + assert not a.parse("TCP:reserved:0tamper{TCP:flags:replace:S}-|", logger), "Tree parsed malformed DNA" + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + + assert a.count_leaves() == 1 + assert a.remove_one() + a.add_action(duplicate) + assert a.count_leaves() == 1 + duplicate.left = duplicate2 + assert a.count_leaves() == 1 + duplicate.right = drop + assert a.count_leaves() == 2 + + +def test_check(): + """ + Tests action tree check function. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + a.parse("[TCP:flags:RA]-tamper{TCP:flags:replace:S}-|", logger) + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert not a.check(p, logger) + p = actions.packet.Packet(IP(ttl=64)/TCP(flags="RA")) + assert a.check(p, logger) + assert a.remove_one() + assert a.check(p, logger) + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + assert a.check(p, logger) + a.parse("[IP:ttl:64]-tamper{TCP:flags:replace:S}-|", logger) + assert a.check(p, logger) + p = actions.packet.Packet(IP(ttl=15)/TCP(flags="RA")) + assert not a.check(p, logger) + + +def test_scapy(): + """ + Tests misc. scapy aspects relevant to strategies. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(p, logger) + packets = a.run(p, logger) + assert packets[0][TCP].flags == "S" + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(p, logger) + a.parse("[TCP:reserved:0]-tamper{TCP:chksum:corrupt}-|", logger) + packets = a.run(p, logger) + assert packets[0][TCP].chksum + assert a.check(p, logger) + + +def test_str(): + """ + Tests string representation. + """ + logger = logging.getLogger("test") + + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + assert str(a).strip() == "[%s]-|" % str(t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + assert a.add_action(tamper) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|" + # Tree will not add a duplicate action + assert not a.add_action(tamper) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|" + assert a.add_action(tamper2) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|" + assert a.add_action(actions.duplicate.DuplicateAction()) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + drop = actions.drop.DropAction() + assert a.add_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(drop,),),)-|" or \ + str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(,drop),),)-|" + assert a.remove_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + # Cannot remove action that is not present + assert not a.remove_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + + a = actions.tree.ActionTree("out", trigger=t) + orig = "[TCP:urgptr:15963]-duplicate(,drop)-|" + a.parse(orig, logger) + assert a.remove_one() + assert orig != str(a) + assert str(a) in ["[TCP:urgptr:15963]-drop-|", "[TCP:urgptr:15963]-duplicate-|"] + + +def test_pretty_print_send(): + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + correct_string = "TCP:flags:0\nduplicate\n├── ===> \n└── ===> " + assert a.pretty_print() == correct_string + + +def test_pretty_print(): + """ + Print complex tree, although difficult to test + """ + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + duplicate3 = actions.duplicate.DuplicateAction() + duplicate4 = actions.duplicate.DuplicateAction() + duplicate5 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + drop2 = actions.drop.DropAction() + drop3 = actions.drop.DropAction() + drop4 = actions.drop.DropAction() + + duplicate.left = duplicate2 + duplicate.right = duplicate3 + duplicate2.left = tamper + duplicate2.right = drop + duplicate3.left = duplicate4 + duplicate3.right = drop2 + duplicate4.left = duplicate5 + duplicate4.right = drop3 + duplicate5.left = drop4 + duplicate5.right = tamper2 + + a.add_action(duplicate) + correct_string = "TCP:flags:0\nduplicate\n├── duplicate\n│ ├── tamper{TCP:flags:replace:S}\n│ │ └── ===> \n│ └── drop\n└── duplicate\n ├── duplicate\n │ ├── duplicate\n │ │ ├── drop\n │ │ └── tamper{TCP:flags:replace:R}\n │ │ └── ===> \n │ └── drop\n └── drop" + assert a.pretty_print() == correct_string + assert a.pretty_print(visual=True) + assert os.path.exists("tree.png") + os.remove("tree.png") + a.parse("[TCP:flags:0]-|", logging.getLogger("test")) + a.pretty_print(visual=True) # Empty action tree + assert not os.path.exists("tree.png") + +def test_pretty_print_order(): + """ + Tests the left/right ordering by reading in a new tree + """ + logger = logging.getLogger("test") + a = actions.tree.ActionTree("out") + assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14239},),))-|", logger) + correct_pretty_print = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:replace:14239}\n│ └── ===> \n└── duplicate\n ├── tamper{TCP:flags:replace:S}\n │ └── tamper{TCP:chksum:replace:14239}\n │ └── ===> \n └── ===> " + assert a.pretty_print() == correct_pretty_print + +def test_parse(): + """ + Tests string parsing. + """ + logger = logging.getLogger("test") + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + + base_t = actions.trigger.Trigger("field", "flags", "TCP") + base_a = actions.tree.ActionTree("out", trigger=base_t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + a.parse("[TCP:flags:0]-|", logger) + assert str(a) == str(base_a) + assert len(a) == 0 + + base_a.add_action(tamper) + + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}-|", logger) + + assert str(a) == str(base_a) + assert len(a) == 1 + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|", logging.getLogger("test")) + base_a.add_action(tamper2) + assert str(a) == str(base_a) + assert len(a) == 2 + + base_a.add_action(tamper3) + base_a.add_action(tamper4) + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},),),)-|", logging.getLogger("test")) + assert str(a) == str(base_a) + assert len(a) == 4 + + base_t = actions.trigger.Trigger("field", "flags", "TCP") + base_a = actions.tree.ActionTree("out", trigger=base_t) + duplicate = actions.duplicate.DuplicateAction() + assert a.parse("[TCP:flags:0]-duplicate-|", logger) + base_a.add_action(duplicate) + assert str(a) == str(base_a) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="A") + tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate.left = tamper + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},)-|", logger) + assert str(a) == str(base_a) + + duplicate.right = tamper2 + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-|", logger) + assert str(a) == str(base_a) + + tamper2.left = tamper3 + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:A},))-|", logger) + assert str(a) == str(base_a) + + strategy = actions.utils.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-| \/", logger) + assert strategy + assert len(strategy.out_actions[0]) == 3 + assert len(strategy.in_actions) == 0 + + assert not a.parse("[]", logger) # No valid trigger + assert not a.parse("[TCP:flags:0]-", logger) # No valid ending "|" + assert not a.parse("[TCP:]-|", logger) # invalid trigger + assert not a.parse("[TCP:flags:0]-foo-|", logger) # Non-existent action + assert not a.parse("[TCP:flags:0]--|", logger) # Empty action + assert not a.parse("[TCP:flags:0]-duplicate(,,,)-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate()))-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate(((()-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate(,))))-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-drop(duplicate,)-|", logger) # Terminal action with children + assert not a.parse("[TCP:flags:0]-drop(duplicate,duplicate)-|", logger) # Terminal action with children + assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(,duplicate)-|", logger) # Non-branching action with right child + assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,duplicate)-|", logger) # Non-branching action with children + + +def test_tree(): + """ + Tests basic tree functionality. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + duplicate = actions.duplicate.DuplicateAction() + assert a.get_parent(None) == (None, None) + + a.add_action(None) + a.add_action(tamper) + assert a.get_slots() == 1 + a.add_action(tamper2) + assert a.get_parent(tamper2) == (tamper, "left") + assert a.get_slots() == 1 + a.add_action(duplicate) + assert a.get_slots() == 2 + + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + drop = actions.drop.DropAction() + a.add_action(drop) + assert a.get_parent(drop) == (None, None) + assert a.get_slots() == 0 + add_success = a.add_action(tamper) + assert not add_success + assert a.get_slots() == 0 + + rep = "" + for s in a.string_repr(a.action_root): + rep += s + assert rep == "drop" + + print(str(a)) + + assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:seq:corrupt},)-|", logging.getLogger("test")) + for act in a: + print(str(a)) + assert len(a) == 2 + assert a.get_slots() == 2 + + +def test_remove(): + """ + Tests remove + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + assert not a.remove_action(tamper) + a.add_action(tamper) + assert a.remove_action(tamper) + a.add_action(tamper) + a.add_action(tamper2) + a.add_action(tamper3) + assert a.remove_action(tamper2) + assert tamper2 not in a + assert tamper.left == tamper3 + assert not tamper.right + assert len(a) == 2 + a = actions.tree.ActionTree("out", trigger=t) + duplicate = actions.duplicate.DuplicateAction() + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + a.add_action(tamper) + assert a.action_root == tamper + duplicate.left = tamper2 + duplicate.right = tamper3 + a.add_action(duplicate) + assert a.get_parent(tamper3) == (duplicate, "right") + assert len(a) == 4 + assert a.remove_action(duplicate) + assert duplicate not in a + assert tamper.left == tamper2 + assert not tamper.right + assert len(a) == 2 + + a.parse("[TCP:flags:A]-|", logging.getLogger("test")) + assert not a.remove_one(), "Cannot remove one with no action root" + + +def test_len(): + """ + Tests length calculation. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + assert len(a) == 0, "__len__ returned wrong length" + a.add_action(tamper) + assert len(a) == 1, "__len__ returned wrong length" + a.add_action(tamper) + assert len(a) == 1, "__len__ returned wrong length" + a.add_action(tamper2) + assert len(a) == 2, "__len__ returned wrong length" + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + assert len(a) == 3, "__len__ returned wrong length" + + +def test_contains(): + """ + Tests contains method + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + + assert not a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper) + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + add_success = a.add_action(tamper) + assert not add_success, "added duplicate action" + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + a.remove_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + remove_success = a.remove_action(tamper) + assert remove_success + assert not a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper3) + assert a.contains(tamper3), "contains incorrect behavior" + assert len(a) == 2, "len incorrect return" + remove_success = a.remove_action(tamper2) + assert remove_success + + +def test_iter(): + """ + Tests iterator. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + + assert a.add_action(tamper) + assert a.add_action(tamper2) + assert not a.add_action(tamper) + for node in a: + print(node) + + +def test_run(): + """ + Tests running packets through the chain. + """ + logger = logging.getLogger("test") + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + + packet = actions.packet.Packet(IP()/TCP()) + a.add_action(tamper) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 1 + assert None not in packets + assert packets[0].get("TCP", "flags") == "S" + a.add_action(tamper2) + print(str(a)) + + packet = actions.packet.Packet(IP()/TCP()) + assert not a.add_action(tamper), "tree added duplicate action" + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 1 + assert None not in packets + assert packets[0].get("TCP", "flags") == "R" + print(str(a)) + + a.remove_action(tamper2) + a.remove_action(tamper) + a.add_action(duplicate) + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + assert packets[0][TCP].flags == "RA" + assert packets[1][TCP].flags == "RA" + print(str(a)) + + duplicate.left = tamper + duplicate.right = tamper2 + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + print(str(a)) + print(str(packets[0])) + print(str(packets[1])) + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "R" + print(str(a)) + + tamper.left = duplicate2 + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 3 + assert None not in packets + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "S" + assert packets[2][TCP].flags == "R" + print(str(a)) + + tamper2.left = drop + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "S" + print(str(a)) + + assert a.remove_action(duplicate2) + tamper.left = actions.drop.DropAction() + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logger ) + assert len(packets) == 0 + print(str(a)) + + a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S},))-|", logger) + packet = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(packet, logger) + packets = a.run(packet, logger) + assert len(packets) == 3 + assert packets[0][TCP].flags == "R" + assert packets[1][TCP].flags == "S" + assert packets[2][TCP].flags == "A" + + +def test_index(): + """ + Tests index + """ + a = actions.tree.ActionTree("out") + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="F") + + assert a.add_action(tamper) + assert a[0] == tamper + assert not a[1] + assert a.add_action(tamper2) + assert a[0] == tamper + assert a[1] == tamper2 + assert a[-1] == tamper2 + assert not a[10] + assert a.add_action(tamper3) + assert a[-1] == tamper3 + assert not a[-11] + + +def test_choose_one(): + """ + Tests choose_one functionality + """ + a = actions.tree.ActionTree("out") + drop = actions.drop.DropAction() + assert not a.choose_one() + assert a.add_action(drop) + assert a.choose_one() == drop + assert a.remove_action(drop) + assert not a.choose_one() + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + assert a.choose_one() == duplicate + duplicate.left = drop + assert a.choose_one() in [duplicate, drop] + # Make sure that both actions get chosen + chosen = set() + for i in range(0, 10000): + act = a.choose_one() + chosen.add(act) + assert chosen == set([duplicate, drop]) diff --git a/tests/test_trigger.py b/tests/test_trigger.py new file mode 100644 index 0000000..062c2ec --- /dev/null +++ b/tests/test_trigger.py @@ -0,0 +1,151 @@ +import logging +import sys +# Include the root of the project +sys.path.append("..") + +import actions.packet +import actions.strategy +import actions.tamper +import actions.utils + +from scapy.all import IP, TCP + +logger = logging.getLogger("test") + + +def test_trigger_gas(): + """ + Tests triggers having gas, including changing that gas while in use + """ + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA")) + trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=1) + print(trigger) + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + print(trigger) + # test add gas # + trigger.add_gas(3) + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + + # Test disable, set, and enable gas # + trigger.disable_gas() + assert trigger.is_applicable(packet, logger) + trigger.set_gas(3) + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + trigger.enable_gas() + trigger.set_gas(2) + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + + +def test_bomb_trigger_gas(): + """ + Tests triggers having bomb gas, including changing that gas while in use + """ + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA")) + trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=-1) + print(trigger) + assert not trigger.is_applicable(packet, logger), "trigger should not fire on first run" + assert trigger.is_applicable(packet, logger), "trigger should fire on second run" + print(trigger) + # test add gas # + trigger.add_gas(-3) + assert not trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + + # Test disable, set, and enable gas # + trigger.disable_gas() + assert trigger.is_applicable(packet, logger) + trigger.set_gas(-3) + assert not trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + trigger.enable_gas() + trigger.set_gas(-2) + assert not trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + + +def test_trigger_parse_gas(): + """ + Tests triggers having gas, including changing that gas while in use + """ + + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA")) + + + # parse a trigger with 1 gas + trigger = actions.trigger.Trigger.parse("TCP:flags:SA:1") + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + + # parse a trigger with no gas left + trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0") + assert not trigger.is_applicable(packet, logger) + + # parse a trigger not using gas + trigger = actions.trigger.Trigger.parse("TCP:flags:SA") + assert trigger.is_applicable(packet, logger) + # Check that adding gas while gas is disabled does not work + trigger.add_gas(10) + assert trigger.gas_remaining == None + + trigger.enable_gas() + trigger.set_gas(2) + + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + + # Test that it can handle leading/trailing [] + trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]") + assert trigger.is_applicable(packet, logger) + + +def test_bomb_trigger_parse_gas(): + """ + Tests bomb triggers having gas, including changing that gas while in use + """ + packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA")) + + # parse a bomb trigger with 1 gas + trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1") + assert not trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + + # parse a trigger with no gas left + trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0") + assert not trigger.is_applicable(packet, logger) + + trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1") + assert not trigger.is_applicable(packet, logger) + + # parse a trigger not using gas + trigger = actions.trigger.Trigger.parse("TCP:flags:SA") + assert trigger.is_applicable(packet, logger) + # Check that adding gas while gas is disabled does not work + trigger.add_gas(10) + assert trigger.gas_remaining == None + + trigger.enable_gas() + trigger.set_gas(2) + + assert trigger.is_applicable(packet, logger) + assert trigger.is_applicable(packet, logger) + assert not trigger.is_applicable(packet, logger) + + # Test that it can handle leading/trailing [] + trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]") + assert trigger.is_applicable(packet, logger) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6f84891 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,42 @@ +import sys +import pytest +# Include the root of the project +sys.path.append("..") + +import actions.action +import actions.strategy +import actions.utils +import actions.duplicate + +import logging + +logger = logging.getLogger("test") + + +def get_test_configs(): + """ + Sets up the tests + """ + tests = [ + ("both", True, ['DuplicateAction', 'DropAction', 'SleepAction', 'TraceAction', 'TamperAction', 'FragmentAction']), + ("in", True, ['DropAction', 'TamperAction', 'SleepAction']), + ("out", True, ['DropAction', 'TamperAction', 'TraceAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']), + ("both", False, ['DuplicateAction', 'SleepAction', 'TamperAction', 'FragmentAction']), + ("in", False, ['TamperAction', 'SleepAction']), + ("out", False, ['TamperAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']), + ] + # To ensure caching is not breaking anything, double the tests + return tests + tests + + +@pytest.mark.parametrize("direction,allow_terminal,supported_actions", get_test_configs()) +def test_get_actions(direction, allow_terminal, supported_actions): + """ + Tests the duplicate action primitive. + """ + collected_actions = actions.action.Action.get_actions(direction, allow_terminal=allow_terminal) + names = [] + for name, action_class in collected_actions: + names.append(name) + assert set(names) == set(supported_actions) + assert len(names) == len(supported_actions)