Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OTA: chunked download #520

Merged
merged 18 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ArduinoIoTCloudTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ int ArduinoIoTCloudTCP::begin(bool const enable_watchdog, String brokerAddress,

#if OTA_ENABLED && !defined(OFFLOADED_DOWNLOAD)
_ota.setClient(&_otaClient);
if (_connection->getInterface() == NetworkAdapter::ETHERNET) {
_ota.setFetchMode(OTADefaultCloudProcessInterface::OtaFetchChunk);
}
#endif // OTA_ENABLED && !defined(OFFLOADED_DOWNLOAD)

#if OTA_ENABLED && defined(OTA_BASIC_AUTH)
Expand Down
132 changes: 89 additions & 43 deletions src/ota/interface/OTAInterfaceDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ OTADefaultCloudProcessInterface::OTADefaultCloudProcessInterface(MessageStream *
, client(client)
, http_client(nullptr)
, username(nullptr), password(nullptr)
, fetchMode(OtaFetchTime)
, context(nullptr) {
}

Expand All @@ -41,57 +42,43 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
}
);

// make the http get request
// check url
if(strcmp(context->parsed_url.schema(), "https") == 0) {
http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port());
} else {
return UrlParseErrorFail;
}

http_client->beginRequest();
auto res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

http_client->endRequest();

if(res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(statusCode != 200) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}
// make the http get request
requestOta(OtaFetchTime);
andreagilardoni marked this conversation as resolved.
Show resolved Hide resolved

// The following call is required to save the header value , keep it
if(http_client->contentLength() == HttpClient::kNoContentLengthHeader) {
context->contentLength = http_client->contentLength();
if(context->contentLength == HttpClient::kNoContentLengthHeader) {
DEBUG_VERBOSE("OTA ERROR: the response header doesn't contain \"ContentLength\" field");
return HttpHeaderErrorFail;
}

context->lastReportTime = millis();

DEBUG_VERBOSE("OTA file length: %d", context->contentLength);
return Fetch;
}

OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
OTACloudProcessInterface::State res = Fetch;
int http_res = 0;
uint32_t start = millis();

if(fetchMode == OtaFetchChunk) {
res = requestOta(OtaFetchChunk);
}

context->downloadedChunkSize = 0;
context->downloadedChunkStartTime = millis();

if(res != Fetch) {
goto exit;
}

/* download chunked or timed */
do {
if(!http_client->connected()) {
res = OtaDownloadFail;
Expand All @@ -104,7 +91,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
continue;
}

http_res = http_client->read(context->buffer, context->buf_len);
int http_res = http_client->read(context->buffer, context->bufLen);

if(http_res < 0) {
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
Expand All @@ -119,8 +106,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
res = ErrorWriteUpdateFileFail;
goto exit;
}
} while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) &&
millis() - start < downloadTime);

context->downloadedChunkSize += http_res;

} while(context->downloadState < OtaDownloadCompleted && fetchMore());

// TODO verify that the information present in the ota header match the info in context
if(context->downloadState == OtaDownloadCompleted) {
Expand Down Expand Up @@ -153,13 +142,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
return res;
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) {
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OTAFetchMode mode) {
int http_res = 0;

/* stop connected client */
http_client->stop();

/* request chunk */
http_client->beginRequest();
http_res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

if(mode == OtaFetchChunk) {
char range[128] = {0};
size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize;
sprintf(range, "bytes=%d-%d", context->downloadedSize, context->downloadedSize + rangeSize);
DEBUG_VERBOSE("OTA downloading range: %s", range);
http_client->sendHeader("Range", range);
}

http_client->endRequest();

if(http_res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(http_res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(http_res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(((mode == OtaFetchChunk) && (statusCode != 206)) || ((mode == OtaFetchTime) && (statusCode != 200))) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}

http_client->skipResponseHeaders();
andreagilardoni marked this conversation as resolved.
Show resolved Hide resolved

return Fetch;
}

bool OTADefaultCloudProcessInterface::fetchMore() {
if (fetchMode == OtaFetchChunk) {
return context->downloadedChunkSize < maxChunkSize;
} else {
return (millis() - context->downloadedChunkStartTime) < downloadTime;
}
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) {
assert(context != nullptr); // This should never fail

for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+buf_len; ) {
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+bufLen; ) {
switch(context->downloadState) {
case OtaDownloadHeader: {
const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes;
const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes;
memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft);
cursor += headerLeft;
context->headerCopiedBytes += headerLeft;
Expand All @@ -184,8 +229,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
break;
}
case OtaDownloadFile: {
const uint32_t contentLength = http_client->contentLength();
const uint32_t dataLeft = buf_len - (cursor-buffer);
const uint32_t dataLeft = bufLen - (cursor-buffer);
context->decoder.decompress(cursor, dataLeft); // TODO verify return value

context->calculatedCrc32 = crc_update(
Expand All @@ -198,18 +242,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
context->downloadedSize += dataLeft;

if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength);
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength);

reportStatus(context->downloadedSize);
context->lastReportTime = millis();
}

// TODO there should be no more bytes available when the download is completed
if(context->downloadedSize == contentLength) {
if(context->downloadedSize == context->contentLength) {
context->downloadState = OtaDownloadCompleted;
}

if(context->downloadedSize > contentLength) {
if(context->downloadedSize > context->contentLength) {
context->downloadState = OtaDownloadError;
}
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
Expand Down Expand Up @@ -250,7 +294,9 @@ OTADefaultCloudProcessInterface::Context::Context(
, headerCopiedBytes(0)
, downloadedSize(0)
, lastReportTime(0)
, contentLength(0)
, writeError(false)
, downloadedChunkSize(0)
, decoder(putc) { }

static const uint32_t crc_table[256] = {
Expand Down
20 changes: 18 additions & 2 deletions src/ota/interface/OTAInterfaceDefault.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,36 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
this->password = password;
}

enum OTAFetchMode: uint8_t {
OtaFetchTime,
OtaFetchChunk
};

inline virtual void setFetchMode(OTAFetchMode mode) { this->fetchMode = mode; }

protected:
State startOTA();
State fetch();
void reset();
virtual int writeFlash(uint8_t* const buffer, size_t len) = 0;

private:
void parseOta(uint8_t* buffer, size_t buf_len);
void parseOta(uint8_t* buffer, size_t bufLen);
State requestOta(OTAFetchMode mode);
bool fetchMore();

Client* client;
HttpClient* http_client;

const char *username, *password;
OTAFetchMode fetchMode;

// The amount of time that each iteration of Fetch has to take at least
// This mitigate the issues arising from tasks run in main loop that are using all the computing time
static constexpr uint32_t downloadTime = 2000;

static constexpr size_t maxChunkSize = 1024 * 10;

enum OTADownloadState: uint8_t {
OtaDownloadHeader,
OtaDownloadFile,
Expand All @@ -74,12 +86,16 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
uint32_t headerCopiedBytes;
uint32_t downloadedSize;
uint32_t lastReportTime;
uint32_t contentLength;
bool writeError;

uint32_t downloadedChunkStartTime;
uint32_t downloadedChunkSize;

// LZSS decoder
LZSSDecoder decoder;

const size_t buf_len = 64;
const size_t bufLen = 64;
uint8_t buffer[64];
} *context;
};
Expand Down
Loading