Skip to content

Commit

Permalink
Merge pull request #520 from pennam/main-ota-chunked
Browse files Browse the repository at this point in the history
OTA: chunked download
  • Loading branch information
pennam authored Dec 16, 2024
2 parents 07da25e + 0f53459 commit a5dcd3d
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 50 deletions.
13 changes: 11 additions & 2 deletions src/ArduinoIoTCloudTCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,18 @@ class ArduinoIoTCloudTCP: public ArduinoIoTCloudClass
_get_ota_confirmation = cb;

if(_get_ota_confirmation) {
_ota.setOtaPolicies(OTACloudProcessInterface::ApprovalRequired);
_ota.enableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
} else {
_ota.setOtaPolicies(OTACloudProcessInterface::None);
_ota.disableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
}
}

/* Slower but more reliable in some corner cases */
void setOTAChunkMode(bool enable = true) {
if(enable) {
_ota.enableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
} else {
_ota.disableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
}
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/ota/interface/OTAInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ OTACloudProcessInterface::State OTACloudProcessInterface::idle(Message* msg) {
OTACloudProcessInterface::State OTACloudProcessInterface::otaAvailable() {
// depending on the policy decided on this device the ota process can start immediately
// or wait for confirmation from the user
if((policies & (ApprovalRequired | Approved)) == ApprovalRequired ) {
if(getOtaPolicy(ApprovalRequired) && !getOtaPolicy(Approved)) {
return OtaAvailable;
} else {
policies &= ~Approved;
disableOtaPolicy(Approved);
return StartOTA;
} // TODO add an abortOTA command? in this case delete the context
}
Expand Down
9 changes: 7 additions & 2 deletions src/ota/interface/OTAInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,22 @@ class OTACloudProcessInterface: public CloudProcess {
enum OtaFlags: uint16_t {
None = 0,
ApprovalRequired = 1,
Approved = 1<<1
Approved = 1<<1,
ChunkDownload = 1<<2
};

virtual void handleMessage(Message*);
// virtual CloudProcess::State getState();
// virtual void hook(State s, void* action);
virtual void update() { handleMessage(nullptr); }

inline void approveOta() { policies |= Approved; }
inline void approveOta() { this->policies |= Approved; }
inline void setOtaPolicies(uint16_t policies) { this->policies = policies; }

inline void enableOtaPolicy(OtaFlags policyFlag) { this->policies |= policyFlag; }
inline void disableOtaPolicy(OtaFlags policyFlag) { this->policies &= ~policyFlag; }
inline bool getOtaPolicy(OtaFlags policyFlag) { return (this->policies & policyFlag) != 0;}

inline State getState() { return state; }

virtual bool isOtaCapable() = 0;
Expand Down
130 changes: 89 additions & 41 deletions src/ota/interface/OTAInterfaceDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,17 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
}
);

// make the http get request
// check url
if(strcmp(context->parsed_url.schema(), "https") == 0) {
http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port());
} else {
return UrlParseErrorFail;
}

http_client->beginRequest();
auto res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

http_client->endRequest();

if(res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(statusCode != 200) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
// make the http get request
OTACloudProcessInterface::State res = requestOta();
if(res != Fetch) {
return res;
}

// The following call is required to save the header value , keep it
Expand All @@ -82,16 +60,27 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
return HttpHeaderErrorFail;
}

context->contentLength = http_client->contentLength();
context->lastReportTime = millis();

DEBUG_VERBOSE("OTA file length: %d", context->contentLength);
return Fetch;
}

OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
OTACloudProcessInterface::State res = Fetch;
int http_res = 0;
uint32_t start = millis();

if(getOtaPolicy(ChunkDownload)) {
res = requestOta(ChunkDownload);
}

context->downloadedChunkSize = 0;
context->downloadedChunkStartTime = millis();

if(res != Fetch) {
goto exit;
}

/* download chunked or timed */
do {
if(!http_client->connected()) {
res = OtaDownloadFail;
Expand All @@ -104,7 +93,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
continue;
}

http_res = http_client->read(context->buffer, context->buf_len);
int http_res = http_client->read(context->buffer, context->bufLen);

if(http_res < 0) {
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
Expand All @@ -119,8 +108,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
res = ErrorWriteUpdateFileFail;
goto exit;
}
} while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) &&
millis() - start < downloadTime);

context->downloadedChunkSize += http_res;

} while(context->downloadState < OtaDownloadCompleted && fetchMore());

// TODO verify that the information present in the ota header match the info in context
if(context->downloadState == OtaDownloadCompleted) {
Expand Down Expand Up @@ -153,13 +144,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
return res;
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) {
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OtaFlags mode) {
int http_res = 0;

/* stop connected client */
http_client->stop();

/* request chunk */
http_client->beginRequest();
http_res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

if((mode & ChunkDownload) == ChunkDownload) {
char range[128] = {0};
size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize;
sprintf(range, "bytes=%" PRIu32 "-%" PRIu32, context->downloadedSize, context->downloadedSize + rangeSize);
DEBUG_VERBOSE("OTA downloading range: %s", range);
http_client->sendHeader("Range", range);
}

http_client->endRequest();

if(http_res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(http_res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(http_res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if((((mode & ChunkDownload) == ChunkDownload) && (statusCode != 206)) ||
(((mode & ChunkDownload) != ChunkDownload) && (statusCode != 200))) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}

http_client->skipResponseHeaders();
return Fetch;
}

bool OTADefaultCloudProcessInterface::fetchMore() {
if (getOtaPolicy(ChunkDownload)) {
return context->downloadedChunkSize < maxChunkSize;
} else {
return (millis() - context->downloadedChunkStartTime) < downloadTime;
}
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) {
assert(context != nullptr); // This should never fail

for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+buf_len; ) {
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+bufLen; ) {
switch(context->downloadState) {
case OtaDownloadHeader: {
const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes;
const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes;
memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft);
cursor += headerLeft;
context->headerCopiedBytes += headerLeft;
Expand All @@ -184,8 +231,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
break;
}
case OtaDownloadFile: {
const uint32_t contentLength = http_client->contentLength();
const uint32_t dataLeft = buf_len - (cursor-buffer);
const uint32_t dataLeft = bufLen - (cursor-buffer);
context->decoder.decompress(cursor, dataLeft); // TODO verify return value

context->calculatedCrc32 = crc_update(
Expand All @@ -198,18 +244,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
context->downloadedSize += dataLeft;

if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength);
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength);

reportStatus(context->downloadedSize);
context->lastReportTime = millis();
}

// TODO there should be no more bytes available when the download is completed
if(context->downloadedSize == contentLength) {
if(context->downloadedSize == context->contentLength) {
context->downloadState = OtaDownloadCompleted;
}

if(context->downloadedSize > contentLength) {
if(context->downloadedSize > context->contentLength) {
context->downloadState = OtaDownloadError;
}
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
Expand Down Expand Up @@ -250,7 +296,9 @@ OTADefaultCloudProcessInterface::Context::Context(
, headerCopiedBytes(0)
, downloadedSize(0)
, lastReportTime(0)
, contentLength(0)
, writeError(false)
, downloadedChunkSize(0)
, decoder(putc) { }

static const uint32_t crc_table[256] = {
Expand Down
16 changes: 13 additions & 3 deletions src/ota/interface/OTAInterfaceDefault.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
virtual int writeFlash(uint8_t* const buffer, size_t len) = 0;

private:
void parseOta(uint8_t* buffer, size_t buf_len);
void parseOta(uint8_t* buffer, size_t bufLen);
State requestOta(OtaFlags mode = None);
bool fetchMore();

Client* client;
HttpClient* http_client;
Expand All @@ -53,6 +55,10 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
// This mitigate the issues arising from tasks run in main loop that are using all the computing time
static constexpr uint32_t downloadTime = 2000;

// The amount of data that each iteration of Fetch has to take at least
// This should be enabled setting ChunkDownload OtaFlag to 1 and mitigate some Ota corner cases
static constexpr size_t maxChunkSize = 1024 * 10;

enum OTADownloadState: uint8_t {
OtaDownloadHeader,
OtaDownloadFile,
Expand All @@ -74,13 +80,17 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
uint32_t headerCopiedBytes;
uint32_t downloadedSize;
uint32_t lastReportTime;
uint32_t contentLength;
bool writeError;

uint32_t downloadedChunkStartTime;
uint32_t downloadedChunkSize;

// LZSS decoder
LZSSDecoder decoder;

const size_t buf_len = 64;
uint8_t buffer[64];
static constexpr size_t bufLen = 64;
uint8_t buffer[bufLen];
} *context;
};

Expand Down

0 comments on commit a5dcd3d

Please sign in to comment.