diff --git a/src/Arduino_ESP32_OTA.cpp b/src/Arduino_ESP32_OTA.cpp index d500961..8a4cbda 100644 --- a/src/Arduino_ESP32_OTA.cpp +++ b/src/Arduino_ESP32_OTA.cpp @@ -22,8 +22,6 @@ #include #include "Arduino_ESP32_OTA.h" #include "tls/amazon_root_ca.h" -#include "decompress/lzss.h" -#include "decompress/utility.h" #include "esp_ota_ops.h" /****************************************************************************** @@ -31,10 +29,9 @@ ******************************************************************************/ Arduino_ESP32_OTA::Arduino_ESP32_OTA() -:_client{nullptr} -,_ota_header{0} -,_ota_size(0) -,_crc32(0) +: _context(nullptr) +, _client(nullptr) +, _http_client(nullptr) ,_ca_cert{amazon_root_ca} ,_ca_cert_bundle{nullptr} ,_magic(0) @@ -42,18 +39,16 @@ Arduino_ESP32_OTA::Arduino_ESP32_OTA() } +Arduino_ESP32_OTA::~Arduino_ESP32_OTA(){ + clean(); +} + /****************************************************************************** PUBLIC MEMBER FUNCTIONS ******************************************************************************/ Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin(uint32_t magic) { - /* initialize private variables */ - otaInit(); - - /* ... initialize CRC ... */ - crc32Init(); - /* ... configure board Magic number */ setMagic(magic); @@ -93,38 +88,29 @@ void Arduino_ESP32_OTA::setMagic(uint32_t magic) _magic = magic; } -uint8_t Arduino_ESP32_OTA::read_byte_from_network() -{ - bool is_http_data_timeout = false; - for(unsigned long const start = millis();;) - { - is_http_data_timeout = (millis() - start) > ARDUINO_ESP32_OTA_BINARY_BYTE_RECEIVE_TIMEOUT_ms; - if (is_http_data_timeout) { - DEBUG_ERROR("%s: timeout waiting data", __FUNCTION__); - return -1; - } - if (_client->available()) { - const uint8_t data = _client->read(); - crc32Update(data); - return data; - } - } -} - void Arduino_ESP32_OTA::write_byte_to_flash(uint8_t data) { Update.write(&data, 1); } -int Arduino_ESP32_OTA::download(const char * ota_url) +int Arduino_ESP32_OTA::startDownload(const char * ota_url) { - URI url(ota_url); - int port = 0; + assert(_context == nullptr); + assert(_client == nullptr); + assert(_http_client == nullptr); + + Error err = Error::None; + int statusCode; + int res; - if (url.protocol_ == "http") { + _context = new Context(ota_url, [this](uint8_t data){ + _context->writtenBytes++; + write_byte_to_flash(data); + }); + + if(strcmp(_context->parsed_url.schema(), "http") == 0) { _client = new WiFiClient(); - port = 80; - } else if (url.protocol_ == "https") { + } else if(strcmp(_context->parsed_url.schema(), "https") == 0) { _client = new WiFiClientSecure(); if (_ca_cert != nullptr) { static_cast(_client)->setCACert(_ca_cert); @@ -133,152 +119,206 @@ int Arduino_ESP32_OTA::download(const char * ota_url) } else { DEBUG_VERBOSE("%s: CA not configured for download client"); } - port = 443; } else { - DEBUG_ERROR("%s: Failed to parse OTA URL %s", __FUNCTION__, ota_url); - return static_cast(Error::UrlParseError); + err = Error::UrlParseError; + goto exit; } - if (!_client->connect(url.host_.c_str(), port)) - { - DEBUG_ERROR("%s: Connection failure with OTA storage server %s", __FUNCTION__, url.host_.c_str()); - delete _client; - _client = nullptr; - return static_cast(Error::ServerConnectError); + _http_client = new HttpClient(*_client, _context->parsed_url.host(), _context->parsed_url.port()); + + res= _http_client->get(_context->parsed_url.path()); + + if(res == HTTP_ERROR_CONNECTION_FAILED) { + DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"", + _context->parsed_url.host(), _context->parsed_url.port()); + err = Error::ServerConnectError; + goto exit; + } else if(res == HTTP_ERROR_TIMED_OUT) { + DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", _context->url); + err = Error::OtaHeaderTimeout; + goto exit; + } else if(res != HTTP_SUCCESS) { + DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, _context->url); + err = Error::OtaDownload; + goto exit; } - _client->println(String("GET ") + url.path_.c_str() + " HTTP/1.1"); - _client->println(String("Host: ") + url.host_.c_str()); - _client->println("Connection: close"); - _client->println(); - - /* Receive HTTP header. */ - String http_header; - bool is_header_complete = false, - is_http_header_timeout = false; - for (unsigned long const start = millis(); !is_header_complete;) - { - is_http_header_timeout = (millis() - start) > ARDUINO_ESP32_OTA_HTTP_HEADER_RECEIVE_TIMEOUT_ms; - if (is_http_header_timeout) break; - - if (_client->available()) - { - char const c = _client->read(); - - http_header += c; - if (http_header.endsWith("\r\n\r\n")) - is_header_complete = true; - } + statusCode = _http_client->responseStatusCode(); + + if(statusCode != 200) { + DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", _context->url, statusCode); + err = Error::HttpResponse; + goto exit; } - if (!is_header_complete) - { - DEBUG_ERROR("%s: Error receiving HTTP header %s", __FUNCTION__, is_http_header_timeout ? "(timeout)":""); - delete _client; - _client = nullptr; - return static_cast(Error::HttpHeaderError); + // The following call is required to save the header value , keep it + if(_http_client->contentLength() == HttpClient::kNoContentLengthHeader) { + DEBUG_VERBOSE("OTA ERROR: the response header doesn't contain \"ContentLength\" field"); + err = Error::HttpHeaderError; + goto exit; } - /* Check HTTP response status code */ - char const * http_response_ptr = strstr(http_header.c_str(), "HTTP/1.1"); - if (!http_response_ptr) - { - DEBUG_ERROR("%s: Failure to extract http response from header", __FUNCTION__); - return static_cast(Error::ParseHttpHeader); +exit: + if(err != Error::None) { + clean(); + return static_cast(err); + } else { + return _http_client->contentLength(); } - /* Find start of numerical value. */ - char * ptr = const_cast(http_response_ptr); - for (ptr += strlen("HTTP/1.1"); (*ptr != '\0') && !isDigit(*ptr); ptr++) { } - /* Extract numerical value. */ - String http_response_str; - for (; isDigit(*ptr); ptr++) http_response_str += *ptr; - int const http_response = atoi(http_response_str.c_str()); - - if (http_response != 200) { - DEBUG_ERROR("%s: HTTP response status code = %d", __FUNCTION__, http_response); - return static_cast(Error::HttpResponse); +} + +int Arduino_ESP32_OTA::progressDownload() +{ + int http_res = static_cast(Error::None);; + int res = 0; + + if(_http_client->available() == 0) { + goto exit; } - /* Extract content length from HTTP header. A typical entry looks like - * "Content-Length: 123456" - */ - char const * content_length_ptr = strstr(http_header.c_str(), "Content-Length"); - if (!content_length_ptr) - { - DEBUG_ERROR("%s: Failure to extract content length from http header", __FUNCTION__); - delete _client; - _client = nullptr; - return static_cast(Error::ParseHttpHeader); + http_res = _http_client->read(_context->buffer, _context->buf_len); + + if(http_res < 0) { + DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res); + res = static_cast(Error::OtaDownload); + goto exit; } - /* Find start of numerical value. */ - ptr = const_cast(content_length_ptr); - for (; (*ptr != '\0') && !isDigit(*ptr); ptr++) { } - /* Extract numerical value. */ - String content_length_str; - for (; isDigit(*ptr); ptr++) content_length_str += *ptr; - int const content_length_val = atoi(content_length_str.c_str()); - DEBUG_VERBOSE("%s: Length of OTA binary according to HTTP header = %d bytes", __FUNCTION__, content_length_val); - - /* Read the OTA header ... */ - bool is_ota_header_timeout = false; - unsigned long const start = millis(); - for (int i = 0; i < sizeof(OtaHeader);) - { - is_ota_header_timeout = (millis() - start) > ARDUINO_ESP32_OTA_BINARY_HEADER_RECEIVE_TIMEOUT_ms; - if (is_ota_header_timeout) break; - - if (_client->available()) - { - _ota_header.buf[i++] = _client->read(); + + for(uint8_t* cursor=(uint8_t*)_context->buffer; cursor<_context->buffer+http_res; ) { + switch(_context->downloadState) { + case OtaDownloadHeader: { + uint32_t copied = http_res < sizeof(_context->header.buf) ? http_res : sizeof(_context->header.buf); + memcpy(_context->header.buf+_context->headerCopiedBytes, _context->buffer, copied); + cursor += copied; + _context->headerCopiedBytes += copied; + + // when finished go to next state + if(sizeof(_context->header.buf) == _context->headerCopiedBytes) { + _context->downloadState = OtaDownloadFile; + + _context->calculatedCrc32 = crc_update( + _context->calculatedCrc32, + &(_context->header.header.magic_number), + sizeof(_context->header) - offsetof(OtaHeader, header.magic_number) + ); + + if(_context->header.header.magic_number != _magic) { + _context->downloadState = OtaDownloadMagicNumberMismatch; + res = static_cast(Error::OtaHeaderMagicNumber); + + goto exit; + } + } + + break; + } + case OtaDownloadFile: + _context->decoder.decompress(cursor, http_res - (cursor-_context->buffer)); // TODO verify return value + + _context->calculatedCrc32 = crc_update( + _context->calculatedCrc32, + cursor, + http_res - (cursor-_context->buffer) + ); + + cursor += http_res - (cursor-_context->buffer); + _context->downloadedSize += (cursor-_context->buffer); + + // TODO there should be no more bytes available when the download is completed + if(_context->downloadedSize == _http_client->contentLength()) { + _context->downloadState = OtaDownloadCompleted; + res = 1; + } + + if(_context->downloadedSize > _http_client->contentLength()) { + _context->downloadState = OtaDownloadError; + res = static_cast(Error::OtaDownload); + } + // TODO fail if we exceed a timeout? and available is 0 (client is broken) + break; + case OtaDownloadCompleted: + res = 1; + goto exit; + default: + _context->downloadState = OtaDownloadError; + res = static_cast(Error::OtaDownload); + goto exit; } } - /* ... check for header download timeout ... */ - if (is_ota_header_timeout) { - delete _client; - _client = nullptr; - return static_cast(Error::OtaHeaderTimeout); +exit: + if(_context->downloadState == OtaDownloadError || + _context->downloadState == OtaDownloadMagicNumberMismatch) { + clean(); // need to clean everything because the download failed + } else if(_context->downloadState == OtaDownloadCompleted) { + // only need to delete clients and not the context, since it will be needed + if(_client != nullptr) { + delete _client; + _client = nullptr; + } + + if(_http_client != nullptr) { + delete _http_client; + _http_client = nullptr; + } } - /* ... then check if OTA header length field matches HTTP content length... */ - if (_ota_header.header.len != (content_length_val - sizeof(_ota_header.header.len) - sizeof(_ota_header.header.crc32))) { - delete _client; - _client = nullptr; - return static_cast(Error::OtaHeaderLength); + return res; +} + +int Arduino_ESP32_OTA::downloadProgress() +{ + if(_context->error != Error::None) { + return static_cast(_context->error); + } else { + return _context->downloadedSize; } +} - /* ... and OTA magic number */ - if (_ota_header.header.magic_number != _magic) - { - delete _client; - _client = nullptr; - return static_cast(Error::OtaHeaterMagicNumber); +size_t Arduino_ESP32_OTA::downloadSize() +{ + return _http_client!=nullptr ? _http_client->contentLength() : 0; +} + +int Arduino_ESP32_OTA::download(const char * ota_url) +{ + int err = startDownload(ota_url); + + if(err < 0) { + return err; } - /* ... start CRC32 from OTA MAGIC ... */ - _crc32 = crc_update(_crc32, &_ota_header.header.magic_number, 12); + int res = 0; + while((res = progressDownload()) <= 0); - /* Download and decode OTA file */ - _ota_size = lzss_download(this, content_length_val - sizeof(_ota_header)); + return res == 1? _context->writtenBytes : res; +} - if(_ota_size <= content_length_val - sizeof(_ota_header)) - { +void Arduino_ESP32_OTA::clean() +{ + if(_client != nullptr) { delete _client; _client = nullptr; - return static_cast(Error::OtaDownload); } - delete _client; - _client = nullptr; - return _ota_size; + if(_http_client != nullptr) { + delete _http_client; + _http_client = nullptr; + } + + if(_context != nullptr) { + delete _context; + _context = nullptr; + } } Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::update() { /* ... then finalize ... */ - crc32Finalize(); + _context->calculatedCrc32 ^= 0xFFFFFFFF; - if(!crc32Verify()) { + /* Verify the crc */ + if(_context->header.header.crc32 != _context->calculatedCrc32) { DEBUG_ERROR("%s: CRC32 mismatch", __FUNCTION__); return Error::OtaHeaderCrc; } @@ -288,6 +328,8 @@ Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::update() return Error::OtaStorageEnd; } + clean(); + return Error::None; } @@ -307,28 +349,20 @@ bool Arduino_ESP32_OTA::isCapable() PROTECTED MEMBER FUNCTIONS ******************************************************************************/ -void Arduino_ESP32_OTA::otaInit() -{ - _ota_size = 0; - _ota_header = {0}; -} - -void Arduino_ESP32_OTA::crc32Init() -{ - _crc32 = 0xFFFFFFFF; -} - -void Arduino_ESP32_OTA::crc32Update(const uint8_t data) -{ - _crc32 = crc_update(_crc32, &data, 1); -} - -void Arduino_ESP32_OTA::crc32Finalize() -{ - _crc32 ^= 0xFFFFFFFF; -} +Arduino_ESP32_OTA::Context::Context( + const char* url, std::function putc) + : url((char*)malloc(strlen(url)+1)) + , parsed_url(url) + , downloadState(OtaDownloadHeader) + , calculatedCrc32(0xFFFFFFFF) + , headerCopiedBytes(0) + , downloadedSize(0) + , error(Error::None) + , decoder(putc) { + strcpy(this->url, url); + } -bool Arduino_ESP32_OTA::crc32Verify() -{ - return (_crc32 == _ota_header.header.crc32); -} +Arduino_ESP32_OTA::Context::~Context(){ + free(url); + url = nullptr; +} \ No newline at end of file diff --git a/src/Arduino_ESP32_OTA.h b/src/Arduino_ESP32_OTA.h index 804a50b..d965821 100644 --- a/src/Arduino_ESP32_OTA.h +++ b/src/Arduino_ESP32_OTA.h @@ -25,6 +25,10 @@ #include #include #include "decompress/utility.h" +#include "decompress/lzss.h" +#include +#include +#include /****************************************************************************** DEFINES @@ -69,37 +73,86 @@ class Arduino_ESP32_OTA HttpResponse = -14 }; + enum OTADownloadState: uint8_t { + OtaDownloadHeader, + OtaDownloadFile, + OtaDownloadCompleted, + OtaDownloadMagicNumberMismatch, + OtaDownloadError + }; + Arduino_ESP32_OTA(); - virtual ~Arduino_ESP32_OTA() { } + virtual ~Arduino_ESP32_OTA(); Arduino_ESP32_OTA::Error begin(uint32_t magic = ARDUINO_ESP32_OTA_MAGIC); void setMagic(uint32_t magic); void setCACert(const char *rootCA); void setCACertBundle(const uint8_t * bundle); + + // blocking version for the download + // returns the size of the downloaded binary int download(const char * ota_url); - uint8_t read_byte_from_network(); + + // start a download in a non blocking fashion + // call progressDownload, until it returns OtaDownloadCompleted + // returns the value in content-length http header + int startDownload(const char * ota_url); + + // This function is used to make the download progress. + // it returns 0, if the download is in progress + // it returns 1, if the download is completed + // it returns <0 if an error occurred, following Error enum values + virtual int progressDownload(); + + // this function is used to get the progress of the download + // it returns a positive value when the download is progressing correctly + // it returns a negative value on error following Error enum values + int downloadProgress(); + + // this function is used to get the size of the download + // 0 if no download is in progress + size_t downloadSize(); + virtual void write_byte_to_flash(uint8_t data); Arduino_ESP32_OTA::Error update(); void reset(); static bool isCapable(); protected: + struct Context { + Context( + const char* url, + std::function putc); + + ~Context(); + + char* url; + ParsedUrl parsed_url; + OtaHeader header; + OTADownloadState downloadState; + uint32_t calculatedCrc32; + uint32_t headerCopiedBytes; + uint32_t downloadedSize; + uint32_t writtenBytes; + + // If an error occurred during download it is reported in this field + Error error; + + // LZSS decoder + LZSSDecoder decoder; - void otaInit(); - void crc32Init(); - void crc32Update(const uint8_t data); - void crc32Finalize(); - bool crc32Verify(); + const size_t buf_len = 64; + uint8_t buffer[64]; + } *_context; private: Client * _client; - OtaHeader _ota_header; - size_t _ota_size; - uint32_t _crc32; + HttpClient* _http_client; const char * _ca_cert; const uint8_t * _ca_cert_bundle; uint32_t _magic; + void clean(); }; #endif /* ARDUINO_ESP32_OTA_H_ */