diff --git a/src/activities/settings/OtaUpdateActivity.cpp b/src/activities/settings/OtaUpdateActivity.cpp index 0393847..86dcf2a 100644 --- a/src/activities/settings/OtaUpdateActivity.cpp +++ b/src/activities/settings/OtaUpdateActivity.cpp @@ -97,7 +97,7 @@ void OtaUpdateActivity::onExit() { void OtaUpdateActivity::displayTaskLoop() { while (true) { - if (updateRequired) { + if (updateRequired || updater.getRender()) { updateRequired = false; xSemaphoreTake(renderingMutex, portMAX_DELAY); render(); @@ -115,8 +115,9 @@ void OtaUpdateActivity::render() { float updaterProgress = 0; if (state == UPDATE_IN_PROGRESS) { - Serial.printf("[%lu] [OTA] Update progress: %d / %d\n", millis(), updater.processedSize, updater.totalSize); - updaterProgress = static_cast(updater.processedSize) / static_cast(updater.totalSize); + Serial.printf("[%lu] [OTA] Update progress: %d / %d\n", millis(), updater.getProcessedSize(), + updater.getTotalSize()); + updaterProgress = static_cast(updater.getProcessedSize()) / static_cast(updater.getTotalSize()); // Only update every 2% at the most if (static_cast(updaterProgress * 50) == lastUpdaterPercentage / 2) { return; @@ -154,7 +155,7 @@ void OtaUpdateActivity::render() { (std::to_string(static_cast(updaterProgress * 100)) + "%").c_str()); renderer.drawCenteredText( UI_10_FONT_ID, 440, - (std::to_string(updater.processedSize) + " / " + std::to_string(updater.totalSize)).c_str()); + (std::to_string(updater.getProcessedSize()) + " / " + std::to_string(updater.getTotalSize())).c_str()); renderer.displayBuffer(); return; } @@ -194,7 +195,7 @@ void OtaUpdateActivity::loop() { xSemaphoreGive(renderingMutex); updateRequired = true; vTaskDelay(10 / portTICK_PERIOD_MS); - const auto res = updater.installUpdate([this](const size_t, const size_t) { updateRequired = true; }); + const auto res = updater.installUpdate(); if (res != OtaUpdater::OK) { Serial.printf("[%lu] [OTA] Update failed: %d\n", millis(), res); diff --git a/src/network/OtaUpdater.cpp b/src/network/OtaUpdater.cpp index d831af0..1733e13 100644 --- a/src/network/OtaUpdater.cpp +++ b/src/network/OtaUpdater.cpp @@ -1,38 +1,123 @@ #include "OtaUpdater.h" #include -#include -#include + +#include "esp_http_client.h" +#include "esp_https_ota.h" +#include "esp_wifi.h" namespace { constexpr char latestReleaseUrl[] = "https://api.github.com/repos/crosspoint-reader/crosspoint-reader/releases/latest"; + +/* This is buffer and size holder to keep upcoming data from latestReleaseUrl */ +char* local_buf; +int output_len; + +/* + * When esp_crt_bundle.h included, it is pointing wrong header file + * which is something under WifiClientSecure because of our framework based on arduno platform. + * To manage this obstacle, don't include anything, just extern and it will point correct one. + */ +extern "C" { +extern esp_err_t esp_crt_bundle_attach(void* conf); } +esp_err_t http_client_set_header_cb(esp_http_client_handle_t http_client) { + return esp_http_client_set_header(http_client, "User-Agent", "CrossPoint-ESP32-" CROSSPOINT_VERSION); +} + +esp_err_t event_handler(esp_http_client_event_t* event) { + /* We do interested in only HTTP_EVENT_ON_DATA event only */ + if (event->event_id != HTTP_EVENT_ON_DATA) return ESP_OK; + + if (!esp_http_client_is_chunked_response(event->client)) { + int content_len = esp_http_client_get_content_length(event->client); + int copy_len = 0; + + if (local_buf == NULL) { + /* local_buf life span is tracked by caller checkForUpdate */ + local_buf = static_cast(calloc(content_len + 1, sizeof(char))); + output_len = 0; + if (local_buf == NULL) { + Serial.printf("[%lu] [OTA] HTTP Client Out of Memory Failed, Allocation %d\n", millis(), content_len); + return ESP_ERR_NO_MEM; + } + } + copy_len = min(event->data_len, (content_len - output_len)); + if (copy_len) { + memcpy(local_buf + output_len, event->data, copy_len); + } + output_len += copy_len; + } else { + /* Code might be hits here, It happened once (for version checking) but I need more logs to handle that */ + int chunked_len; + esp_http_client_get_chunk_length(event->client, &chunked_len); + Serial.printf("[%lu] [OTA] esp_http_client_is_chunked_response failed, chunked_len: %d\n", millis(), chunked_len); + } + + return ESP_OK; +} /* event_handler */ +} /* namespace */ + OtaUpdater::OtaUpdaterError OtaUpdater::checkForUpdate() { - const std::unique_ptr client(new WiFiClientSecure); - client->setInsecure(); - HTTPClient http; + JsonDocument filter; + esp_err_t esp_err; + JsonDocument doc; - Serial.printf("[%lu] [OTA] Fetching: %s\n", millis(), latestReleaseUrl); + esp_http_client_config_t client_config = { + .url = latestReleaseUrl, + .event_handler = event_handler, + /* Default HTTP client buffer size 512 byte only */ + .buffer_size = 8192, + .buffer_size_tx = 8192, + .skip_cert_common_name_check = true, + .crt_bundle_attach = esp_crt_bundle_attach, + .keep_alive_enable = true, + }; - http.begin(*client, latestReleaseUrl); - http.addHeader("User-Agent", "CrossPoint-ESP32-" CROSSPOINT_VERSION); + /* To track life time of local_buf, dtor will be called on exit from that function */ + struct localBufCleaner { + char** bufPtr; + ~localBufCleaner() { + if (*bufPtr) { + free(*bufPtr); + *bufPtr = NULL; + } + } + } localBufCleaner = {&local_buf}; - const int httpCode = http.GET(); - if (httpCode != HTTP_CODE_OK) { - Serial.printf("[%lu] [OTA] HTTP error: %d\n", millis(), httpCode); - http.end(); + esp_http_client_handle_t client_handle = esp_http_client_init(&client_config); + if (!client_handle) { + Serial.printf("[%lu] [OTA] HTTP Client Handle Failed\n", millis()); + return INTERNAL_UPDATE_ERROR; + } + + esp_err = esp_http_client_set_header(client_handle, "User-Agent", "CrossPoint-ESP32-" CROSSPOINT_VERSION); + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] esp_http_client_set_header Failed : %s\n", millis(), esp_err_to_name(esp_err)); + esp_http_client_cleanup(client_handle); + return INTERNAL_UPDATE_ERROR; + } + + esp_err = esp_http_client_perform(client_handle); + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] esp_http_client_perform Failed : %s\n", millis(), esp_err_to_name(esp_err)); + esp_http_client_cleanup(client_handle); return HTTP_ERROR; } - JsonDocument doc; - JsonDocument filter; + /* esp_http_client_close will be called inside cleanup as well*/ + esp_err = esp_http_client_cleanup(client_handle); + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] esp_http_client_cleanupp Failed : %s\n", millis(), esp_err_to_name(esp_err)); + return INTERNAL_UPDATE_ERROR; + } + filter["tag_name"] = true; filter["assets"][0]["name"] = true; filter["assets"][0]["browser_download_url"] = true; filter["assets"][0]["size"] = true; - const DeserializationError error = deserializeJson(doc, *client, DeserializationOption::Filter(filter)); - http.end(); + const DeserializationError error = deserializeJson(doc, local_buf, DeserializationOption::Filter(filter)); if (error) { Serial.printf("[%lu] [OTA] JSON parse failed: %s\n", millis(), error.c_str()); return JSON_PARSE_ERROR; @@ -42,6 +127,7 @@ OtaUpdater::OtaUpdaterError OtaUpdater::checkForUpdate() { Serial.printf("[%lu] [OTA] No tag_name found\n", millis()); return JSON_PARSE_ERROR; } + if (!doc["assets"].is()) { Serial.printf("[%lu] [OTA] No assets found\n", millis()); return JSON_PARSE_ERROR; @@ -104,67 +190,74 @@ bool OtaUpdater::isUpdateNewer() const { const std::string& OtaUpdater::getLatestVersion() const { return latestVersion; } -OtaUpdater::OtaUpdaterError OtaUpdater::installUpdate(const std::function& onProgress) { +OtaUpdater::OtaUpdaterError OtaUpdater::installUpdate() { if (!isUpdateNewer()) { return UPDATE_OLDER_ERROR; } - const std::unique_ptr client(new WiFiClientSecure); - client->setInsecure(); - HTTPClient http; + esp_https_ota_handle_t ota_handle = NULL; + esp_err_t esp_err; + /* Signal for OtaUpdateActivity */ + render = false; - Serial.printf("[%lu] [OTA] Fetching: %s\n", millis(), otaUrl.c_str()); + esp_http_client_config_t client_config = { + .url = otaUrl.c_str(), + .timeout_ms = 15000, + /* Default HTTP client buffer size 512 byte only + * not sufficent to handle URL redirection cases or + * parsing of large HTTP headers. + */ + .buffer_size = 8192, + .buffer_size_tx = 8192, + .skip_cert_common_name_check = true, + .crt_bundle_attach = esp_crt_bundle_attach, + .keep_alive_enable = true, + }; - http.begin(*client, otaUrl.c_str()); - http.setFollowRedirects(HTTPC_STRICT_FOLLOW_REDIRECTS); - http.addHeader("User-Agent", "CrossPoint-ESP32-" CROSSPOINT_VERSION); - const int httpCode = http.GET(); + esp_https_ota_config_t ota_config = { + .http_config = &client_config, + .http_client_init_cb = http_client_set_header_cb, + }; - if (httpCode != HTTP_CODE_OK) { - Serial.printf("[%lu] [OTA] Download failed: %d\n", millis(), httpCode); - http.end(); + /* For better timing and connectivity, we disable power saving for WiFi */ + esp_wifi_set_ps(WIFI_PS_NONE); + + esp_err = esp_https_ota_begin(&ota_config, &ota_handle); + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] HTTP OTA Begin Failed: %s\n", millis(), esp_err_to_name(esp_err)); + return INTERNAL_UPDATE_ERROR; + } + + do { + esp_err = esp_https_ota_perform(ota_handle); + processedSize = esp_https_ota_get_image_len_read(ota_handle); + /* Sent signal to OtaUpdateActivity */ + render = true; + vTaskDelay(10 / portTICK_PERIOD_MS); + } while (esp_err == ESP_ERR_HTTPS_OTA_IN_PROGRESS); + + /* Return back to default power saving for WiFi in case of failing */ + esp_wifi_set_ps(WIFI_PS_MIN_MODEM); + + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] esp_https_ota_perform Failed: %s\n", millis(), esp_err_to_name(esp_err)); + esp_https_ota_finish(ota_handle); return HTTP_ERROR; } - // 2. Get length and stream - const size_t contentLength = http.getSize(); - - if (contentLength != otaSize) { - Serial.printf("[%lu] [OTA] Invalid content length\n", millis()); - http.end(); - return HTTP_ERROR; - } - - // 3. Begin the ESP-IDF Update process - if (!Update.begin(otaSize)) { - Serial.printf("[%lu] [OTA] Not enough space. Error: %s\n", millis(), Update.errorString()); - http.end(); + if (!esp_https_ota_is_complete_data_received(ota_handle)) { + Serial.printf("[%lu] [OTA] esp_https_ota_is_complete_data_received Failed: %s\n", millis(), + esp_err_to_name(esp_err)); + esp_https_ota_finish(ota_handle); return INTERNAL_UPDATE_ERROR; } - this->totalSize = otaSize; - Serial.printf("[%lu] [OTA] Update started\n", millis()); - Update.onProgress([this, onProgress](const size_t progress, const size_t total) { - this->processedSize = progress; - this->totalSize = total; - onProgress(progress, total); - }); - const size_t written = Update.writeStream(*client); - http.end(); - - if (written == otaSize) { - Serial.printf("[%lu] [OTA] Successfully written %u bytes\n", millis(), written); - } else { - Serial.printf("[%lu] [OTA] Written only %u/%u bytes. Error: %s\n", millis(), written, otaSize, - Update.errorString()); + esp_err = esp_https_ota_finish(ota_handle); + if (esp_err != ESP_OK) { + Serial.printf("[%lu] [OTA] esp_https_ota_finish Failed: %s\n", millis(), esp_err_to_name(esp_err)); return INTERNAL_UPDATE_ERROR; } - if (Update.end() && Update.isFinished()) { - Serial.printf("[%lu] [OTA] Update complete\n", millis()); - return OK; - } else { - Serial.printf("[%lu] [OTA] Error Occurred: %s\n", millis(), Update.errorString()); - return INTERNAL_UPDATE_ERROR; - } + Serial.printf("[%lu] [OTA] Update completed\n", millis()); + return OK; } diff --git a/src/network/OtaUpdater.h b/src/network/OtaUpdater.h index 817f24b..24e04cf 100644 --- a/src/network/OtaUpdater.h +++ b/src/network/OtaUpdater.h @@ -8,6 +8,9 @@ class OtaUpdater { std::string latestVersion; std::string otaUrl; size_t otaSize = 0; + size_t processedSize = 0; + size_t totalSize = 0; + bool render = false; public: enum OtaUpdaterError { @@ -19,12 +22,18 @@ class OtaUpdater { INTERNAL_UPDATE_ERROR, OOM_ERROR, }; - size_t processedSize = 0; - size_t totalSize = 0; + + size_t getOtaSize() const { return otaSize; } + + size_t getProcessedSize() const { return processedSize; } + + size_t getTotalSize() const { return totalSize; } + + bool getRender() const { return render; } OtaUpdater() = default; bool isUpdateNewer() const; const std::string& getLatestVersion() const; OtaUpdaterError checkForUpdate(); - OtaUpdaterError installUpdate(const std::function& onProgress); + OtaUpdaterError installUpdate(); };