From 6af6ca52a2d96c8714c127d90152670ad78aaa33 Mon Sep 17 00:00:00 2001 From: Xu Si Yu Date: Thu, 25 Apr 2024 10:57:06 +0800 Subject: [PATCH] feat(mdns): add check of instance when handling PTR query --- components/mdns/mdns.c | 110 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 5 deletions(-) diff --git a/components/mdns/mdns.c b/components/mdns/mdns.c index c341482a03..d2c67d02d0 100644 --- a/components/mdns/mdns.c +++ b/components/mdns/mdns.c @@ -1871,6 +1871,7 @@ static void _mdns_create_answer_from_parsed_packet(mdns_parsed_packet_t *parsed_ packet->id = parsed_packet->id; mdns_parsed_question_t *q = parsed_packet->questions; + uint32_t out_record_nums = 0; while (q) { shared = q->type == MDNS_TYPE_PTR || q->type == MDNS_TYPE_SDPTR || !parsed_packet->probe; if (q->type == MDNS_TYPE_SRV || q->type == MDNS_TYPE_TXT) { @@ -1878,14 +1879,36 @@ static void _mdns_create_answer_from_parsed_packet(mdns_parsed_packet_t *parsed_ if (service == NULL || !_mdns_create_answer_from_service(packet, service->service, q, shared, send_flush)) { _mdns_free_tx_packet(packet); return; + } else { + out_record_nums++; } } else if (q->service && q->proto) { mdns_srv_item_t *service = _mdns_server->services; while (service) { if (_mdns_service_match_ptr_question(service->service, q)) { - if (!_mdns_create_answer_from_service(packet, service->service, q, shared, send_flush)) { - _mdns_free_tx_packet(packet); - return; + mdns_parsed_record_t *r = parsed_packet->records; + bool is_record_exist = false; + while (r) { + if (service->service->instance && r->host) { + if (_mdns_service_match_instance(service->service, r->host, r->service, r->proto, NULL) && r->ttl > (MDNS_ANSWER_PTR_TTL / 2)) { + is_record_exist = true; + break; + } + } else if (!service->service->instance && !r->host) { + if (_mdns_service_match(service->service, r->service, r->proto, NULL) && r->ttl > (MDNS_ANSWER_PTR_TTL / 2)) { + is_record_exist = true; + break; + } + } + r = r->next; + } + if (!is_record_exist) { + if (!_mdns_create_answer_from_service(packet, service->service, q, shared, send_flush)) { + _mdns_free_tx_packet(packet); + return; + } else { + out_record_nums++; + } } } service = service->next; @@ -1894,22 +1917,31 @@ static void _mdns_create_answer_from_parsed_packet(mdns_parsed_packet_t *parsed_ if (!_mdns_create_answer_from_hostname(packet, q->host, send_flush)) { _mdns_free_tx_packet(packet); return; + } else { + out_record_nums++; } } else if (q->type == MDNS_TYPE_ANY) { if (!_mdns_append_host_list(&packet->answers, send_flush, false)) { _mdns_free_tx_packet(packet); return; + } else { + out_record_nums++; } #ifdef CONFIG_MDNS_RESPOND_REVERSE_QUERIES } else if (q->type == MDNS_TYPE_PTR) { mdns_host_item_t *host = mdns_get_host_item(q->host); if (!_mdns_alloc_answer(&packet->answers, MDNS_TYPE_PTR, NULL, host, send_flush, false)) { + _mdns_free_tx_packet(packet); return; + } else { + out_record_nums++; } #endif /* CONFIG_MDNS_RESPOND_REVERSE_QUERIES */ } else if (!_mdns_alloc_answer(&packet->answers, q->type, NULL, NULL, send_flush, false)) { _mdns_free_tx_packet(packet); return; + } else { + out_record_nums++; } if (parsed_packet->src_port != MDNS_SERVICE_PORT && // Repeat the queries only for "One-Shot mDNS queries" @@ -1943,6 +1975,10 @@ static void _mdns_create_answer_from_parsed_packet(mdns_parsed_packet_t *parsed_ } q = q->next; } + if (out_record_nums == 0) { + _mdns_free_tx_packet(packet); + return; + } if (unicast || !send_flush) { memcpy(&packet->dst, &parsed_packet->src, sizeof(esp_ip_addr_t)); packet->port = parsed_packet->src_port; @@ -3336,7 +3372,11 @@ static bool _mdns_question_matches(mdns_parsed_question_t *question, uint16_t ty && !strcasecmp(service->service->service, question->service) && !strcasecmp(service->service->proto, question->proto) && !strcasecmp(MDNS_DEFAULT_DOMAIN, question->domain)) { - return true; + if (!service->service->instance) { + return true; + } else if (service->service->instance && question->host && !strcasecmp(service->service->instance, question->host)) { + return true; + } } } else if (service && (type == MDNS_TYPE_SRV || type == MDNS_TYPE_TXT)) { const char *name = _mdns_get_service_instance_name(service->service); @@ -3635,6 +3675,7 @@ void mdns_parse_packet(mdns_rx_packet_t *packet) parsed_packet->id = header.id; esp_netif_ip_addr_copy(&parsed_packet->src, &packet->src); parsed_packet->src_port = packet->src_port; + parsed_packet->records = NULL; if (header.questions) { uint8_t qs = header.questions; @@ -3821,7 +3862,12 @@ void mdns_parse_packet(mdns_rx_packet_t *packet) _mdns_search_result_add_ptr(search_result, name->host, name->service, name->proto, packet->tcpip_if, packet->ip_protocol, ttl); } else if ((discovery || ours) && !name->sub && _mdns_name_is_ours(name)) { - if (discovery && (service = _mdns_get_service_item(name->service, name->proto, NULL))) { + if (name->host[0]) { + service = _mdns_get_service_item_instance(name->host, name->service, name->proto, NULL); + } else { + service = _mdns_get_service_item(name->service, name->proto, NULL); + } + if (discovery && service) { _mdns_remove_parsed_question(parsed_packet, MDNS_TYPE_SDPTR, service); } else if (service && parsed_packet->questions && !parsed_packet->probe) { _mdns_remove_parsed_question(parsed_packet, type, service); @@ -3831,6 +3877,45 @@ void mdns_parse_packet(mdns_rx_packet_t *packet) _mdns_remove_scheduled_answer(packet->tcpip_if, packet->ip_protocol, type, service); } } + if (service) { + mdns_parsed_record_t *record = malloc(sizeof(mdns_parsed_record_t)); + if (!record) { + HOOK_MALLOC_FAILED; + goto clear_rx_packet; + } + record->next = parsed_packet->records; + parsed_packet->records = record; + record->type = MDNS_TYPE_PTR; + record->record_type = MDNS_ANSWER; + record->ttl = ttl; + record->host = NULL; + record->service = NULL; + record->proto = NULL; + if (name->host[0]) { + record->host = malloc(MDNS_NAME_BUF_LEN); + if (!record->host) { + HOOK_MALLOC_FAILED; + goto clear_rx_packet; + } + memcpy(record->host, name->host, MDNS_NAME_BUF_LEN); + } + if (name->service[0]) { + record->service = malloc(MDNS_NAME_BUF_LEN); + if (!record->service) { + HOOK_MALLOC_FAILED; + goto clear_rx_packet; + } + memcpy(record->service, name->service, MDNS_NAME_BUF_LEN); + } + if (name->proto[0]) { + record->proto = malloc(MDNS_NAME_BUF_LEN); + if (!record->proto) { + HOOK_MALLOC_FAILED; + goto clear_rx_packet; + } + memcpy(record->proto, name->proto, MDNS_NAME_BUF_LEN); + } + } } } else if (type == MDNS_TYPE_SRV) { mdns_result_t *result = NULL; @@ -4161,6 +4246,21 @@ void mdns_parse_packet(mdns_rx_packet_t *packet) } free(question); } + while (parsed_packet->records) { + mdns_parsed_record_t *record = parsed_packet->records; + parsed_packet->records = parsed_packet->records->next; + if (record->host) { + free(record->host); + } + if (record->service) { + free(record->service); + } + if (record->proto) { + free(record->proto); + } + record->next = NULL; + free(record); + } free(parsed_packet); if (browse_result_instance) { free(browse_result_instance);