diff --git a/src/Arduino_ESP32_OTA.cpp b/src/Arduino_ESP32_OTA.cpp index b6335b4..13a7d92 100644 --- a/src/Arduino_ESP32_OTA.cpp +++ b/src/Arduino_ESP32_OTA.cpp @@ -76,7 +76,16 @@ Arduino_ESP32_OTA::Error Arduino_ESP32_OTA::begin() DEBUG_ERROR("%s: board is not capable to perform OTA", __FUNCTION__); return Error::NoOtaStorage; } - + + /* initialize private variables */ + _ota_size = 0; + _ota_header = {0}; + + if(Update.isRunning()) { + Update.abort(); + DEBUG_DEBUG("%s: Aborting running update", __FUNCTION__); + } + if(!Update.begin(UPDATE_SIZE_UNKNOWN)) { DEBUG_ERROR("%s: failed to initialize flash update", __FUNCTION__); return Error::OtaStorageInit; @@ -147,6 +156,8 @@ int Arduino_ESP32_OTA::download(const char * ota_url) if (!_client->connect(url.host_.c_str(), port)) { DEBUG_ERROR("%s: Connection failure with OTA storage server %s", __FUNCTION__, url.host_.c_str()); + delete _client; + _client = nullptr; return static_cast(Error::ServerConnectError); } @@ -177,6 +188,8 @@ int Arduino_ESP32_OTA::download(const char * ota_url) if (!is_header_complete) { DEBUG_ERROR("%s: Error receiving HTTP header %s", __FUNCTION__, is_http_header_timeout ? "(timeout)":""); + delete _client; + _client = nullptr; return static_cast(Error::HttpHeaderError); } @@ -207,6 +220,8 @@ int Arduino_ESP32_OTA::download(const char * ota_url) if (!content_length_ptr) { DEBUG_ERROR("%s: Failure to extract content length from http header", __FUNCTION__); + delete _client; + _client = nullptr; return static_cast(Error::ParseHttpHeader); } /* Find start of numerical value. */ @@ -234,17 +249,23 @@ int Arduino_ESP32_OTA::download(const char * ota_url) /* ... check for header download timeout ... */ if (is_ota_header_timeout) { + delete _client; + _client = nullptr; return static_cast(Error::OtaHeaderTimeout); } /* ... then check if OTA header length field matches HTTP content length... */ if (_ota_header.header.len != (content_length_val - sizeof(_ota_header.header.len) - sizeof(_ota_header.header.crc32))) { + delete _client; + _client = nullptr; return static_cast(Error::OtaHeaderLength); } /* ... and OTA magic number */ if (_ota_header.header.magic_number != ARDUINO_ESP32_OTA_MAGIC) { + delete _client; + _client = nullptr; return static_cast(Error::OtaHeaterMagicNumber); } @@ -256,9 +277,13 @@ int Arduino_ESP32_OTA::download(const char * ota_url) if(_ota_size <= content_length_val - sizeof(_ota_header)) { + delete _client; + _client = nullptr; return static_cast(Error::OtaDownload); } + delete _client; + _client = nullptr; return _ota_size; } diff --git a/src/decompress/lzss.cpp b/src/decompress/lzss.cpp index 72427ba..bce0e1c 100644 --- a/src/decompress/lzss.cpp +++ b/src/decompress/lzss.cpp @@ -30,6 +30,7 @@ int bit_buffer = 0, bit_mask = 128; unsigned char buffer[N * 2]; static size_t bytes_written_fputc = 0; +static size_t bytes_read_fgetc = 0; /************************************************************************************** PRIVATE FUNCTIONS @@ -45,8 +46,6 @@ void lzss_fputc(int const c) int lzss_fgetc() { - static size_t bytes_read_fgetc = 0; - /* lzss_file_size is set within SSUBoot:main * and contains the size of the LZSS file. Once * all those bytes have been read its time to return @@ -163,6 +162,8 @@ int lzss_download(ArduinoEsp32OtaReadByteFuncPointer read_byte, ArduinoEsp32OtaW read_byte_fptr = read_byte; write_byte_fptr = write_byte; LZSS_FILE_SIZE = lzss_file_size; + bytes_written_fputc = 0; + bytes_read_fgetc = 0; lzss_decode(); return bytes_written_fputc; }