Skip to content

Commit

Permalink
[mdns] Fix handling of SRP client Update message
Browse files Browse the repository at this point in the history
Signed-off-by: Marius Preda <[email protected]>
  • Loading branch information
marius-preda committed Nov 6, 2023
1 parent 72cd9ee commit 0df7afb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 50 deletions.
115 changes: 67 additions & 48 deletions src/core/net/mdns_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1280,8 +1280,6 @@ Error MdnsServer::UpdateServiceContent(Service *aService,

if (aTxtEntries != nullptr)
{
VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), aTxtEntries, aTxtEntries->mValueLength),
error = kErrorDuplicated);

error = kErrorNone;

Expand All @@ -1308,10 +1306,13 @@ Error MdnsServer::UpdateServiceContent(Service *aService,
txtBufferOffset += aTxtEntries[i].mValueLength;
}
}
VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), txtBuffer, txtBufferOffset), error = kErrorDuplicated);
VerifyOrExit(aService->mTxtData.SetFrom(txtBuffer, txtBufferOffset) == kErrorNone, error = kErrorFailed);
}
else
{
VerifyOrExit(memcmp(aService->mTxtData.GetBytes(), aTxtEntries, aTxtEntries->mValueLength),
error = kErrorDuplicated);
VerifyOrExit(aService->mTxtData.SetFrom(aTxtEntries->mValue, aTxtEntries->mValueLength) == kErrorNone,
error = kErrorFailed);
}
Expand Down Expand Up @@ -1727,8 +1728,11 @@ void MdnsServer::HandleAnnouncerFinished()
{
OutstandingUpdate *update = nullptr;
update = mOutstandingUpdates.GetHead();
IgnoreError(mOutstandingUpdates.Remove(*update));
update->Free();
if(update)
{
IgnoreError(mOutstandingUpdates.Remove(*update));
update->Free();
}

for(Service &service : mServices)
{
Expand Down Expand Up @@ -1769,7 +1773,6 @@ void MdnsServer::MdnsProbingHandler()
}
else
{
//update->SetState(OutstandingUpdate::State::kStateProbing);
break;
}
}
Expand All @@ -1787,18 +1790,21 @@ void MdnsServer::MdnsAnnouncingHandler()
{
//update->SetState(OutstandingUpdate::State::kStateAnnouncing);
Message *announceMessage = nullptr;
if(!update->GetHost())
if(!update->GetHostName())
{
announceMessage = CreateHostAndServicesAnnounceMessage(update);
}
else
{
announceMessage = CreateSrpAnnounceMessage(update->GetHost());
announceMessage = CreateSrpAnnounceMessage(update->GetHostName());
}

VerifyOrExit(announceMessage != nullptr);
Get<MdnsServer::Announcer>().EnqueueAnnounceMessage(*announceMessage);
Get<MdnsServer::Announcer>().StartAnnouncing();
}
exit:
return;
}

void MdnsServer::CheckForOutstandingUpdates()
Expand Down Expand Up @@ -2082,7 +2088,11 @@ void MdnsServer::Announcer::Stop(void)
Error MdnsServer::OutstandingUpdate::Init(uint32_t aId, otSrpServerHost *aHost, Type aType)
{
mId = aId;
mHost = aHost;
if(aHost != nullptr)
{
mHost = aHost;
mHostName.Set(AsCoreType(aHost).GetFullName());
}
mType = aType;
mState = kStateIdle;

Expand Down Expand Up @@ -2535,7 +2545,7 @@ void MdnsServer::SrpAdvertisingProxyHandler(otSrpServerServiceUpdateId aId, cons
update = OutstandingUpdate::AllocateAndInit(aId, AsNonConst(aHost), OutstandingUpdate::kTypeSrpServiceRemoved);
VerifyOrExit(update != nullptr);
update->SetService(service);
ExitNow();
break;
}
}
update = OutstandingUpdate::AllocateAndInit(aId, AsNonConst(aHost), OutstandingUpdate::kTypeProbeAndAnnounce);
Expand Down Expand Up @@ -2594,24 +2604,17 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost)
Header header;

Message *message = nullptr;
Message *QSectionMsg = nullptr;
Message *AuthSectionMsg = nullptr;

Question question(ResourceRecord::kTypeAny, ResourceRecord::kClassInternet);

bool shouldPublishHost = true;
const otSrpServerHost *host = nullptr;

VerifyOrExit((message = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs);
VerifyOrExit((QSectionMsg = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs);
VerifyOrExit((AuthSectionMsg = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs);

question.SetQuQuestion();

// Allocate space for DNS header
SuccessOrExit(error = message->SetLength(sizeof(Header)));
SuccessOrExit(error = QSectionMsg->SetLength(sizeof(Header)));
SuccessOrExit(error = AuthSectionMsg->SetLength(sizeof(Header)));

// Setup initial DNS response header
header.SetType(Header::kTypeQuery);
Expand All @@ -2632,14 +2635,32 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost)
{
// Hostname
SuccessOrExit(error =
Get<Server>().AppendHostName(*QSectionMsg, name, compressInfo));
QSectionMsg->Append(question);
Get<Server>().AppendHostName(*message, name, compressInfo));
message->Append(question);
header.SetQuestionCount(header.GetQuestionCount() + 1);
}

while ((service = AsCoreType(aHost).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr,
nullptr)) != nullptr)
{
char serviceName[Name::kMaxNameSize] = {0};

if (!service->IsDeleted())
{
ConvertDomainName(serviceName, service->GetInstanceName(), kThreadDefaultDomainName, kDefaultDomainName);
SuccessOrExit(error =
Get<Server>().AppendInstanceName(*message, serviceName, compressInfo));
message->Append(question);
header.SetQuestionCount(header.GetQuestionCount() + 1);
}
}

if (shouldPublishHost)
{
// AAAA Resource Record
for (uint8_t i = 0; i < addrNum; i++)
{
SuccessOrExit(error = Get<Server>().AppendAaaaRecord(*AuthSectionMsg, name,
SuccessOrExit(error = Get<Server>().AppendAaaaRecord(*message, name,
addrs[i], hostTtl, compressInfo));
header.SetAuthorityRecordCount(header.GetAuthorityRecordCount() + 1);
}
Expand All @@ -2653,38 +2674,17 @@ Message* MdnsServer::CreateSrpPublishMessage(const otSrpServerHost *aHost)
if (!service->IsDeleted())
{
ConvertDomainName(serviceName, service->GetInstanceName(), kThreadDefaultDomainName, kDefaultDomainName);
SuccessOrExit(error =
Get<Server>().AppendInstanceName(*QSectionMsg, serviceName, compressInfo));
QSectionMsg->Append(question);
header.SetQuestionCount(header.GetQuestionCount() + 1);

SuccessOrExit(error = Get<Server>().AppendSrvRecord(*AuthSectionMsg, serviceName,
SuccessOrExit(error = Get<Server>().AppendSrvRecord(*message, serviceName,
name, service->GetTtl(),
service->GetPriority(), service->GetWeight(),
service->GetPort(), compressInfo));
header.SetAuthorityRecordCount(header.GetAuthorityRecordCount() + 1);
}
}

if (header.GetQuestionCount())
{
SuccessOrExit(error = message->AppendBytesFromMessage(*QSectionMsg, sizeof(Header),
(QSectionMsg->GetLength() - sizeof(Header)) -
QSectionMsg->GetOffset()));
}

if (header.GetAuthorityRecordCount())
{
SuccessOrExit(error = message->AppendBytesFromMessage(*AuthSectionMsg, sizeof(Header),
(AuthSectionMsg->GetLength() - sizeof(Header)) -
AuthSectionMsg->GetOffset()));
}
header.SetResponseCode(Header::kResponseSuccess);
message->Write(0, header);

QSectionMsg->Free();
AuthSectionMsg->Free();

return message;

exit:
Expand All @@ -2708,26 +2708,42 @@ bool MdnsServer::AddressIsFromLocalSubnet(const Ip6::Address &srcAddr)
return false;
}

Message* MdnsServer::CreateSrpAnnounceMessage(const otSrpServerHost *aHost)
Message* MdnsServer::CreateSrpAnnounceMessage(const char *aHostName)
{
Error error = kErrorNone;
NameCompressInfo compressInfo(kDefaultDomainName);
char name[Name::kMaxNameSize];

uint8_t addrNum;
const Ip6::Address *addrs = AsCoreType(aHost).GetAddresses(addrNum);
uint32_t hostTtl = TimeMilli::MsecToSec(AsCoreType(aHost).GetExpireTime() - TimerMilli::GetNow());
uint8_t addrNum;
const Ip6::Address *addrs;
uint32_t hostTtl;
const Srp::Server::Service *service = nullptr;

Message *message = nullptr;
Header header;

const otSrpServerHost *host = nullptr;

while ((host = otSrpServerGetNextHost(reinterpret_cast<otInstance *>(&InstanceLocator::GetInstance()), host)) !=
nullptr)
{
if (StringMatch(AsCoreType(host).GetFullName(), aHostName, kStringCaseInsensitiveMatch))
{
break;
}
}

VerifyOrExit(host != nullptr, error = kErrorAbort);

addrs = AsCoreType(host).GetAddresses(addrNum);
hostTtl = TimeMilli::MsecToSec(AsCoreType(host).GetExpireTime() - TimerMilli::GetNow());

VerifyOrExit((message = mSocket.NewMessage(0)) != nullptr, error = kErrorNoBufs);
SuccessOrExit(error = message->SetLength(sizeof(Header)));

header.SetType(Header::kTypeResponse);

Get<MdnsServer>().ConvertDomainName(name, AsCoreType(aHost).GetFullName(), kThreadDefaultDomainName, kDefaultDomainName);
Get<MdnsServer>().ConvertDomainName(name, AsCoreType(host).GetFullName(), kThreadDefaultDomainName, kDefaultDomainName);

// AAAA Resource Record
for (uint8_t i = 0; i < addrNum; i++)
Expand All @@ -2737,7 +2753,7 @@ Message* MdnsServer::CreateSrpAnnounceMessage(const otSrpServerHost *aHost)
Server::IncResourceRecordCount(header, false);
}

while ((service = AsCoreType(aHost).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr,
while ((service = AsCoreType(host).FindNextService(service, OT_SRP_SERVER_FLAGS_BASE_TYPE_SERVICE_ONLY, nullptr,
nullptr)) != nullptr)
{
char serviceName[Name::kMaxNameSize] = {0};
Expand Down Expand Up @@ -2775,12 +2791,15 @@ Error MdnsServer::PublishFromSrp(const otSrpServerHost *aHost)
Error error = kErrorNone;

Message *message = CreateSrpPublishMessage(aHost);
OutstandingUpdate *update = mOutstandingUpdates.GetHead();
VerifyOrExit(update != nullptr, error = kErrorFailed);
VerifyOrExit(message != nullptr, error = kErrorNoBufs);

if (message->GetLength() == sizeof(Header))
{
Get<Srp::Server>().HandleServiceUpdateResult(mOutstandingUpdates.GetHead()->GetId(), kErrorNone);
mOutstandingUpdates.Remove(*mOutstandingUpdates.GetHead());
mOutstandingUpdates.Remove(*update);
update->Free();
ExitNow();
}

Expand Down
4 changes: 3 additions & 1 deletion src/core/net/mdns_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable
bool Matches(const uint32_t aId) const { return mId == aId; }
uint32_t GetId(void) { return mId; }
const otSrpServerHost *GetHost(void) { return mHost; }
const char *GetHostName() {return mHostName.AsCString();}
void SetService(const otSrpServerService *aService) { mService = aService; }
LinkedList<MdnsServer::Service> GetServiceList(void) { return mServiceList; }
const otSrpServerService *GetService(void) { return mService; }
Expand All @@ -459,6 +460,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable
private:
uint32_t mId;
const otSrpServerHost *mHost;
Heap::String mHostName;
const otSrpServerService *mService;
LinkedList<MdnsServer::Service> mServiceList;
State mState;
Expand Down Expand Up @@ -723,7 +725,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable
Message *CreateHostAndServicesAnnounceMessage(OutstandingUpdate *aUpdate);
Message *CreateHostAndServicesPublishMessage(OutstandingUpdate *aUpdate);
Error PublishHostAndServices(OutstandingUpdate *aUpdate);
Message *CreateSrpAnnounceMessage(const otSrpServerHost *aHost);
Message *CreateSrpAnnounceMessage(const char *aHostName);
Message *CreateSrpPublishMessage(const otSrpServerHost *aHost);
Error PublishFromSrp(const otSrpServerHost *aHost);
bool AddressIsFromLocalSubnet(const Ip6::Address &srcAddr);
Expand Down
1 change: 0 additions & 1 deletion src/core/net/srp_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ void Server::CommitSrpUpdate(Error aError,
else if (existingHost != nullptr)
{
SuccessOrExit(aError = existingHost->MergeServicesAndResourcesFrom(aHost));
shouldFreeHost = false;
}
else
{
Expand Down

0 comments on commit 0df7afb

Please sign in to comment.