From b336aedbc58e2646149c071509ac94264744c8b6 Mon Sep 17 00:00:00 2001 From: zjs81 Date: Sat, 14 Mar 2026 17:32:08 -0700 Subject: [PATCH] fix: address PR #296 code review feedback - Clamp ML predictions between physics floor (raw airtime) and ceiling (worst-case formula) so model can never produce unsafe timeouts - Replace hourOfDay feature with secondsSinceLastRx for network activity - Remove unused _ContactStats.stdDev and dead model persistence code - Debounce observation writes (2s) instead of writing on every delivery - Skip recording observations when pathLength is null to avoid corrupting training data - Add comment explaining global (not per-contact) RX time tracking - Remove notifyListeners from retrain to avoid unnecessary widget rebuilds - Run dart format --- lib/connector/meshcore_connector.dart | 107 ++++++++------ lib/services/message_retry_service.dart | 18 ++- lib/services/storage_service.dart | 21 +-- lib/services/timeout_prediction_service.dart | 41 +++--- test/services/ml_algo_sanity_test.dart | 136 +++++++++++------- .../timeout_prediction_service_test.dart | 6 +- 6 files changed, 187 insertions(+), 142 deletions(-) diff --git a/lib/connector/meshcore_connector.dart b/lib/connector/meshcore_connector.dart index d05a8f9c..33e5c48e 100644 --- a/lib/connector/meshcore_connector.dart +++ b/lib/connector/meshcore_connector.dart @@ -168,6 +168,8 @@ class MeshCoreConnector extends ChangeNotifier { bool _isLoadingChannels = false; bool _hasLoadedChannels = false; TimeoutPredictionService? _timeoutPredictionService; + // Intentionally global (not per-contact): tracks overall network activity. + // Frequent RX from any source indicates a busy network with more collisions. DateTime _lastRxTime = DateTime.now(); bool _batteryRequested = false; bool _awaitingSelfInfo = false; @@ -694,23 +696,28 @@ class MeshCoreConnector extends ChangeNotifier { updateMessageCallback: _updateMessage, clearContactPathCallback: clearContactPath, setContactPathCallback: setContactPath, - calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) => - calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey), + calculateTimeoutCallback: + (pathLength, messageBytes, {String? contactKey}) => calculateTimeout( + pathLength: pathLength, + messageBytes: messageBytes, + contactKey: contactKey, + ), getSelfPublicKeyCallback: () => _selfPublicKey, prepareContactOutboundTextCallback: prepareContactOutboundText, appSettingsService: appSettingsService, debugLogService: _appDebugLogService, recordPathResultCallback: _recordPathResult, - onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) { - final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; - _timeoutPredictionService?.recordObservation( - contactKey: contactKey, - pathLength: pathLength, - messageBytes: messageBytes, - tripTimeMs: tripTimeMs, - secondsSinceLastRx: secSinceRx, - ); - }, + onDeliveryObservedCallback: + (contactKey, pathLength, messageBytes, tripTimeMs) { + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + _timeoutPredictionService?.recordObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + tripTimeMs: tripTimeMs, + secondsSinceLastRx: secSinceRx, + ); + }, ); } @@ -2890,14 +2897,54 @@ class MeshCoreConnector extends ChangeNotifier { } } - /// Calculate timeout for a message based on radio settings and path length - /// Returns timeout in milliseconds, considering number of hops + /// Estimate single-packet airtime in ms from radio settings, or a fallback. + int _estimateAirtimeMs(int messageBytes) { + if (_currentFreqHz != null && + _currentBwHz != null && + _currentSf != null && + _currentCr != null) { + final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4; + return calculateLoRaAirtime( + payloadBytes: messageBytes, + spreadingFactor: _currentSf!, + bandwidthHz: _currentBwHz!, + codingRate: cr, + lowDataRateOptimize: _currentSf! >= 11, + ); + } + return 50; // fallback: ~SF7/BW125 for 100 bytes + } + + /// Physics-based worst-case timeout (ceiling). + int _physicsMaxTimeout(int pathLength, int airtime) { + if (pathLength < 0) { + return 500 + (16 * airtime); + } else { + return 500 + ((airtime * 6 + 250) * (pathLength + 1)); + } + } + + /// Physics-based minimum timeout (floor): raw traversal time. + int _physicsMinTimeout(int pathLength, int airtime) { + if (pathLength < 0) { + return airtime; + } else { + return airtime * (pathLength + 1); + } + } + + /// Calculate timeout for a message based on radio settings and path length. + /// Returns timeout in milliseconds, considering number of hops. int calculateTimeout({ required int pathLength, int messageBytes = 100, String? contactKey, }) { - // Try ML-based prediction first + final airtime = _estimateAirtimeMs(messageBytes); + final physicsMin = _physicsMinTimeout(pathLength, airtime); + final physicsMax = _physicsMaxTimeout(pathLength, airtime); + + // Try ML-based prediction, clamped between physics bounds final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; final mlTimeout = _timeoutPredictionService?.predictTimeout( contactKey: contactKey, @@ -2905,35 +2952,11 @@ class MeshCoreConnector extends ChangeNotifier { messageBytes: messageBytes, secondsSinceLastRx: secSinceRx, ); - if (mlTimeout != null) return mlTimeout; - - // If we have radio settings, use them for accurate calculation - if (_currentFreqHz != null && - _currentBwHz != null && - _currentSf != null && - _currentCr != null) { - final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4; - return calculateMessageTimeout( - freqHz: _currentFreqHz!, - bwHz: _currentBwHz!, - sf: _currentSf!, - cr: cr, - pathLength: pathLength, - messageBytes: messageBytes, - ); + if (mlTimeout != null) { + return mlTimeout.clamp(physicsMin, physicsMax); } - // Fallback: Conservative estimates based on typical settings - // Assume SF7, BW125, which gives ~50ms airtime for 100 bytes - const estimatedAirtime = 50; - - if (pathLength < 0) { - // Flood mode: Base delay + 16× airtime - return 500 + (16 * estimatedAirtime); - } else { - // Direct path: Base delay + ((airtime×6 + 250ms)×(hops+1)) - return 500 + ((estimatedAirtime * 6 + 250) * (pathLength + 1)); - } + return physicsMax; } void _handleContact(Uint8List frame, {bool isContact = true}) { diff --git a/lib/services/message_retry_service.dart b/lib/services/message_retry_service.dart index d94b763b..b66ba51a 100644 --- a/lib/services/message_retry_service.dart +++ b/lib/services/message_retry_service.dart @@ -74,14 +74,20 @@ class MessageRetryService extends ChangeNotifier { required Function(Message) updateMessageCallback, Function(Contact)? clearContactPathCallback, Function(Contact, Uint8List, int)? setContactPathCallback, - Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback, + Function(int pathLength, int messageBytes, {String? contactKey})? + calculateTimeoutCallback, Uint8List? Function()? getSelfPublicKeyCallback, String Function(Contact, String)? prepareContactOutboundTextCallback, AppSettingsService? appSettingsService, AppDebugLogService? debugLogService, Function(String, PathSelection, bool, int?)? recordPathResultCallback, - Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)? - onDeliveryObservedCallback, + Function( + String contactKey, + int pathLength, + int messageBytes, + int tripTimeMs, + )? + onDeliveryObservedCallback, }) { _sendMessageCallback = sendMessageCallback; _addMessageCallback = addMessageCallback; @@ -750,10 +756,12 @@ class MessageRetryService extends ChangeNotifier { true, tripTimeMs, ); - if (_onDeliveryObservedCallback != null && tripTimeMs > 0) { + if (_onDeliveryObservedCallback != null && + tripTimeMs > 0 && + message.pathLength != null) { _onDeliveryObservedCallback!( contact.publicKeyHex, - message.pathLength ?? 0, + message.pathLength!, message.text.length, tripTimeMs, ); diff --git a/lib/services/storage_service.dart b/lib/services/storage_service.dart index c591f648..a86c1f6d 100644 --- a/lib/services/storage_service.dart +++ b/lib/services/storage_service.dart @@ -8,7 +8,6 @@ class StorageService { static const String _pendingMessagesKey = 'pending_messages'; static const String _repeaterPasswordsKey = 'repeater_passwords'; static const String _deliveryObservationsKey = 'delivery_observations'; - static const String _timeoutModelKey = 'timeout_ml_model'; Future savePathHistory( String contactPubKeyHex, @@ -143,10 +142,7 @@ class StorageService { try { final list = jsonDecode(jsonStr) as List; return list - .map( - (e) => - DeliveryObservation.fromJson(e as Map), - ) + .map((e) => DeliveryObservation.fromJson(e as Map)) .toList(); } catch (e) { return []; @@ -157,19 +153,4 @@ class StorageService { final prefs = PrefsManager.instance; await prefs.remove(_deliveryObservationsKey); } - - Future saveTimeoutModel(String modelJson) async { - final prefs = PrefsManager.instance; - await prefs.setString(_timeoutModelKey, modelJson); - } - - Future loadTimeoutModel() async { - final prefs = PrefsManager.instance; - return prefs.getString(_timeoutModelKey); - } - - Future clearTimeoutModel() async { - final prefs = PrefsManager.instance; - await prefs.remove(_timeoutModelKey); - } } diff --git a/lib/services/timeout_prediction_service.dart b/lib/services/timeout_prediction_service.dart index 21e229e8..1f3d6ddf 100644 --- a/lib/services/timeout_prediction_service.dart +++ b/lib/services/timeout_prediction_service.dart @@ -1,5 +1,4 @@ -import 'dart:convert'; -import 'dart:math'; +import 'dart:async'; import 'package:flutter/foundation.dart'; import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -9,16 +8,13 @@ import 'storage_service.dart'; class _ContactStats { int count = 0; double _sum = 0; - double _sumSq = 0; void add(double ms) { count++; _sum += ms; - _sumSq += ms * ms; } double get mean => _sum / count; - double get stdDev => sqrt((_sumSq / count) - (mean * mean)); } class TimeoutPredictionService extends ChangeNotifier { @@ -27,9 +23,10 @@ class TimeoutPredictionService extends ChangeNotifier { static const int minObservations = 10; static const int maxObservations = 100; static const int _retrainInterval = 5; + // 1.5x multiplier on raw prediction to account for variance in delivery + // times — tight enough to improve on worst-case physics, loose enough + // to avoid premature timeouts from model noise. static const double _safetyMargin = 1.5; - static const int _minTimeoutMs = 2000; - static const int _maxTimeoutMs = 120000; static const int _minContactObservations = 10; List _observations = []; @@ -37,6 +34,7 @@ class TimeoutPredictionService extends ChangeNotifier { List _activeFeatures = []; int _observationsSinceLastTrain = 0; final Map _contactStats = {}; + Timer? _persistTimer; TimeoutPredictionService(StorageService storage) : _storage = storage; TimeoutPredictionService.noStorage() : _storage = null; @@ -89,7 +87,10 @@ class TimeoutPredictionService extends ChangeNotifier { _trainModel(); } - _storage?.saveDeliveryObservations(_observations); + _persistTimer?.cancel(); + _persistTimer = Timer(const Duration(seconds: 2), () { + _storage?.saveDeliveryObservations(_observations); + }); debugPrint( 'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops ' '(${_observations.length} total)', @@ -123,7 +124,9 @@ class TimeoutPredictionService extends ChangeNotifier { final prediction = _model!.predict(features); final rawValue = prediction.rows.first.first; - var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble(); + var predictedMs = (rawValue is double) + ? rawValue + : (rawValue as num).toDouble(); debugPrint( 'TimeoutPrediction: raw prediction=$predictedMs for ' @@ -142,8 +145,8 @@ class TimeoutPredictionService extends ChangeNotifier { } } - final timeout = - (predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs); + // Connector clamps this between physics min/max bounds + final timeout = (predictedMs * _safetyMargin).ceil(); debugPrint( 'TimeoutPrediction: ML timeout ${timeout}ms ' '(raw: ${predictedMs.round()}ms, contact: $contactKey)', @@ -174,7 +177,9 @@ class TimeoutPredictionService extends ChangeNotifier { } if (_activeFeatures.isEmpty) { - debugPrint('TimeoutPrediction: no features with variance, skipping training'); + debugPrint( + 'TimeoutPrediction: no features with variance, skipping training', + ); return; } @@ -190,25 +195,19 @@ class TimeoutPredictionService extends ChangeNotifier { return row; }); - final data = DataFrame( - [header, ...rows], - headerExists: true, - ); + final data = DataFrame([header, ...rows], headerExists: true); _model = LinearRegressor(data, 'deliveryMs'); _observationsSinceLastTrain = 0; // Log training summary with sample predictions - final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / + final avgMs = + _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / _observations.length; debugPrint( 'TimeoutPrediction: trained on ${_observations.length} observations ' '(avg: ${avgMs.round()}ms, features: $_activeFeatures)', ); - - final modelJson = jsonEncode(_model!.toJson()); - _storage?.saveTimeoutModel(modelJson); - notifyListeners(); } catch (e) { debugPrint('TimeoutPrediction: training failed: $e'); } diff --git a/test/services/ml_algo_sanity_test.dart b/test/services/ml_algo_sanity_test.dart index e4f980ed..427a8a6e 100644 --- a/test/services/ml_algo_sanity_test.dart +++ b/test/services/ml_algo_sanity_test.dart @@ -6,18 +6,22 @@ import 'package:ml_dataframe/ml_dataframe.dart'; void main() { test('LinearRegressor basic sanity check', () { // Simple: y = 2x + 100 - final data = DataFrame([ - [1.0, 102.0], - [2.0, 104.0], - [3.0, 106.0], - [4.0, 108.0], - [5.0, 110.0], - [10.0, 120.0], - [20.0, 140.0], - [50.0, 200.0], - [0.0, 100.0], - [100.0, 300.0], - ], headerExists: false, header: ['x', 'y']); + final data = DataFrame( + [ + [1.0, 102.0], + [2.0, 104.0], + [3.0, 106.0], + [4.0, 108.0], + [5.0, 110.0], + [10.0, 120.0], + [20.0, 140.0], + [50.0, 200.0], + [0.0, 100.0], + [100.0, 300.0], + ], + headerExists: false, + header: ['x', 'y'], + ); debugPrint('Training data columns: ${data.header}'); debugPrint('Training data rows: ${data.rows.length}'); @@ -25,7 +29,9 @@ void main() { final model = LinearRegressor(data, 'y'); final testDf = DataFrame( - [[25.0]], + [ + [25.0], + ], headerExists: false, header: ['x'], ); @@ -38,45 +44,63 @@ void main() { test('LinearRegressor multi-feature with constant column produces zeros', () { // isFlood=0 for all rows → zero-variance column → singular matrix - final data = DataFrame([ - [0.0, 50.0, 14.0, 0.0, 1900.0], - [0.0, 80.0, 14.0, 0.0, 2200.0], - [2.0, 50.0, 14.0, 0.0, 5000.0], - [4.0, 50.0, 14.0, 0.0, 9500.0], - ], headerExists: false, header: [ - 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', - ]); + final data = DataFrame( + [ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 14.0, 0.0, 2200.0], + [2.0, 50.0, 14.0, 0.0, 5000.0], + [4.0, 50.0, 14.0, 0.0, 9500.0], + ], + headerExists: false, + header: [ + 'pathLength', + 'messageBytes', + 'hourOfDay', + 'isFlood', + 'deliveryMs', + ], + ); final model = LinearRegressor(data, 'deliveryMs'); final testDf = DataFrame( - [[2.0, 50.0, 14.0, 0.0]], + [ + [2.0, 50.0, 14.0, 0.0], + ], headerExists: false, header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], ); final pred = model.predict(testDf).rows.first.first; - debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)'); + debugPrint( + 'With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)', + ); }); test('LinearRegressor 2-feature works correctly', () { // Just pathLength + messageBytes → deliveryMs - final data = DataFrame([ - [0.0, 50.0, 1900.0], - [0.0, 80.0, 2200.0], - [2.0, 50.0, 5000.0], - [2.0, 80.0, 5500.0], - [4.0, 50.0, 9500.0], - [4.0, 80.0, 10000.0], - [0.0, 30.0, 1800.0], - [2.0, 30.0, 4800.0], - [4.0, 30.0, 9000.0], - [0.0, 60.0, 2000.0], - ], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']); + final data = DataFrame( + [ + [0.0, 50.0, 1900.0], + [0.0, 80.0, 2200.0], + [2.0, 50.0, 5000.0], + [2.0, 80.0, 5500.0], + [4.0, 50.0, 9500.0], + [4.0, 80.0, 10000.0], + [0.0, 30.0, 1800.0], + [2.0, 30.0, 4800.0], + [4.0, 30.0, 9000.0], + [0.0, 60.0, 2000.0], + ], + headerExists: false, + header: ['pathLength', 'messageBytes', 'deliveryMs'], + ); final model = LinearRegressor(data, 'deliveryMs'); for (final hops in [0.0, 2.0, 4.0]) { final testDf = DataFrame( - [[hops, 50.0]], + [ + [hops, 50.0], + ], headerExists: false, header: ['pathLength', 'messageBytes'], ); @@ -87,20 +111,28 @@ void main() { test('LinearRegressor multi-feature with variance in all columns', () { // Mix flood and direct so isFlood has variance - final data = DataFrame([ - [0.0, 50.0, 14.0, 0.0, 1900.0], - [0.0, 80.0, 10.0, 0.0, 2200.0], - [2.0, 50.0, 16.0, 0.0, 5000.0], - [2.0, 80.0, 20.0, 0.0, 5500.0], - [4.0, 50.0, 8.0, 0.0, 9500.0], - [4.0, 80.0, 12.0, 0.0, 10000.0], - [-1.0, 40.0, 14.0, 1.0, 5000.0], - [-1.0, 60.0, 18.0, 1.0, 6500.0], - [-1.0, 30.0, 10.0, 1.0, 4000.0], - [-1.0, 80.0, 22.0, 1.0, 7000.0], - ], headerExists: false, header: [ - 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', - ]); + final data = DataFrame( + [ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 10.0, 0.0, 2200.0], + [2.0, 50.0, 16.0, 0.0, 5000.0], + [2.0, 80.0, 20.0, 0.0, 5500.0], + [4.0, 50.0, 8.0, 0.0, 9500.0], + [4.0, 80.0, 12.0, 0.0, 10000.0], + [-1.0, 40.0, 14.0, 1.0, 5000.0], + [-1.0, 60.0, 18.0, 1.0, 6500.0], + [-1.0, 30.0, 10.0, 1.0, 4000.0], + [-1.0, 80.0, 22.0, 1.0, 7000.0], + ], + headerExists: false, + header: [ + 'pathLength', + 'messageBytes', + 'hourOfDay', + 'isFlood', + 'deliveryMs', + ], + ); final model = LinearRegressor(data, 'deliveryMs'); @@ -116,7 +148,9 @@ void main() { header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], ); final pred = model.predict(testDf).rows.first.first; - debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms'); + debugPrint( + '4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms', + ); } }); } diff --git a/test/services/timeout_prediction_service_test.dart b/test/services/timeout_prediction_service_test.dart index 46dc5dfd..dbd852d8 100644 --- a/test/services/timeout_prediction_service_test.dart +++ b/test/services/timeout_prediction_service_test.dart @@ -64,9 +64,9 @@ void main() { expect(direct4!, greaterThan(direct2!)); expect(direct2, greaterThan(direct0!)); - // All should be within the clamp range - expect(direct0, greaterThanOrEqualTo(2000)); - expect(direct4, lessThanOrEqualTo(120000)); + // All should be positive + expect(direct0, greaterThan(0)); + expect(direct4, greaterThan(0)); // Print predictions for visibility debugPrint('Predictions (with 1.5x safety margin):');