diff --git a/src/core/net/mdns_server.cpp b/src/core/net/mdns_server.cpp index 1582a359a4c..bace80f83a8 100644 --- a/src/core/net/mdns_server.cpp +++ b/src/core/net/mdns_server.cpp @@ -1280,8 +1280,6 @@ Error MdnsServer::UpdateServiceContent(Service *aService, if (aTxtEntries != nullptr) { - VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), aTxtEntries, aTxtEntries->mValueLength), - error = kErrorDuplicated); error = kErrorNone; @@ -1308,10 +1306,13 @@ Error MdnsServer::UpdateServiceContent(Service *aService, txtBufferOffset += aTxtEntries[i].mValueLength; } } + VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), txtBuffer, txtBufferOffset), error = kErrorDuplicated); VerifyOrExit(aService->mTxtData.SetFrom(txtBuffer, txtBufferOffset) == kErrorNone, error = kErrorFailed); } else { + VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), aTxtEntries, aTxtEntries->mValueLength), + error = kErrorDuplicated); VerifyOrExit(aService->mTxtData.SetFrom(aTxtEntries->mValue, aTxtEntries->mValueLength) == kErrorNone, error = kErrorFailed); } @@ -1727,8 +1728,11 @@ void MdnsServer::HandleAnnouncerFinished() { OutstandingUpdate *update = nullptr; update = mOutstandingUpdates.GetHead(); - IgnoreError(mOutstandingUpdates.Remove(*update)); - update->Free(); + if(update) + { + IgnoreError(mOutstandingUpdates.Remove(*update)); + update->Free(); + } for(Service &service : mServices) { @@ -1769,7 +1773,6 @@ void MdnsServer::MdnsProbingHandler() } else { - //update->SetState(OutstandingUpdate::State::kStateProbing); break; } } @@ -1787,18 +1790,21 @@ void MdnsServer::MdnsAnnouncingHandler() { //update->SetState(OutstandingUpdate::State::kStateAnnouncing); Message *announceMessage = nullptr; - if(!update->GetHost()) + if(!update->GetHostName()) { announceMessage = CreateHostAndServicesAnnounceMessage(update); } else { - announceMessage = CreateSrpAnnounceMessage(update->GetHost()); + announceMessage = CreateSrpAnnounceMessage(update->GetHostName()); } + VerifyOrExit(announceMessage != nullptr); Get().EnqueueAnnounceMessage(*announceMessage); Get().StartAnnouncing(); } +exit: + return; } void MdnsServer::CheckForOutstandingUpdates() @@ -2082,7 +2088,11 @@ void MdnsServer::Announcer::Stop(void) Error MdnsServer::OutstandingUpdate::Init(uint32_t aId, otSrpServerHost *aHost, Type aType) { mId = aId; - mHost = aHost; + if(aHost != nullptr) + { + mHost = aHost; + mHostName.Set(AsCoreType(aHost).GetFullName()); + } mType = aType; mState = kStateIdle; @@ -2535,7 +2545,7 @@ void MdnsServer::SrpAdvertisingProxyHandler(otSrpServerServiceUpdateId aId, cons update = OutstandingUpdate::AllocateAndInit(aId, AsNonConst(aHost), OutstandingUpdate::kTypeSrpServiceRemoved); VerifyOrExit(update != nullptr); update->SetService(service); - ExitNow(); + break; } } update = OutstandingUpdate::AllocateAndInit(aId, AsNonConst(aHost), OutstandingUpdate::kTypeProbeAndAnnounce); @@ -2594,24 +2604,17 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost) Header header; Message *message = nullptr; - Message *QSectionMsg = nullptr; - Message *AuthSectionMsg = nullptr; - Question question(ResourceRecord::kTypeAny, ResourceRecord::kClassInternet); bool shouldPublishHost = true; const otSrpServerHost *host = nullptr; VerifyOrExit((message = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs); - VerifyOrExit((QSectionMsg = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs); - VerifyOrExit((AuthSectionMsg = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs); question.SetQuQuestion(); // Allocate space for DNS header SuccessOrExit(error = message->SetLength(sizeof(Header))); - SuccessOrExit(error = QSectionMsg->SetLength(sizeof(Header))); - SuccessOrExit(error = AuthSectionMsg->SetLength(sizeof(Header))); // Setup initial DNS response header header.SetType(Header::kTypeQuery); @@ -2632,14 +2635,32 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost) { // Hostname SuccessOrExit(error = - Get().AppendHostName(*QSectionMsg, name, compressInfo)); - QSectionMsg->Append(question); + Get().AppendHostName(*message, name, compressInfo)); + message->Append(question); header.SetQuestionCount(header.GetQuestionCount() + 1); + } + while ((service = AsCoreType(aHost).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr, + nullptr)) != nullptr) + { + char serviceName[Name::kMaxNameSize] = {0}; + + if (!service->IsDeleted()) + { + ConvertDomainName(serviceName, service->GetInstanceName(), kThreadDefaultDomainName, kDefaultDomainName); + SuccessOrExit(error = + Get().AppendInstanceName(*message, serviceName, compressInfo)); + message->Append(question); + header.SetQuestionCount(header.GetQuestionCount() + 1); + } + } + + if (shouldPublishHost) + { // AAAA Resource Record for (uint8_t i = 0; i < addrNum; i++) { - SuccessOrExit(error = Get().AppendAaaaRecord(*AuthSectionMsg, name, + SuccessOrExit(error = Get().AppendAaaaRecord(*message, name, addrs[i], hostTtl, compressInfo)); header.SetAuthorityRecordCount(header.GetAuthorityRecordCount() + 1); } @@ -2653,12 +2674,7 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost) if (!service->IsDeleted()) { ConvertDomainName(serviceName, service->GetInstanceName(), kThreadDefaultDomainName, kDefaultDomainName); - SuccessOrExit(error = - Get().AppendInstanceName(*QSectionMsg, serviceName, compressInfo)); - QSectionMsg->Append(question); - header.SetQuestionCount(header.GetQuestionCount() + 1); - - SuccessOrExit(error = Get().AppendSrvRecord(*AuthSectionMsg, serviceName, + SuccessOrExit(error = Get().AppendSrvRecord(*message, serviceName, name, service->GetTtl(), service->GetPriority(), service->GetWeight(), service->GetPort(), compressInfo)); @@ -2666,25 +2682,9 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost) } } - if (header.GetQuestionCount()) - { - SuccessOrExit(error = message->AppendBytesFromMessage(*QSectionMsg, sizeof(Header), - (QSectionMsg->GetLength() - sizeof(Header)) - - QSectionMsg->GetOffset())); - } - - if (header.GetAuthorityRecordCount()) - { - SuccessOrExit(error = message->AppendBytesFromMessage(*AuthSectionMsg, sizeof(Header), - (AuthSectionMsg->GetLength() - sizeof(Header)) - - AuthSectionMsg->GetOffset())); - } header.SetResponseCode(Header::kResponseSuccess); message->Write(0, header); - QSectionMsg->Free(); - AuthSectionMsg->Free(); - return message; exit: @@ -2708,26 +2708,42 @@ bool MdnsServer::AddressIsFromLocalSubnet(const Ip6::Address &srcAddr) return false; } -Message* MdnsServer::CreateSrpAnnounceMessage(const otSrpServerHost *aHost) +Message* MdnsServer::CreateSrpAnnounceMessage(const char *aHostName) { Error error = kErrorNone; NameCompressInfo compressInfo(kDefaultDomainName); char name[Name::kMaxNameSize]; - uint8_t addrNum; - const Ip6::Address *addrs = AsCoreType(aHost).GetAddresses(addrNum); - uint32_t hostTtl = TimeMilli::MsecToSec(AsCoreType(aHost).GetExpireTime() - TimerMilli::GetNow()); + uint8_t addrNum; + const Ip6::Address *addrs; + uint32_t hostTtl; const Srp::Server::Service *service = nullptr; Message *message = nullptr; Header header; + const otSrpServerHost *host = nullptr; + + while ((host = otSrpServerGetNextHost(reinterpret_cast(&InstanceLocator::GetInstance()), host)) != + nullptr) + { + if (StringMatch(AsCoreType(host).GetFullName(), aHostName, kStringCaseInsensitiveMatch)) + { + break; + } + } + + VerifyOrExit(host != nullptr, error = kErrorAbort); + + addrs = AsCoreType(host).GetAddresses(addrNum); + hostTtl = TimeMilli::MsecToSec(AsCoreType(host).GetExpireTime() - TimerMilli::GetNow()); + VerifyOrExit((message = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs); SuccessOrExit(error = message->SetLength(sizeof(Header))); header.SetType(Header::kTypeResponse); - Get().ConvertDomainName(name, AsCoreType(aHost).GetFullName(), kThreadDefaultDomainName, kDefaultDomainName); + Get().ConvertDomainName(name, AsCoreType(host).GetFullName(), kThreadDefaultDomainName, kDefaultDomainName); // AAAA Resource Record for (uint8_t i = 0; i < addrNum; i++) @@ -2737,7 +2753,7 @@ Message* MdnsServer::CreateSrpAnnounceMessage(const otSrpServerHost *aHost) Server::IncResourceRecordCount(header, false); } - while ((service = AsCoreType(aHost).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr, + while ((service = AsCoreType(host).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr, nullptr)) != nullptr) { char serviceName[Name::kMaxNameSize] = {0}; @@ -2775,12 +2791,15 @@ Error MdnsServer::PublishFromSrp(const otSrpServerHost *aHost) Error error = kErrorNone; Message *message = CreateSrpPublishMessage(aHost); + OutstandingUpdate *update = mOutstandingUpdates.GetHead(); + VerifyOrExit(update != nullptr, error = kErrorFailed); VerifyOrExit(message != nullptr, error = kErrorNoBufs); if (message->GetLength() == sizeof(Header)) { Get().HandleServiceUpdateResult(mOutstandingUpdates.GetHead()->GetId(), kErrorNone); - mOutstandingUpdates.Remove(*mOutstandingUpdates.GetHead()); + mOutstandingUpdates.Remove(*update); + update->Free(); ExitNow(); } diff --git a/src/core/net/mdns_server.hpp b/src/core/net/mdns_server.hpp index df9a1e9e79d..96718a9ceaf 100644 --- a/src/core/net/mdns_server.hpp +++ b/src/core/net/mdns_server.hpp @@ -445,6 +445,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable bool Matches(const uint32_t aId) const { return mId == aId; } uint32_t GetId(void) { return mId; } const otSrpServerHost *GetHost(void) { return mHost; } + const char *GetHostName() {return mHostName.AsCString();} void SetService(const otSrpServerService *aService) { mService = aService; } LinkedList GetServiceList(void) { return mServiceList; } const otSrpServerService *GetService(void) { return mService; } @@ -459,6 +460,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable private: uint32_t mId; const otSrpServerHost *mHost; + Heap::String mHostName; const otSrpServerService *mService; LinkedList mServiceList; State mState; @@ -723,7 +725,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable Message *CreateHostAndServicesAnnounceMessage(OutstandingUpdate *aUpdate); Message *CreateHostAndServicesPublishMessage(OutstandingUpdate *aUpdate); Error PublishHostAndServices(OutstandingUpdate *aUpdate); - Message *CreateSrpAnnounceMessage(const otSrpServerHost *aHost); + Message *CreateSrpAnnounceMessage(const char *aHostName); Message *CreateSrpPublishMessage(const otSrpServerHost *aHost); Error PublishFromSrp(const otSrpServerHost *aHost); bool AddressIsFromLocalSubnet(const Ip6::Address &srcAddr); diff --git a/src/core/net/srp_server.cpp b/src/core/net/srp_server.cpp index 10f2fbb52ec..9699d84aece 100644 --- a/src/core/net/srp_server.cpp +++ b/src/core/net/srp_server.cpp @@ -492,7 +492,6 @@ void Server::CommitSrpUpdate(Error aError, else if (existingHost != nullptr) { SuccessOrExit(aError = existingHost->MergeServicesAndResourcesFrom(aHost)); - shouldFreeHost = false; } else {