diff --git a/memcached.c b/memcached.c index dc1f636ed4..459dc9045a 100644 --- a/memcached.c +++ b/memcached.c @@ -246,6 +246,8 @@ static void settings_init(void) { settings.crawls_persleep = 1000; settings.logger_watcher_buf_size = LOGGER_WATCHER_BUF_SIZE; settings.logger_buf_size = LOGGER_BUF_SIZE; + settings.local_write_only = false; + settings.anonymous_reads = false; } /* @@ -433,6 +435,8 @@ void conn_close_idle(conn *c) { if (settings.idle_timeout > 0 && (current_time - c->last_cmd_time) > settings.idle_timeout) { if (c->state != conn_new_cmd && c->state != conn_read) { + settings.local_write_only = false; + settings.anonymous_reads = false; if (settings.verbose > 1) fprintf(stderr, "fd %d wants to timeout, but isn't in read state", c->sfd); @@ -464,7 +468,8 @@ void conn_worker_readd(conn *c) { conn *conn_new(const int sfd, enum conn_states init_state, const int event_flags, const int read_buffer_size, enum network_transport transport, - struct event_base *base) { + struct event_base *base, + bool local) { conn *c; assert(sfd >= 0 && sfd < max_fds); @@ -522,6 +527,7 @@ conn *conn_new(const int sfd, enum conn_states init_state, c->transport = transport; c->protocol = settings.binding_protocol; + c->local = local; /* unix socket mode doesn't need this, so zeroed out. but why * is this done for every command? presumably for UDP @@ -1971,6 +1977,39 @@ static void process_bin_complete_sasl_auth(conn *c) { } } +static bool local_write_check(conn *c) +{ + assert(settings.local_write_only); + bool rv = c->local; + if (!c->local){ + switch (c->cmd){ + /* SASL fall-through */ + case PROTOCOL_BINARY_CMD_SASL_LIST_MECHS: + case PROTOCOL_BINARY_CMD_SASL_AUTH: + case PROTOCOL_BINARY_CMD_SASL_STEP: + case PROTOCOL_BINARY_CMD_VERSION: + /* anonymous reads */ + case PROTOCOL_BINARY_CMD_GET: + case PROTOCOL_BINARY_CMD_GETQ: + case PROTOCOL_BINARY_CMD_GETK: + case PROTOCOL_BINARY_CMD_GETKQ: + case PROTOCOL_BINARY_CMD_STAT: + case PROTOCOL_BINARY_CMD_NOOP: + case PROTOCOL_BINARY_CMD_QUIT: + case PROTOCOL_BINARY_CMD_QUITQ: + rv = true; + break; + } + } + + if (settings.verbose > 1){ + fprintf(stderr, "%d: local_write_check() in cmd 0x%02x is %s\n", + c->sfd, c->cmd, rv ? "true" : "false"); + } + return rv; + +} + static bool authenticated(conn *c) { assert(settings.sasl); bool rv = false; @@ -2006,6 +2045,13 @@ static void dispatch_bin_command(conn *c) { c->write_and_go = conn_closing; return; } + + if (settings.local_write_only && !local_write_check(c)){ + write_bin_error(c, PROTOCOL_BINARY_RESPONSE_AUTH_ERROR, NULL, 0); + c->write_and_go = conn_closing; + return; + } + MEMCACHED_PROCESS_COMMAND_START(c->sfd, c->rcurr, c->rbytes); c->noreply = true; @@ -3793,8 +3839,28 @@ static void process_memlimit_command(conn *c, token_t *tokens, const size_t ntok } } -static void process_command(conn *c, char *command) { +static bool local_write_ascii_check(conn *c, char *cmdToken){ + assert(settings.local_write_only); + bool rv = c->local; + if (!c->local){ + if (strcmp(cmdToken, "get" ) == 0 + || strcmp(cmdToken, "gets" ) == 0 + || strcmp(cmdToken, "bget" ) == 0 + || strcmp(cmdToken, "quit" ) == 0 + || strcmp(cmdToken, "stats" ) == 0 + || strcmp(cmdToken, "version" ) == 0 ){ + rv = true; + } + } + if (settings.verbose > 1){ + fprintf(stderr, "%d: local_write_ascii_check() in cmd '%s' is %s\n", + c-> sfd, cmdToken, rv ? "true" : "false"); + } + return rv; +} + +static void process_command(conn *c, char *command) { token_t tokens[MAX_TOKENS]; size_t ntokens; int comm; @@ -3820,6 +3886,13 @@ static void process_command(conn *c, char *command) { } ntokens = tokenize_command(command, tokens, MAX_TOKENS); + + if (settings.local_write_only + && !local_write_ascii_check(c, tokens[COMMAND_TOKEN].value)){ + out_string(c, "AUTH_ERROR Remote write not allowed"); + return; + } + if (ntokens >= 3 && ((strcmp(tokens[COMMAND_TOKEN].value, "get") == 0) || (strcmp(tokens[COMMAND_TOKEN].value, "bget") == 0))) { @@ -4577,7 +4650,7 @@ static void drive_machine(conn *c) { STATS_UNLOCK(); } else { dispatch_conn_new(sfd, conn_new_cmd, EV_READ | EV_PERSIST, - DATA_BUFFER_SIZE, c->transport); + DATA_BUFFER_SIZE, c->transport, c->local); } stop = true; @@ -4930,6 +5003,20 @@ static void maximize_sndbuf(const int sfd) { fprintf(stderr, "<%d send buffer was %d, now %d\n", sfd, old_size, last_good); } +const char * LOCALHOST_IPV4 = "127.0.0.1"; +const char * LOCALHOST_IPV6 = "::1"; + +static bool is_local_host(const char *iface) +{ + if (iface == NULL) + { + return false; + } + return strcmp(iface, "localhost") == 0 + || strcmp(iface, LOCALHOST_IPV4) == 0 + || strcmp(iface, LOCALHOST_IPV6) == 0; +} + /** * Create a socket and bind it to a specific port number * @param interface the interface to bind to @@ -4968,7 +5055,7 @@ static int server_socket(const char *interface, perror("getaddrinfo()"); return 1; } - + for (next= ai; next; next= next->ai_next) { conn *listen_conn_add; if ((sfd = new_socket(next)) == -1) { @@ -5065,12 +5152,13 @@ static int server_socket(const char *interface, int per_thread_fd = c ? dup(sfd) : sfd; dispatch_conn_new(per_thread_fd, conn_read, EV_READ | EV_PERSIST, - UDP_READ_BUFFER_SIZE, transport); + UDP_READ_BUFFER_SIZE, transport, false); } } else { if (!(listen_conn_add = conn_new(sfd, conn_listening, EV_READ | EV_PERSIST, 1, - transport, main_base))) { + transport, main_base, + is_local_host(interface)))) { fprintf(stderr, "failed to create listening connection\n"); exit(EXIT_FAILURE); } @@ -5085,6 +5173,7 @@ static int server_socket(const char *interface, return success == 0; } + static int server_sockets(int port, enum network_transport transport, FILE *portnumber_file) { if (settings.inter == NULL) { @@ -5217,7 +5306,7 @@ static int server_socket_unix(const char *path, int access_mask) { } if (!(listen_conn = conn_new(sfd, conn_listening, EV_READ | EV_PERSIST, 1, - local_transport, main_base))) { + local_transport, main_base,false))) { fprintf(stderr, "failed to create listening connection\n"); exit(EXIT_FAILURE); } @@ -5669,7 +5758,9 @@ int main (int argc, char **argv) { SLAB_SIZES, SLAB_CHUNK_MAX, TRACK_SIZES, - MODERN + MODERN, + LOCAL_WRITE_ONLY, + ANONYMOUS_READ }; char *const subopts_tokens[] = { [MAXCONNS_FAST] = "maxconns_fast", @@ -5692,6 +5783,8 @@ int main (int argc, char **argv) { [SLAB_CHUNK_MAX] = "slab_chunk_max", [TRACK_SIZES] = "track_sizes", [MODERN] = "modern", + [LOCAL_WRITE_ONLY] = "local_write_only", + [ANONYMOUS_READ] = "anonymous_read", NULL }; @@ -6121,6 +6214,18 @@ int main (int argc, char **argv) { start_lru_crawler = true; start_lru_maintainer = true; break; + case LOCAL_WRITE_ONLY: + settings.local_write_only = true; + break; + case ANONYMOUS_READ: + #ifndef ENABLE_SASL + fprintf(stderr, + "This server is not built with SASL support.\n"); + exit(EX_USAGE); + #endif + settings.anonymous_reads = true; + break; + default: printf("Illegal suboption \"%s\"\n", subopts_value); return 1; @@ -6134,6 +6239,9 @@ int main (int argc, char **argv) { return 1; } } + if(settings.verbose > 0 && settings.local_write_only){ + fprintf(stderr, "Only allowing writing via local loop back"); + } if (settings.slab_chunk_size_max > settings.item_size_max) { fprintf(stderr, "slab_chunk_max (bytes: %d) cannot be larger than -I (item_size_max %d)\n", diff --git a/memcached.h b/memcached.h index f7f4cd4a3e..b7fd2f76de 100644 --- a/memcached.h +++ b/memcached.h @@ -369,6 +369,8 @@ struct settings { int idle_timeout; /* Number of seconds to let connections idle */ unsigned int logger_watcher_buf_size; /* size of logger's per-watcher buffer */ unsigned int logger_buf_size; /* size of per-thread logger buffer */ + bool local_write_only; + bool anonymous_reads; }; extern struct stats stats; @@ -488,6 +490,7 @@ struct conn { char *rcurr; /** but if we parsed some already, this is where we stopped */ int rsize; /** total allocated size of rbuf */ int rbytes; /** how much data, starting from rcur, do we have unparsed */ + bool local; /** connection via local loopback **/ char *wbuf; char *wcurr; @@ -597,7 +600,7 @@ enum delta_result_type do_add_delta(conn *c, const char *key, const int64_t delta, char *buf, uint64_t *cas, const uint32_t hv); enum store_item_type do_store_item(item *item, int comm, conn* c, const uint32_t hv); -conn *conn_new(const int sfd, const enum conn_states init_state, const int event_flags, const int read_buffer_size, enum network_transport transport, struct event_base *base); +conn *conn_new(const int sfd, const enum conn_states init_state, const int event_flags, const int read_buffer_size, enum network_transport transport, struct event_base *base, bool local); void conn_worker_readd(conn *c); extern int daemonize(int nochdir, int noclose); @@ -622,7 +625,7 @@ extern int daemonize(int nochdir, int noclose); void memcached_thread_init(int nthreads); void redispatch_conn(conn *c); -void dispatch_conn_new(int sfd, enum conn_states init_state, int event_flags, int read_buffer_size, enum network_transport transport); +void dispatch_conn_new(int sfd, enum conn_states init_state, int event_flags, int read_buffer_size, enum network_transport transport, bool local); void sidethread_conn_close(conn *c); /* Lock wrappers for cache functions that are called from main loop. */ diff --git a/testapp.c b/testapp.c index de43b3da01..86e4c413a6 100644 --- a/testapp.c +++ b/testapp.c @@ -17,6 +17,7 @@ #include #include #include +#include #include "config.h" #include "cache.h" @@ -292,7 +293,7 @@ static enum test_return test_safe_strtol(void) { * as a daemon process * @return the pid of the memcached server */ -static pid_t start_server(in_port_t *port_out, bool daemon, int timeout) { +static pid_t start_server(in_port_t *port_out, bool daemon, int timeout, bool localWriteOnly, char * extra_args ) { char environment[80]; snprintf(environment, sizeof(environment), "MEMCACHED_PORT_FILENAME=/tmp/ports.%lu", (long)getpid()); @@ -350,8 +351,16 @@ static pid_t start_server(in_port_t *port_out, bool daemon, int timeout) { argv[arg++] = pid_file; } #ifdef MESSAGE_DEBUG - argv[arg++] = "-vvv"; + argv[arg++] = "-vvv"; #endif + if( localWriteOnly ){ + argv[arg++] = "-o"; + argv[arg++] = "local_write_only"; + } + if ( extra_args != NULL ){ + argv[arg++] = extra_args; + } + argv[arg++] = NULL; assert(execv(argv[0], argv) != -1); } @@ -412,7 +421,7 @@ static pid_t start_server(in_port_t *port_out, bool daemon, int timeout) { static enum test_return test_issue_44(void) { in_port_t port; - pid_t pid = start_server(&port, true, 15); + pid_t pid = start_server(&port, true, 15, false, NULL); assert(kill(pid, SIGHUP) == 0); sleep(1); assert(kill(pid, SIGTERM) == 0); @@ -561,6 +570,19 @@ static void read_ascii_response(char *buffer, size_t size) { } while (need_more); } +static int read_multiline_ascii_response(char **buffers, size_t max_num_lines, size_t size){ + size_t i = 0; + while ( i < max_num_lines){ + read_ascii_response(buffers[i], size); + if ((strncmp(buffers[i++], "END", strlen("END")) == 0 )) + { + return i; + } + } + return -1 * i; +} + + static enum test_return test_issue_92(void) { char buffer[1024]; @@ -630,7 +652,7 @@ static enum test_return test_issue_102(void) { } static enum test_return start_memcached_server(void) { - server_pid = start_server(&port, false, 600); + server_pid = start_server(&port, false, 600,false,NULL); sock = connect_server("127.0.0.1", port, false); return TEST_PASS; } @@ -1043,7 +1065,7 @@ static enum test_return test_binary_noop(void) { return TEST_PASS; } -static enum test_return test_binary_quit_impl(uint8_t cmd) { +static enum test_return test_binary_quit_impl_no_reconnect(uint8_t cmd) { union { protocol_binary_request_no_extras request; protocol_binary_response_no_extras response; @@ -1062,6 +1084,11 @@ static enum test_return test_binary_quit_impl(uint8_t cmd) { /* Socket should be closed now, read should return 0 */ assert(read(sock, buffer.bytes, sizeof(buffer.bytes)) == 0); close(sock); + return TEST_PASS; +} + +static enum test_return test_binary_quit_impl(uint8_t cmd) { + assert(test_binary_quit_impl_no_reconnect(cmd) == TEST_PASS); sock = connect_server("127.0.0.1", port, false); return TEST_PASS; @@ -1155,6 +1182,44 @@ static enum test_return test_binary_add_impl(const char *key, uint8_t cmd) { return TEST_PASS; } +static enum test_return test_binary_addauthfailure_impl(const char *key, uint8_t cmd) { + uint64_t value = 0xdeadbeefdeadcafe; + union { + protocol_binary_request_no_extras request; + protocol_binary_response_no_extras response; + char bytes[1024]; + } send, receive; + size_t len = storage_command(send.bytes, sizeof(send.bytes), cmd, key, + strlen(key), &value, sizeof(value), + 0, 0); + + /* Add should never work*/ + safe_send(send.bytes, len, false); + safe_recv_packet(receive.bytes, sizeof(receive.bytes)); + validate_response_header(&receive.response, cmd, + PROTOCOL_BINARY_RESPONSE_AUTH_ERROR); + + return TEST_PASS; +} + +static enum test_return test_read_single_bin_value( char *key, uint8_t cmd ) +{ + union { + protocol_binary_request_no_extras request; + protocol_binary_response_no_extras response; + char bytes[1024]; + } send, receive; + size_t len = ext_command(send.bytes, sizeof(send.bytes), PROTOCOL_BINARY_CMD_GET, + NULL, 0, key, strlen(key), NULL, 0); + + safe_send(send.bytes, len, false); + safe_recv_packet(receive.bytes, sizeof(receive.bytes)); + validate_response_header(&receive.response, cmd, + PROTOCOL_BINARY_RESPONSE_SUCCESS); + return TEST_PASS; +} + + static enum test_return test_binary_add(void) { return test_binary_add_impl("test_binary_add", PROTOCOL_BINARY_CMD_ADD); } @@ -1874,7 +1939,7 @@ static enum test_return test_issue_101(void) { const char *command = "stats\r\nstats\r\nstats\r\nstats\r\nstats\r\n"; size_t cmdlen = strlen(command); - server_pid = start_server(&port, false, 1000); + server_pid = start_server(&port, false, 1000, false, NULL); for (ii = 0; ii < max; ++ii) { fds[ii] = connect_server("127.0.0.1", port, true); @@ -1929,6 +1994,147 @@ static enum test_return test_issue_101(void) { return ret; } +static char *get_first_nonlocal_addr(){ + struct ifaddrs *addrs; + getifaddrs(&addrs); + struct ifaddrs *tmp = addrs; + + char * result = (char *) malloc(17); + result[0] = '\0'; + while(tmp){ + if(tmp->ifa_addr && tmp->ifa_addr->sa_family == AF_INET){ + struct sockaddr_in *pAddr = (struct sockaddr_in *)tmp->ifa_addr; + char *ipAddr = inet_ntoa(pAddr->sin_addr); + if ( strlen(ipAddr) > 6 + && strncmp(ipAddr, "169", 3) != 0 + && strncmp(ipAddr, "127", 3) != 0){ + sprintf(result, "%s", ipAddr); + break; + } + } + tmp = tmp->ifa_next; + } + freeifaddrs(addrs); + return result; +} + +static enum test_return test_send_ascii_command(const char * command, + const char * expectedResponse ) +{ + send_ascii_command(command); + char response_buffer[80]; + read_ascii_response(response_buffer, sizeof(response_buffer)); + assert(strncmp(response_buffer, expectedResponse, strlen(expectedResponse)) == 0); + return TEST_PASS; +} + +static enum test_return test_ascii_quit() +{ + send_ascii_command("quit\n"); + char response_buffer[1]; + ssize_t nr = read(sock, response_buffer, 1); + assert(nr == 0); + return TEST_PASS; + +} + +static enum test_return test_ascii_get(const char *command, const char **expectedResults, size_t num_lines, size_t line_len ) +{ + char *buffer = (char*) malloc(num_lines * line_len); + char *multiline_response[num_lines]; + for ( size_t i = 0; i < num_lines; ++i ) + { + multiline_response[i] = buffer + i * line_len; + } + + send_ascii_command(command); + int read_lines = + read_multiline_ascii_response(multiline_response, num_lines, line_len); + assert( read_lines == num_lines ); + for ( size_t i = 0; i < num_lines; ++i ) + { + assert( strncmp(multiline_response[i], expectedResults[i], + strlen(expectedResults[i])) == 0 ); + } + free(buffer); + return TEST_PASS; +} + +static enum test_return test_ascii_read_write(){ + sock = connect_server("127.0.0.1", port, false); + test_send_ascii_command("add ascii_key 0 60 5\r\n12345\r\n", "STORED"); + + const char *expectedResult[3]; + expectedResult[0] = "VALUE ascii_key 0 5"; + expectedResult[1] = "12345"; + expectedResult[2] = "END"; + + assert(test_ascii_get("get ascii_key\r\n", expectedResult, 3, 40) == TEST_PASS); + close(sock); + return TEST_PASS; +} + +static enum test_return test_local_write_only(void){ + char *addr = get_first_nonlocal_addr(); + assert(strlen(addr) > 6); + char * extra_args = (char *) malloc(50); + snprintf(extra_args, 50, "-l127.0.0.1:11213,%s:11214", addr); + pid_t server_pid = start_server(&port, false, 60, true, extra_args); + free(extra_args); + + /* locally, should be able to read and write */ + sock = connect_server("127.0.0.1", 11213, false); + assert(test_binary_add_impl("test_binary_add", PROTOCOL_BINARY_CMD_ADD) == TEST_PASS); + assert(test_read_single_bin_value("test_binary_add", PROTOCOL_BINARY_CMD_GET) == TEST_PASS); + assert(test_binary_quit_impl_no_reconnect(PROTOCOL_BINARY_CMD_QUITQ) == TEST_PASS); + + /* remote socket - read only */ + sock = connect_server(addr, 11214, false); + assert(test_binary_addauthfailure_impl("test_binary_add", PROTOCOL_BINARY_CMD_ADD) == TEST_PASS); + close(sock); + + sock = connect_server(addr, 11214, false); + assert(test_read_single_bin_value("test_binary_add", PROTOCOL_BINARY_CMD_GET) == TEST_PASS); + assert(test_binary_quit_impl_no_reconnect(PROTOCOL_BINARY_CMD_QUITQ) == TEST_PASS); + + + free(addr); + assert(kill(server_pid, SIGTERM) == 0); + return TEST_PASS; +} + +static enum test_return test_local_write_only_ascii(void){ + char *addr = get_first_nonlocal_addr(); + assert(strlen(addr) > 6); + char * extra_args = (char *) malloc(50); + snprintf(extra_args, 50, "-l127.0.0.1:11215,%s:11216", addr); + pid_t server_pid = start_server(&port, false, 60, true, extra_args); + free(extra_args); + + /* ascii */ + sock = connect_server("127.0.0.1", 11215, false); + test_send_ascii_command("add ascii_key 0 60 5\r\n12345\r\n", "STORED"); + + const char *expectedResult[3]; + expectedResult[0] = "VALUE ascii_key 0 5"; + expectedResult[1] = "12345"; + expectedResult[2] = "END"; + + assert(test_ascii_get("get ascii_key\r\n", expectedResult, 3, 40) == TEST_PASS); + test_ascii_quit(); + + /* ascii */ + sock = connect_server(addr, 11216, false); + assert( test_send_ascii_command("ADD ascii_key2 0 60 5\n", "AUTH_ERROR Remote write not allowed")); + + assert(test_ascii_get("get ascii_key\r\n", expectedResult, 3, 40) == TEST_PASS); + test_ascii_quit(); + + free(addr); + assert(kill(server_pid, SIGTERM) == 0); + return TEST_PASS; +} + typedef enum test_return (*TEST_FUNC)(void); struct testcase { const char *description; @@ -1987,8 +2193,12 @@ struct testcase testcases[] = { { "binary_stat", test_binary_stat }, { "binary_illegal", test_binary_illegal }, { "binary_pipeline_hickup", test_binary_pipeline_hickup }, + { "ascii_read_write", test_ascii_read_write }, { "shutdown", shutdown_memcached_server }, { "stop_server", stop_memcached_server }, + /* This test needs to start its own server with different options */ + { "local_write_only", test_local_write_only }, + { "local_write_only_ascii", test_local_write_only_ascii }, { NULL, NULL } }; diff --git a/thread.c b/thread.c index b6231806fc..6d8a52fdfd 100644 --- a/thread.c +++ b/thread.c @@ -25,6 +25,7 @@ struct conn_queue_item { int read_buffer_size; enum network_transport transport; conn *c; + bool local; CQ_ITEM *next; }; @@ -405,7 +406,7 @@ static void thread_libevent_process(int fd, short which, void *arg) { if (NULL != item) { conn *c = conn_new(item->sfd, item->init_state, item->event_flags, item->read_buffer_size, item->transport, - me->base); + me->base, item->local); if (c == NULL) { if (IS_UDP(item->transport)) { fprintf(stderr, "Can't listen for events on UDP socket\n"); @@ -456,7 +457,8 @@ static int last_thread = -1; * of an incoming connection. */ void dispatch_conn_new(int sfd, enum conn_states init_state, int event_flags, - int read_buffer_size, enum network_transport transport) { + int read_buffer_size, enum network_transport transport, + bool local) { CQ_ITEM *item = cqi_new(); char buf[1]; if (item == NULL) { @@ -477,6 +479,7 @@ void dispatch_conn_new(int sfd, enum conn_states init_state, int event_flags, item->event_flags = event_flags; item->read_buffer_size = read_buffer_size; item->transport = transport; + item->local = local; cq_push(thread->new_conn_queue, item); @@ -504,6 +507,7 @@ void redispatch_conn(conn *c) { item->sfd = c->sfd; item->init_state = conn_new_cmd; item->c = c; + item->local = c->local; cq_push(thread->new_conn_queue, item);