From 2ee2358eccebb18f5e34b062d95b26f9e73eb8fe Mon Sep 17 00:00:00 2001 From: zjs81 Date: Sat, 14 Mar 2026 16:56:11 -0700 Subject: [PATCH 1/3] feat: add ML-based adaptive timeout prediction using LinearRegressor Train a linear regression model on actual message delivery times to predict tighter timeouts, replacing worst-case physics estimates. Features: path length, message bytes, seconds since last RX, flood mode. Global model with per-contact blending after 10+ observations per contact. Falls back to existing physics formula when model has insufficient data. --- lib/connector/meshcore_connector.dart | 36 ++- lib/main.dart | 8 + lib/models/delivery_observation.dart | 43 ++++ lib/services/message_retry_service.dart | 54 +++-- lib/services/storage_service.dart | 50 ++++ lib/services/timeout_prediction_service.dart | 224 ++++++++++++++++++ pubspec.yaml | 2 + test/services/ml_algo_sanity_test.dart | 122 ++++++++++ .../timeout_prediction_service_test.dart | 164 +++++++++++++ 9 files changed, 683 insertions(+), 20 deletions(-) create mode 100644 lib/models/delivery_observation.dart create mode 100644 lib/services/timeout_prediction_service.dart create mode 100644 test/services/ml_algo_sanity_test.dart create mode 100644 test/services/timeout_prediction_service_test.dart diff --git a/lib/connector/meshcore_connector.dart b/lib/connector/meshcore_connector.dart index 7cf32ef0..d05a8f9c 100644 --- a/lib/connector/meshcore_connector.dart +++ b/lib/connector/meshcore_connector.dart @@ -19,6 +19,7 @@ import '../services/message_retry_service.dart'; import '../services/path_history_service.dart'; import '../services/app_settings_service.dart'; import '../services/background_service.dart'; +import '../services/timeout_prediction_service.dart'; import '../services/notification_service.dart'; import 'meshcore_connector_usb.dart'; import 'meshcore_connector_tcp.dart'; @@ -166,6 +167,8 @@ class MeshCoreConnector extends ChangeNotifier { bool _isLoadingContacts = false; bool _isLoadingChannels = false; bool _hasLoadedChannels = false; + TimeoutPredictionService? _timeoutPredictionService; + DateTime _lastRxTime = DateTime.now(); bool _batteryRequested = false; bool _awaitingSelfInfo = false; bool _hasReceivedDeviceInfo = false; @@ -668,6 +671,7 @@ class MeshCoreConnector extends ChangeNotifier { BleDebugLogService? bleDebugLogService, AppDebugLogService? appDebugLogService, BackgroundService? backgroundService, + TimeoutPredictionService? timeoutPredictionService, }) { _retryService = retryService; _pathHistoryService = pathHistoryService; @@ -675,6 +679,7 @@ class MeshCoreConnector extends ChangeNotifier { _bleDebugLogService = bleDebugLogService; _appDebugLogService = appDebugLogService; _backgroundService = backgroundService; + _timeoutPredictionService = timeoutPredictionService; _usbManager.setDebugLogService(_appDebugLogService); _tcpConnector.setDebugLogService(_appDebugLogService); @@ -689,13 +694,23 @@ class MeshCoreConnector extends ChangeNotifier { updateMessageCallback: _updateMessage, clearContactPathCallback: clearContactPath, setContactPathCallback: setContactPath, - calculateTimeoutCallback: (pathLength, messageBytes) => - calculateTimeout(pathLength: pathLength, messageBytes: messageBytes), + calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) => + calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey), getSelfPublicKeyCallback: () => _selfPublicKey, prepareContactOutboundTextCallback: prepareContactOutboundText, appSettingsService: appSettingsService, debugLogService: _appDebugLogService, recordPathResultCallback: _recordPathResult, + onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) { + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + _timeoutPredictionService?.recordObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + tripTimeMs: tripTimeMs, + secondsSinceLastRx: secSinceRx, + ); + }, ); } @@ -2498,6 +2513,7 @@ class MeshCoreConnector extends ChangeNotifier { void _handleFrame(List data) { if (data.isEmpty) return; + _lastRxTime = DateTime.now(); final frame = Uint8List.fromList(data); _receivedFramesController.add(frame); @@ -2876,7 +2892,21 @@ class MeshCoreConnector extends ChangeNotifier { /// Calculate timeout for a message based on radio settings and path length /// Returns timeout in milliseconds, considering number of hops - int calculateTimeout({required int pathLength, int messageBytes = 100}) { + int calculateTimeout({ + required int pathLength, + int messageBytes = 100, + String? contactKey, + }) { + // Try ML-based prediction first + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + final mlTimeout = _timeoutPredictionService?.predictTimeout( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: secSinceRx, + ); + if (mlTimeout != null) return mlTimeout; + // If we have radio settings, use them for accurate calculation if (_currentFreqHz != null && _currentBwHz != null && diff --git a/lib/main.dart b/lib/main.dart index 9e53e215..72909e2b 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -19,6 +19,7 @@ import 'services/app_debug_log_service.dart'; import 'services/background_service.dart'; import 'services/map_tile_cache_service.dart'; import 'services/chat_text_scale_service.dart'; +import 'services/timeout_prediction_service.dart'; import 'storage/prefs_manager.dart'; import 'utils/app_logger.dart'; @@ -39,6 +40,7 @@ void main() async { final backgroundService = BackgroundService(); final mapTileCacheService = MapTileCacheService(); final chatTextScaleService = ChatTextScaleService(); + final timeoutPredictionService = TimeoutPredictionService(storage); // Load settings await appSettingsService.loadSettings(); @@ -56,6 +58,7 @@ void main() async { _registerThirdPartyLicenses(); await chatTextScaleService.initialize(); + await timeoutPredictionService.initialize(); // Wire up connector with services connector.initialize( @@ -65,6 +68,7 @@ void main() async { bleDebugLogService: bleDebugLogService, appDebugLogService: appDebugLogService, backgroundService: backgroundService, + timeoutPredictionService: timeoutPredictionService, ); await connector.loadContactCache(); @@ -86,6 +90,7 @@ void main() async { appDebugLogService: appDebugLogService, mapTileCacheService: mapTileCacheService, chatTextScaleService: chatTextScaleService, + timeoutPredictionService: timeoutPredictionService, ), ); } @@ -121,6 +126,7 @@ class MeshCoreApp extends StatelessWidget { final AppDebugLogService appDebugLogService; final MapTileCacheService mapTileCacheService; final ChatTextScaleService chatTextScaleService; + final TimeoutPredictionService timeoutPredictionService; const MeshCoreApp({ super.key, @@ -133,6 +139,7 @@ class MeshCoreApp extends StatelessWidget { required this.appDebugLogService, required this.mapTileCacheService, required this.chatTextScaleService, + required this.timeoutPredictionService, }); @override @@ -148,6 +155,7 @@ class MeshCoreApp extends StatelessWidget { ChangeNotifierProvider.value(value: chatTextScaleService), Provider.value(value: storage), Provider.value(value: mapTileCacheService), + ChangeNotifierProvider.value(value: timeoutPredictionService), ], child: Consumer( builder: (context, settingsService, child) { diff --git a/lib/models/delivery_observation.dart b/lib/models/delivery_observation.dart new file mode 100644 index 00000000..a598d2a5 --- /dev/null +++ b/lib/models/delivery_observation.dart @@ -0,0 +1,43 @@ +class DeliveryObservation { + final String contactKey; + final int pathLength; + final int messageBytes; + final int secondsSinceLastRx; + final bool isFlood; + final int deliveryMs; + final DateTime timestamp; + + DeliveryObservation({ + required this.contactKey, + required this.pathLength, + required this.messageBytes, + required this.secondsSinceLastRx, + required this.isFlood, + required this.deliveryMs, + required this.timestamp, + }); + + Map toJson() { + return { + 'contact_key': contactKey, + 'path_length': pathLength, + 'message_bytes': messageBytes, + 'seconds_since_last_rx': secondsSinceLastRx, + 'is_flood': isFlood, + 'delivery_ms': deliveryMs, + 'timestamp': timestamp.toIso8601String(), + }; + } + + factory DeliveryObservation.fromJson(Map json) { + return DeliveryObservation( + contactKey: json['contact_key'] as String, + pathLength: json['path_length'] as int, + messageBytes: json['message_bytes'] as int, + secondsSinceLastRx: json['seconds_since_last_rx'] as int? ?? 0, + isFlood: json['is_flood'] as bool, + deliveryMs: json['delivery_ms'] as int, + timestamp: DateTime.parse(json['timestamp'] as String), + ); + } +} diff --git a/lib/services/message_retry_service.dart b/lib/services/message_retry_service.dart index db4475fd..d94b763b 100644 --- a/lib/services/message_retry_service.dart +++ b/lib/services/message_retry_service.dart @@ -58,12 +58,13 @@ class MessageRetryService extends ChangeNotifier { Function(Message)? _updateMessageCallback; Function(Contact)? _clearContactPathCallback; Function(Contact, Uint8List, int)? _setContactPathCallback; - Function(int, int)? _calculateTimeoutCallback; + Function(int, int, {String? contactKey})? _calculateTimeoutCallback; Uint8List? Function()? _getSelfPublicKeyCallback; String Function(Contact, String)? _prepareContactOutboundTextCallback; AppSettingsService? _appSettingsService; AppDebugLogService? _debugLogService; Function(String, PathSelection, bool, int?)? _recordPathResultCallback; + Function(String, int, int, int)? _onDeliveryObservedCallback; MessageRetryService(); @@ -73,12 +74,14 @@ class MessageRetryService extends ChangeNotifier { required Function(Message) updateMessageCallback, Function(Contact)? clearContactPathCallback, Function(Contact, Uint8List, int)? setContactPathCallback, - Function(int pathLength, int messageBytes)? calculateTimeoutCallback, + Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback, Uint8List? Function()? getSelfPublicKeyCallback, String Function(Contact, String)? prepareContactOutboundTextCallback, AppSettingsService? appSettingsService, AppDebugLogService? debugLogService, Function(String, PathSelection, bool, int?)? recordPathResultCallback, + Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)? + onDeliveryObservedCallback, }) { _sendMessageCallback = sendMessageCallback; _addMessageCallback = addMessageCallback; @@ -91,6 +94,7 @@ class MessageRetryService extends ChangeNotifier { _appSettingsService = appSettingsService; _debugLogService = debugLogService; _recordPathResultCallback = recordPathResultCallback; + _onDeliveryObservedCallback = onDeliveryObservedCallback; } /// Compute expected ACK hash using same algorithm as firmware: @@ -423,25 +427,33 @@ class MessageRetryService extends ChangeNotifier { ); } - // Use device-provided timeout, or calculate from radio settings if timeout is 0 or invalid + // Calculate timeout: prefer ML prediction, then device-provided, then physics fallback + int pathLengthValue; + if (selection != null) { + pathLengthValue = selection.useFlood ? -1 : selection.hopCount; + if (pathLengthValue < 0) pathLengthValue = contact.pathLength; + } else if (message.pathLength != null) { + pathLengthValue = message.pathLength!; + } else { + pathLengthValue = contact.pathLength; + } + int actualTimeout = timeoutMs; - if (timeoutMs <= 0 && _calculateTimeoutCallback != null) { - int pathLengthValue; - if (selection != null) { - pathLengthValue = selection.useFlood ? -1 : selection.hopCount; - if (pathLengthValue < 0) pathLengthValue = contact.pathLength; - } else if (message.pathLength != null) { - pathLengthValue = message.pathLength!; - } else { - pathLengthValue = contact.pathLength; - } - actualTimeout = _calculateTimeoutCallback!( + if (_calculateTimeoutCallback != null) { + final calculated = _calculateTimeoutCallback!( pathLengthValue, message.text.length, + contactKey: contact.publicKeyHex, ); - debugPrint( - 'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue', - ); + // calculateTimeout tries ML first, falls back to physics. + // Use calculated value if device didn't provide one, or if ML + // produced a tighter prediction than the device's estimate. + if (timeoutMs <= 0 || calculated < timeoutMs) { + actualTimeout = calculated; + debugPrint( + 'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue', + ); + } } final updatedMessage = message.copyWith( @@ -738,6 +750,14 @@ class MessageRetryService extends ChangeNotifier { true, tripTimeMs, ); + if (_onDeliveryObservedCallback != null && tripTimeMs > 0) { + _onDeliveryObservedCallback!( + contact.publicKeyHex, + message.pathLength ?? 0, + message.text.length, + tripTimeMs, + ); + } _onMessageResolved(matchedMessageId, contact.publicKeyHex); } diff --git a/lib/services/storage_service.dart b/lib/services/storage_service.dart index ce0c4f19..c591f648 100644 --- a/lib/services/storage_service.dart +++ b/lib/services/storage_service.dart @@ -1,4 +1,5 @@ import 'dart:convert'; +import '../models/delivery_observation.dart'; import '../models/path_history.dart'; import '../storage/prefs_manager.dart'; @@ -6,6 +7,8 @@ class StorageService { static const String _pathHistoryPrefix = 'path_history_'; static const String _pendingMessagesKey = 'pending_messages'; static const String _repeaterPasswordsKey = 'repeater_passwords'; + static const String _deliveryObservationsKey = 'delivery_observations'; + static const String _timeoutModelKey = 'timeout_ml_model'; Future savePathHistory( String contactPubKeyHex, @@ -122,4 +125,51 @@ class StorageService { final prefs = PrefsManager.instance; await prefs.remove(_repeaterPasswordsKey); } + + Future saveDeliveryObservations( + List observations, + ) async { + final prefs = PrefsManager.instance; + final jsonStr = jsonEncode(observations.map((o) => o.toJson()).toList()); + await prefs.setString(_deliveryObservationsKey, jsonStr); + } + + Future> loadDeliveryObservations() async { + final prefs = PrefsManager.instance; + final jsonStr = prefs.getString(_deliveryObservationsKey); + + if (jsonStr == null) return []; + + try { + final list = jsonDecode(jsonStr) as List; + return list + .map( + (e) => + DeliveryObservation.fromJson(e as Map), + ) + .toList(); + } catch (e) { + return []; + } + } + + Future clearDeliveryObservations() async { + final prefs = PrefsManager.instance; + await prefs.remove(_deliveryObservationsKey); + } + + Future saveTimeoutModel(String modelJson) async { + final prefs = PrefsManager.instance; + await prefs.setString(_timeoutModelKey, modelJson); + } + + Future loadTimeoutModel() async { + final prefs = PrefsManager.instance; + return prefs.getString(_timeoutModelKey); + } + + Future clearTimeoutModel() async { + final prefs = PrefsManager.instance; + await prefs.remove(_timeoutModelKey); + } } diff --git a/lib/services/timeout_prediction_service.dart b/lib/services/timeout_prediction_service.dart new file mode 100644 index 00000000..21e229e8 --- /dev/null +++ b/lib/services/timeout_prediction_service.dart @@ -0,0 +1,224 @@ +import 'dart:convert'; +import 'dart:math'; +import 'package:flutter/foundation.dart'; +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; +import '../models/delivery_observation.dart'; +import 'storage_service.dart'; + +class _ContactStats { + int count = 0; + double _sum = 0; + double _sumSq = 0; + + void add(double ms) { + count++; + _sum += ms; + _sumSq += ms * ms; + } + + double get mean => _sum / count; + double get stdDev => sqrt((_sumSq / count) - (mean * mean)); +} + +class TimeoutPredictionService extends ChangeNotifier { + final StorageService? _storage; + + static const int minObservations = 10; + static const int maxObservations = 100; + static const int _retrainInterval = 5; + static const double _safetyMargin = 1.5; + static const int _minTimeoutMs = 2000; + static const int _maxTimeoutMs = 120000; + static const int _minContactObservations = 10; + + List _observations = []; + LinearRegressor? _model; + List _activeFeatures = []; + int _observationsSinceLastTrain = 0; + final Map _contactStats = {}; + + TimeoutPredictionService(StorageService storage) : _storage = storage; + TimeoutPredictionService.noStorage() : _storage = null; + + int get observationCount => _observations.length; + bool get hasModel => _model != null; + + Future initialize() async { + _observations = await _storage?.loadDeliveryObservations() ?? []; + _rebuildContactStats(); + + if (_observations.length >= minObservations) { + _trainModel(); + } + + debugPrint( + 'TimeoutPrediction: initialized with ${_observations.length} observations, ' + 'model=${_model != null ? "ready" : "waiting for data"}', + ); + } + + void recordObservation({ + required String contactKey, + required int pathLength, + required int messageBytes, + required int tripTimeMs, + int secondsSinceLastRx = 0, + }) { + final observation = DeliveryObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: secondsSinceLastRx, + isFlood: pathLength < 0, + deliveryMs: tripTimeMs, + timestamp: DateTime.now(), + ); + + _observations.add(observation); + if (_observations.length > maxObservations) { + _observations.removeAt(0); + } + + _contactStats.putIfAbsent(contactKey, () => _ContactStats()); + _contactStats[contactKey]!.add(tripTimeMs.toDouble()); + + _observationsSinceLastTrain++; + if (_observationsSinceLastTrain >= _retrainInterval && + _observations.length >= minObservations) { + _trainModel(); + } + + _storage?.saveDeliveryObservations(_observations); + debugPrint( + 'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops ' + '(${_observations.length} total)', + ); + } + + int? predictTimeout({ + String? contactKey, + required int pathLength, + required int messageBytes, + int secondsSinceLastRx = 0, + }) { + if (_model == null) return null; + + try { + if (_activeFeatures.isEmpty) return null; + + final allFeatures = { + 'pathLength': pathLength.toDouble(), + 'messageBytes': messageBytes.toDouble(), + 'secSinceRx': secondsSinceLastRx.toDouble(), + 'isFlood': pathLength < 0 ? 1.0 : 0.0, + }; + final row = _activeFeatures.map((f) => allFeatures[f]!).toList(); + + final features = DataFrame( + [row], + headerExists: false, + header: _activeFeatures, + ); + + final prediction = _model!.predict(features); + final rawValue = prediction.rows.first.first; + var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble(); + + debugPrint( + 'TimeoutPrediction: raw prediction=$predictedMs for ' + 'pathLength=$pathLength, messageBytes=$messageBytes, ' + 'features=$_activeFeatures', + ); + + // Sanity check: if prediction is negative or zero, fall back + if (predictedMs <= 0) return null; + + // Blend with per-contact mean if enough data + if (contactKey != null) { + final stats = _contactStats[contactKey]; + if (stats != null && stats.count >= _minContactObservations) { + predictedMs = 0.5 * predictedMs + 0.5 * stats.mean; + } + } + + final timeout = + (predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs); + debugPrint( + 'TimeoutPrediction: ML timeout ${timeout}ms ' + '(raw: ${predictedMs.round()}ms, contact: $contactKey)', + ); + return timeout; + } catch (e) { + debugPrint('TimeoutPrediction: prediction failed: $e'); + return null; + } + } + + void _trainModel() { + try { + // Build feature columns, then exclude any with zero variance + // (ml_algo's OLS produces all-zero coefficients for singular matrices) + final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood']; + final allExtractors = [ + (o) => o.pathLength.toDouble(), + (o) => o.messageBytes.toDouble(), + (o) => o.secondsSinceLastRx.toDouble(), + (o) => o.isFlood ? 1.0 : 0.0, + ]; + + _activeFeatures = []; + for (var i = 0; i < allNames.length; i++) { + final values = _observations.map(allExtractors[i]).toSet(); + if (values.length > 1) _activeFeatures.add(allNames[i]); + } + + if (_activeFeatures.isEmpty) { + debugPrint('TimeoutPrediction: no features with variance, skipping training'); + return; + } + + final header = [..._activeFeatures, 'deliveryMs']; + final rows = _observations.map((o) { + final row = []; + for (var i = 0; i < allNames.length; i++) { + if (_activeFeatures.contains(allNames[i])) { + row.add(allExtractors[i](o)); + } + } + row.add(o.deliveryMs.toDouble()); + return row; + }); + + final data = DataFrame( + [header, ...rows], + headerExists: true, + ); + + _model = LinearRegressor(data, 'deliveryMs'); + _observationsSinceLastTrain = 0; + + // Log training summary with sample predictions + final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / + _observations.length; + debugPrint( + 'TimeoutPrediction: trained on ${_observations.length} observations ' + '(avg: ${avgMs.round()}ms, features: $_activeFeatures)', + ); + + final modelJson = jsonEncode(_model!.toJson()); + _storage?.saveTimeoutModel(modelJson); + notifyListeners(); + } catch (e) { + debugPrint('TimeoutPrediction: training failed: $e'); + } + } + + void _rebuildContactStats() { + _contactStats.clear(); + for (final obs in _observations) { + _contactStats.putIfAbsent(obs.contactKey, () => _ContactStats()); + _contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble()); + } + } +} diff --git a/pubspec.yaml b/pubspec.yaml index 82e4d9c9..4831e672 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -69,6 +69,8 @@ dependencies: material_symbols_icons: ^4.2906.0 web: ^1.1.1 flutter_svg: ^2.0.10+1 + ml_algo: ^16.0.0 + ml_dataframe: ^1.0.0 dev_dependencies: flutter_test: diff --git a/test/services/ml_algo_sanity_test.dart b/test/services/ml_algo_sanity_test.dart new file mode 100644 index 00000000..e4f980ed --- /dev/null +++ b/test/services/ml_algo_sanity_test.dart @@ -0,0 +1,122 @@ +import 'package:flutter/foundation.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; + +void main() { + test('LinearRegressor basic sanity check', () { + // Simple: y = 2x + 100 + final data = DataFrame([ + [1.0, 102.0], + [2.0, 104.0], + [3.0, 106.0], + [4.0, 108.0], + [5.0, 110.0], + [10.0, 120.0], + [20.0, 140.0], + [50.0, 200.0], + [0.0, 100.0], + [100.0, 300.0], + ], headerExists: false, header: ['x', 'y']); + + debugPrint('Training data columns: ${data.header}'); + debugPrint('Training data rows: ${data.rows.length}'); + + final model = LinearRegressor(data, 'y'); + + final testDf = DataFrame( + [[25.0]], + headerExists: false, + header: ['x'], + ); + + final prediction = model.predict(testDf); + final value = prediction.rows.first.first; + debugPrint('Predict x=25 → y=$value (expected ~150)'); + expect((value as num).toDouble(), closeTo(150, 5)); + }); + + test('LinearRegressor multi-feature with constant column produces zeros', () { + // isFlood=0 for all rows → zero-variance column → singular matrix + final data = DataFrame([ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 14.0, 0.0, 2200.0], + [2.0, 50.0, 14.0, 0.0, 5000.0], + [4.0, 50.0, 14.0, 0.0, 9500.0], + ], headerExists: false, header: [ + 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', + ]); + + final model = LinearRegressor(data, 'deliveryMs'); + final testDf = DataFrame( + [[2.0, 50.0, 14.0, 0.0]], + headerExists: false, + header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)'); + }); + + test('LinearRegressor 2-feature works correctly', () { + // Just pathLength + messageBytes → deliveryMs + final data = DataFrame([ + [0.0, 50.0, 1900.0], + [0.0, 80.0, 2200.0], + [2.0, 50.0, 5000.0], + [2.0, 80.0, 5500.0], + [4.0, 50.0, 9500.0], + [4.0, 80.0, 10000.0], + [0.0, 30.0, 1800.0], + [2.0, 30.0, 4800.0], + [4.0, 30.0, 9000.0], + [0.0, 60.0, 2000.0], + ], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']); + + final model = LinearRegressor(data, 'deliveryMs'); + + for (final hops in [0.0, 2.0, 4.0]) { + final testDf = DataFrame( + [[hops, 50.0]], + headerExists: false, + header: ['pathLength', 'messageBytes'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('2-feature: hops=$hops → ${(pred as num).round()}ms'); + } + }); + + test('LinearRegressor multi-feature with variance in all columns', () { + // Mix flood and direct so isFlood has variance + final data = DataFrame([ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 10.0, 0.0, 2200.0], + [2.0, 50.0, 16.0, 0.0, 5000.0], + [2.0, 80.0, 20.0, 0.0, 5500.0], + [4.0, 50.0, 8.0, 0.0, 9500.0], + [4.0, 80.0, 12.0, 0.0, 10000.0], + [-1.0, 40.0, 14.0, 1.0, 5000.0], + [-1.0, 60.0, 18.0, 1.0, 6500.0], + [-1.0, 30.0, 10.0, 1.0, 4000.0], + [-1.0, 80.0, 22.0, 1.0, 7000.0], + ], headerExists: false, header: [ + 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', + ]); + + final model = LinearRegressor(data, 'deliveryMs'); + + for (final tc in [ + [0.0, 50.0, 14.0, 0.0], + [2.0, 50.0, 14.0, 0.0], + [4.0, 50.0, 14.0, 0.0], + [-1.0, 50.0, 14.0, 1.0], + ]) { + final testDf = DataFrame( + [tc], + headerExists: false, + header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], + ); + final pred = model.predict(testDf).rows.first.first; + debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms'); + } + }); +} diff --git a/test/services/timeout_prediction_service_test.dart b/test/services/timeout_prediction_service_test.dart new file mode 100644 index 00000000..46dc5dfd --- /dev/null +++ b/test/services/timeout_prediction_service_test.dart @@ -0,0 +1,164 @@ +import 'package:flutter/foundation.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:meshcore_open/models/delivery_observation.dart'; +import 'package:meshcore_open/services/timeout_prediction_service.dart'; + +void main() { + late TimeoutPredictionService service; + + setUp(() { + service = TimeoutPredictionService.noStorage(); + }); + + test('trains on sample data and predicts sensible timeouts', () { + // Simulate realistic delivery data: + // Direct 0-hop messages: ~1500-2500ms + // 2-hop messages: ~4000-6000ms + // 4-hop messages: ~8000-12000ms + // Flood messages: ~3000-8000ms + final sampleData = [ + // 0-hop direct + _obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800), + _obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100), + _obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400), + _obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925), + // 2-hop direct + _obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500), + _obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200), + _obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100), + // 4-hop direct + _obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800), + _obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500), + _obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570), + // Flood + _obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000), + _obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500), + ]; + + // Feed all observations + for (final obs in sampleData) { + service.recordObservation( + contactKey: obs.contactKey, + pathLength: obs.pathLength, + messageBytes: obs.messageBytes, + tripTimeMs: obs.deliveryMs, + ); + } + + expect(service.hasModel, isTrue); + expect(service.observationCount, equals(12)); + + // Predict for different scenarios + final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50); + final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50); + final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50); + final flood = service.predictTimeout(pathLength: -1, messageBytes: 50); + + // All should return non-null (model is trained) + expect(direct0, isNotNull); + expect(direct2, isNotNull); + expect(direct4, isNotNull); + expect(flood, isNotNull); + + // More hops should predict longer timeouts + expect(direct4!, greaterThan(direct2!)); + expect(direct2, greaterThan(direct0!)); + + // All should be within the clamp range + expect(direct0, greaterThanOrEqualTo(2000)); + expect(direct4, lessThanOrEqualTo(120000)); + + // Print predictions for visibility + debugPrint('Predictions (with 1.5x safety margin):'); + debugPrint(' 0-hop direct: ${direct0}ms'); + debugPrint(' 2-hop direct: ${direct2}ms'); + debugPrint(' 4-hop direct: ${direct4}ms'); + debugPrint(' flood: ${flood}ms'); + }); + + test('returns null before minimum observations', () { + for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) { + service.recordObservation( + contactKey: 'abc', + pathLength: 0, + messageBytes: 50, + tripTimeMs: 2000, + ); + } + + expect(service.hasModel, isFalse); + expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull); + }); + + test('caps observations at maxObservations', () { + for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) { + service.recordObservation( + contactKey: 'abc', + pathLength: 0, + messageBytes: 50, + tripTimeMs: 2000 + i, + ); + } + + expect( + service.observationCount, + equals(TimeoutPredictionService.maxObservations), + ); + }); + + test('blends per-contact stats after enough observations', () { + // Train with mixed contacts and varied features: + // contactA is fast (0-hop), contactB is slow (2-hop) + for (var i = 0; i < 12; i++) { + service.recordObservation( + contactKey: 'contactA', + pathLength: 0, + messageBytes: 30 + i, + tripTimeMs: 1500, + ); + service.recordObservation( + contactKey: 'contactB', + pathLength: 2, + messageBytes: 30 + i, + tripTimeMs: 8000, + ); + } + + final predA = service.predictTimeout( + contactKey: 'contactA', + pathLength: 0, + messageBytes: 50, + ); + final predB = service.predictTimeout( + contactKey: 'contactB', + pathLength: 0, + messageBytes: 50, + ); + + expect(predA, isNotNull); + expect(predB, isNotNull); + // Contact B (slow) should have a higher predicted timeout than A (fast) + expect(predB!, greaterThan(predA!)); + + debugPrint('Per-contact blending:'); + debugPrint(' contactA (fast): ${predA}ms'); + debugPrint(' contactB (slow): ${predB}ms'); + }); +} + +DeliveryObservation _obs({ + required int pathLength, + required int messageBytes, + required int deliveryMs, + String contactKey = 'test_contact', +}) { + return DeliveryObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + secondsSinceLastRx: 5, + isFlood: pathLength < 0, + deliveryMs: deliveryMs, + timestamp: DateTime.now(), + ); +} From b336aedbc58e2646149c071509ac94264744c8b6 Mon Sep 17 00:00:00 2001 From: zjs81 Date: Sat, 14 Mar 2026 17:32:08 -0700 Subject: [PATCH 2/3] fix: address PR #296 code review feedback - Clamp ML predictions between physics floor (raw airtime) and ceiling (worst-case formula) so model can never produce unsafe timeouts - Replace hourOfDay feature with secondsSinceLastRx for network activity - Remove unused _ContactStats.stdDev and dead model persistence code - Debounce observation writes (2s) instead of writing on every delivery - Skip recording observations when pathLength is null to avoid corrupting training data - Add comment explaining global (not per-contact) RX time tracking - Remove notifyListeners from retrain to avoid unnecessary widget rebuilds - Run dart format --- lib/connector/meshcore_connector.dart | 107 ++++++++------ lib/services/message_retry_service.dart | 18 ++- lib/services/storage_service.dart | 21 +-- lib/services/timeout_prediction_service.dart | 41 +++--- test/services/ml_algo_sanity_test.dart | 136 +++++++++++------- .../timeout_prediction_service_test.dart | 6 +- 6 files changed, 187 insertions(+), 142 deletions(-) diff --git a/lib/connector/meshcore_connector.dart b/lib/connector/meshcore_connector.dart index d05a8f9c..33e5c48e 100644 --- a/lib/connector/meshcore_connector.dart +++ b/lib/connector/meshcore_connector.dart @@ -168,6 +168,8 @@ class MeshCoreConnector extends ChangeNotifier { bool _isLoadingChannels = false; bool _hasLoadedChannels = false; TimeoutPredictionService? _timeoutPredictionService; + // Intentionally global (not per-contact): tracks overall network activity. + // Frequent RX from any source indicates a busy network with more collisions. DateTime _lastRxTime = DateTime.now(); bool _batteryRequested = false; bool _awaitingSelfInfo = false; @@ -694,23 +696,28 @@ class MeshCoreConnector extends ChangeNotifier { updateMessageCallback: _updateMessage, clearContactPathCallback: clearContactPath, setContactPathCallback: setContactPath, - calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) => - calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey), + calculateTimeoutCallback: + (pathLength, messageBytes, {String? contactKey}) => calculateTimeout( + pathLength: pathLength, + messageBytes: messageBytes, + contactKey: contactKey, + ), getSelfPublicKeyCallback: () => _selfPublicKey, prepareContactOutboundTextCallback: prepareContactOutboundText, appSettingsService: appSettingsService, debugLogService: _appDebugLogService, recordPathResultCallback: _recordPathResult, - onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) { - final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; - _timeoutPredictionService?.recordObservation( - contactKey: contactKey, - pathLength: pathLength, - messageBytes: messageBytes, - tripTimeMs: tripTimeMs, - secondsSinceLastRx: secSinceRx, - ); - }, + onDeliveryObservedCallback: + (contactKey, pathLength, messageBytes, tripTimeMs) { + final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; + _timeoutPredictionService?.recordObservation( + contactKey: contactKey, + pathLength: pathLength, + messageBytes: messageBytes, + tripTimeMs: tripTimeMs, + secondsSinceLastRx: secSinceRx, + ); + }, ); } @@ -2890,14 +2897,54 @@ class MeshCoreConnector extends ChangeNotifier { } } - /// Calculate timeout for a message based on radio settings and path length - /// Returns timeout in milliseconds, considering number of hops + /// Estimate single-packet airtime in ms from radio settings, or a fallback. + int _estimateAirtimeMs(int messageBytes) { + if (_currentFreqHz != null && + _currentBwHz != null && + _currentSf != null && + _currentCr != null) { + final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4; + return calculateLoRaAirtime( + payloadBytes: messageBytes, + spreadingFactor: _currentSf!, + bandwidthHz: _currentBwHz!, + codingRate: cr, + lowDataRateOptimize: _currentSf! >= 11, + ); + } + return 50; // fallback: ~SF7/BW125 for 100 bytes + } + + /// Physics-based worst-case timeout (ceiling). + int _physicsMaxTimeout(int pathLength, int airtime) { + if (pathLength < 0) { + return 500 + (16 * airtime); + } else { + return 500 + ((airtime * 6 + 250) * (pathLength + 1)); + } + } + + /// Physics-based minimum timeout (floor): raw traversal time. + int _physicsMinTimeout(int pathLength, int airtime) { + if (pathLength < 0) { + return airtime; + } else { + return airtime * (pathLength + 1); + } + } + + /// Calculate timeout for a message based on radio settings and path length. + /// Returns timeout in milliseconds, considering number of hops. int calculateTimeout({ required int pathLength, int messageBytes = 100, String? contactKey, }) { - // Try ML-based prediction first + final airtime = _estimateAirtimeMs(messageBytes); + final physicsMin = _physicsMinTimeout(pathLength, airtime); + final physicsMax = _physicsMaxTimeout(pathLength, airtime); + + // Try ML-based prediction, clamped between physics bounds final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds; final mlTimeout = _timeoutPredictionService?.predictTimeout( contactKey: contactKey, @@ -2905,35 +2952,11 @@ class MeshCoreConnector extends ChangeNotifier { messageBytes: messageBytes, secondsSinceLastRx: secSinceRx, ); - if (mlTimeout != null) return mlTimeout; - - // If we have radio settings, use them for accurate calculation - if (_currentFreqHz != null && - _currentBwHz != null && - _currentSf != null && - _currentCr != null) { - final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4; - return calculateMessageTimeout( - freqHz: _currentFreqHz!, - bwHz: _currentBwHz!, - sf: _currentSf!, - cr: cr, - pathLength: pathLength, - messageBytes: messageBytes, - ); + if (mlTimeout != null) { + return mlTimeout.clamp(physicsMin, physicsMax); } - // Fallback: Conservative estimates based on typical settings - // Assume SF7, BW125, which gives ~50ms airtime for 100 bytes - const estimatedAirtime = 50; - - if (pathLength < 0) { - // Flood mode: Base delay + 16× airtime - return 500 + (16 * estimatedAirtime); - } else { - // Direct path: Base delay + ((airtime×6 + 250ms)×(hops+1)) - return 500 + ((estimatedAirtime * 6 + 250) * (pathLength + 1)); - } + return physicsMax; } void _handleContact(Uint8List frame, {bool isContact = true}) { diff --git a/lib/services/message_retry_service.dart b/lib/services/message_retry_service.dart index d94b763b..b66ba51a 100644 --- a/lib/services/message_retry_service.dart +++ b/lib/services/message_retry_service.dart @@ -74,14 +74,20 @@ class MessageRetryService extends ChangeNotifier { required Function(Message) updateMessageCallback, Function(Contact)? clearContactPathCallback, Function(Contact, Uint8List, int)? setContactPathCallback, - Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback, + Function(int pathLength, int messageBytes, {String? contactKey})? + calculateTimeoutCallback, Uint8List? Function()? getSelfPublicKeyCallback, String Function(Contact, String)? prepareContactOutboundTextCallback, AppSettingsService? appSettingsService, AppDebugLogService? debugLogService, Function(String, PathSelection, bool, int?)? recordPathResultCallback, - Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)? - onDeliveryObservedCallback, + Function( + String contactKey, + int pathLength, + int messageBytes, + int tripTimeMs, + )? + onDeliveryObservedCallback, }) { _sendMessageCallback = sendMessageCallback; _addMessageCallback = addMessageCallback; @@ -750,10 +756,12 @@ class MessageRetryService extends ChangeNotifier { true, tripTimeMs, ); - if (_onDeliveryObservedCallback != null && tripTimeMs > 0) { + if (_onDeliveryObservedCallback != null && + tripTimeMs > 0 && + message.pathLength != null) { _onDeliveryObservedCallback!( contact.publicKeyHex, - message.pathLength ?? 0, + message.pathLength!, message.text.length, tripTimeMs, ); diff --git a/lib/services/storage_service.dart b/lib/services/storage_service.dart index c591f648..a86c1f6d 100644 --- a/lib/services/storage_service.dart +++ b/lib/services/storage_service.dart @@ -8,7 +8,6 @@ class StorageService { static const String _pendingMessagesKey = 'pending_messages'; static const String _repeaterPasswordsKey = 'repeater_passwords'; static const String _deliveryObservationsKey = 'delivery_observations'; - static const String _timeoutModelKey = 'timeout_ml_model'; Future savePathHistory( String contactPubKeyHex, @@ -143,10 +142,7 @@ class StorageService { try { final list = jsonDecode(jsonStr) as List; return list - .map( - (e) => - DeliveryObservation.fromJson(e as Map), - ) + .map((e) => DeliveryObservation.fromJson(e as Map)) .toList(); } catch (e) { return []; @@ -157,19 +153,4 @@ class StorageService { final prefs = PrefsManager.instance; await prefs.remove(_deliveryObservationsKey); } - - Future saveTimeoutModel(String modelJson) async { - final prefs = PrefsManager.instance; - await prefs.setString(_timeoutModelKey, modelJson); - } - - Future loadTimeoutModel() async { - final prefs = PrefsManager.instance; - return prefs.getString(_timeoutModelKey); - } - - Future clearTimeoutModel() async { - final prefs = PrefsManager.instance; - await prefs.remove(_timeoutModelKey); - } } diff --git a/lib/services/timeout_prediction_service.dart b/lib/services/timeout_prediction_service.dart index 21e229e8..1f3d6ddf 100644 --- a/lib/services/timeout_prediction_service.dart +++ b/lib/services/timeout_prediction_service.dart @@ -1,5 +1,4 @@ -import 'dart:convert'; -import 'dart:math'; +import 'dart:async'; import 'package:flutter/foundation.dart'; import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -9,16 +8,13 @@ import 'storage_service.dart'; class _ContactStats { int count = 0; double _sum = 0; - double _sumSq = 0; void add(double ms) { count++; _sum += ms; - _sumSq += ms * ms; } double get mean => _sum / count; - double get stdDev => sqrt((_sumSq / count) - (mean * mean)); } class TimeoutPredictionService extends ChangeNotifier { @@ -27,9 +23,10 @@ class TimeoutPredictionService extends ChangeNotifier { static const int minObservations = 10; static const int maxObservations = 100; static const int _retrainInterval = 5; + // 1.5x multiplier on raw prediction to account for variance in delivery + // times — tight enough to improve on worst-case physics, loose enough + // to avoid premature timeouts from model noise. static const double _safetyMargin = 1.5; - static const int _minTimeoutMs = 2000; - static const int _maxTimeoutMs = 120000; static const int _minContactObservations = 10; List _observations = []; @@ -37,6 +34,7 @@ class TimeoutPredictionService extends ChangeNotifier { List _activeFeatures = []; int _observationsSinceLastTrain = 0; final Map _contactStats = {}; + Timer? _persistTimer; TimeoutPredictionService(StorageService storage) : _storage = storage; TimeoutPredictionService.noStorage() : _storage = null; @@ -89,7 +87,10 @@ class TimeoutPredictionService extends ChangeNotifier { _trainModel(); } - _storage?.saveDeliveryObservations(_observations); + _persistTimer?.cancel(); + _persistTimer = Timer(const Duration(seconds: 2), () { + _storage?.saveDeliveryObservations(_observations); + }); debugPrint( 'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops ' '(${_observations.length} total)', @@ -123,7 +124,9 @@ class TimeoutPredictionService extends ChangeNotifier { final prediction = _model!.predict(features); final rawValue = prediction.rows.first.first; - var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble(); + var predictedMs = (rawValue is double) + ? rawValue + : (rawValue as num).toDouble(); debugPrint( 'TimeoutPrediction: raw prediction=$predictedMs for ' @@ -142,8 +145,8 @@ class TimeoutPredictionService extends ChangeNotifier { } } - final timeout = - (predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs); + // Connector clamps this between physics min/max bounds + final timeout = (predictedMs * _safetyMargin).ceil(); debugPrint( 'TimeoutPrediction: ML timeout ${timeout}ms ' '(raw: ${predictedMs.round()}ms, contact: $contactKey)', @@ -174,7 +177,9 @@ class TimeoutPredictionService extends ChangeNotifier { } if (_activeFeatures.isEmpty) { - debugPrint('TimeoutPrediction: no features with variance, skipping training'); + debugPrint( + 'TimeoutPrediction: no features with variance, skipping training', + ); return; } @@ -190,25 +195,19 @@ class TimeoutPredictionService extends ChangeNotifier { return row; }); - final data = DataFrame( - [header, ...rows], - headerExists: true, - ); + final data = DataFrame([header, ...rows], headerExists: true); _model = LinearRegressor(data, 'deliveryMs'); _observationsSinceLastTrain = 0; // Log training summary with sample predictions - final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / + final avgMs = + _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / _observations.length; debugPrint( 'TimeoutPrediction: trained on ${_observations.length} observations ' '(avg: ${avgMs.round()}ms, features: $_activeFeatures)', ); - - final modelJson = jsonEncode(_model!.toJson()); - _storage?.saveTimeoutModel(modelJson); - notifyListeners(); } catch (e) { debugPrint('TimeoutPrediction: training failed: $e'); } diff --git a/test/services/ml_algo_sanity_test.dart b/test/services/ml_algo_sanity_test.dart index e4f980ed..427a8a6e 100644 --- a/test/services/ml_algo_sanity_test.dart +++ b/test/services/ml_algo_sanity_test.dart @@ -6,18 +6,22 @@ import 'package:ml_dataframe/ml_dataframe.dart'; void main() { test('LinearRegressor basic sanity check', () { // Simple: y = 2x + 100 - final data = DataFrame([ - [1.0, 102.0], - [2.0, 104.0], - [3.0, 106.0], - [4.0, 108.0], - [5.0, 110.0], - [10.0, 120.0], - [20.0, 140.0], - [50.0, 200.0], - [0.0, 100.0], - [100.0, 300.0], - ], headerExists: false, header: ['x', 'y']); + final data = DataFrame( + [ + [1.0, 102.0], + [2.0, 104.0], + [3.0, 106.0], + [4.0, 108.0], + [5.0, 110.0], + [10.0, 120.0], + [20.0, 140.0], + [50.0, 200.0], + [0.0, 100.0], + [100.0, 300.0], + ], + headerExists: false, + header: ['x', 'y'], + ); debugPrint('Training data columns: ${data.header}'); debugPrint('Training data rows: ${data.rows.length}'); @@ -25,7 +29,9 @@ void main() { final model = LinearRegressor(data, 'y'); final testDf = DataFrame( - [[25.0]], + [ + [25.0], + ], headerExists: false, header: ['x'], ); @@ -38,45 +44,63 @@ void main() { test('LinearRegressor multi-feature with constant column produces zeros', () { // isFlood=0 for all rows → zero-variance column → singular matrix - final data = DataFrame([ - [0.0, 50.0, 14.0, 0.0, 1900.0], - [0.0, 80.0, 14.0, 0.0, 2200.0], - [2.0, 50.0, 14.0, 0.0, 5000.0], - [4.0, 50.0, 14.0, 0.0, 9500.0], - ], headerExists: false, header: [ - 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', - ]); + final data = DataFrame( + [ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 14.0, 0.0, 2200.0], + [2.0, 50.0, 14.0, 0.0, 5000.0], + [4.0, 50.0, 14.0, 0.0, 9500.0], + ], + headerExists: false, + header: [ + 'pathLength', + 'messageBytes', + 'hourOfDay', + 'isFlood', + 'deliveryMs', + ], + ); final model = LinearRegressor(data, 'deliveryMs'); final testDf = DataFrame( - [[2.0, 50.0, 14.0, 0.0]], + [ + [2.0, 50.0, 14.0, 0.0], + ], headerExists: false, header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], ); final pred = model.predict(testDf).rows.first.first; - debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)'); + debugPrint( + 'With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)', + ); }); test('LinearRegressor 2-feature works correctly', () { // Just pathLength + messageBytes → deliveryMs - final data = DataFrame([ - [0.0, 50.0, 1900.0], - [0.0, 80.0, 2200.0], - [2.0, 50.0, 5000.0], - [2.0, 80.0, 5500.0], - [4.0, 50.0, 9500.0], - [4.0, 80.0, 10000.0], - [0.0, 30.0, 1800.0], - [2.0, 30.0, 4800.0], - [4.0, 30.0, 9000.0], - [0.0, 60.0, 2000.0], - ], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']); + final data = DataFrame( + [ + [0.0, 50.0, 1900.0], + [0.0, 80.0, 2200.0], + [2.0, 50.0, 5000.0], + [2.0, 80.0, 5500.0], + [4.0, 50.0, 9500.0], + [4.0, 80.0, 10000.0], + [0.0, 30.0, 1800.0], + [2.0, 30.0, 4800.0], + [4.0, 30.0, 9000.0], + [0.0, 60.0, 2000.0], + ], + headerExists: false, + header: ['pathLength', 'messageBytes', 'deliveryMs'], + ); final model = LinearRegressor(data, 'deliveryMs'); for (final hops in [0.0, 2.0, 4.0]) { final testDf = DataFrame( - [[hops, 50.0]], + [ + [hops, 50.0], + ], headerExists: false, header: ['pathLength', 'messageBytes'], ); @@ -87,20 +111,28 @@ void main() { test('LinearRegressor multi-feature with variance in all columns', () { // Mix flood and direct so isFlood has variance - final data = DataFrame([ - [0.0, 50.0, 14.0, 0.0, 1900.0], - [0.0, 80.0, 10.0, 0.0, 2200.0], - [2.0, 50.0, 16.0, 0.0, 5000.0], - [2.0, 80.0, 20.0, 0.0, 5500.0], - [4.0, 50.0, 8.0, 0.0, 9500.0], - [4.0, 80.0, 12.0, 0.0, 10000.0], - [-1.0, 40.0, 14.0, 1.0, 5000.0], - [-1.0, 60.0, 18.0, 1.0, 6500.0], - [-1.0, 30.0, 10.0, 1.0, 4000.0], - [-1.0, 80.0, 22.0, 1.0, 7000.0], - ], headerExists: false, header: [ - 'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs', - ]); + final data = DataFrame( + [ + [0.0, 50.0, 14.0, 0.0, 1900.0], + [0.0, 80.0, 10.0, 0.0, 2200.0], + [2.0, 50.0, 16.0, 0.0, 5000.0], + [2.0, 80.0, 20.0, 0.0, 5500.0], + [4.0, 50.0, 8.0, 0.0, 9500.0], + [4.0, 80.0, 12.0, 0.0, 10000.0], + [-1.0, 40.0, 14.0, 1.0, 5000.0], + [-1.0, 60.0, 18.0, 1.0, 6500.0], + [-1.0, 30.0, 10.0, 1.0, 4000.0], + [-1.0, 80.0, 22.0, 1.0, 7000.0], + ], + headerExists: false, + header: [ + 'pathLength', + 'messageBytes', + 'hourOfDay', + 'isFlood', + 'deliveryMs', + ], + ); final model = LinearRegressor(data, 'deliveryMs'); @@ -116,7 +148,9 @@ void main() { header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'], ); final pred = model.predict(testDf).rows.first.first; - debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms'); + debugPrint( + '4-feature: hops=${tc[0]} flood=${tc[3]} → ${(pred as num).round()}ms', + ); } }); } diff --git a/test/services/timeout_prediction_service_test.dart b/test/services/timeout_prediction_service_test.dart index 46dc5dfd..dbd852d8 100644 --- a/test/services/timeout_prediction_service_test.dart +++ b/test/services/timeout_prediction_service_test.dart @@ -64,9 +64,9 @@ void main() { expect(direct4!, greaterThan(direct2!)); expect(direct2, greaterThan(direct0!)); - // All should be within the clamp range - expect(direct0, greaterThanOrEqualTo(2000)); - expect(direct4, lessThanOrEqualTo(120000)); + // All should be positive + expect(direct0, greaterThan(0)); + expect(direct4, greaterThan(0)); // Print predictions for visibility debugPrint('Predictions (with 1.5x safety margin):'); From fffcff3b74896e9fe0dd64d02756fd94d7565062 Mon Sep 17 00:00:00 2001 From: zjs81 Date: Sat, 14 Mar 2026 17:39:01 -0700 Subject: [PATCH 3/3] fix: cancel persist timer on dispose to prevent post-dispose writes --- lib/services/timeout_prediction_service.dart | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/services/timeout_prediction_service.dart b/lib/services/timeout_prediction_service.dart index 1f3d6ddf..d92ca643 100644 --- a/lib/services/timeout_prediction_service.dart +++ b/lib/services/timeout_prediction_service.dart @@ -213,6 +213,12 @@ class TimeoutPredictionService extends ChangeNotifier { } } + @override + void dispose() { + _persistTimer?.cancel(); + super.dispose(); + } + void _rebuildContactStats() { _contactStats.clear(); for (final obs in _observations) {