Compare commits

...

3 Commits

Author SHA1 Message Date
zjs81 fffcff3b74 fix: cancel persist timer on dispose to prevent post-dispose writes 2026-03-14 17:39:01 -07:00
zjs81 b336aedbc5 fix: address PR #296 code review feedback
- Clamp ML predictions between physics floor (raw airtime) and ceiling
  (worst-case formula) so model can never produce unsafe timeouts
- Replace hourOfDay feature with secondsSinceLastRx for network activity
- Remove unused _ContactStats.stdDev and dead model persistence code
- Debounce observation writes (2s) instead of writing on every delivery
- Skip recording observations when pathLength is null to avoid corrupting
  training data
- Add comment explaining global (not per-contact) RX time tracking
- Remove notifyListeners from retrain to avoid unnecessary widget rebuilds
- Run dart format
2026-03-14 17:32:08 -07:00
zjs81 2ee2358ecc feat: add ML-based adaptive timeout prediction using LinearRegressor
Train a linear regression model on actual message delivery times to
predict tighter timeouts, replacing worst-case physics estimates.
Features: path length, message bytes, seconds since last RX, flood mode.
Global model with per-contact blending after 10+ observations per contact.
Falls back to existing physics formula when model has insufficient data.
2026-03-14 16:56:11 -07:00
9 changed files with 752 additions and 38 deletions
+74 -21
View File
@@ -19,6 +19,7 @@ import '../services/message_retry_service.dart';
import '../services/path_history_service.dart'; import '../services/path_history_service.dart';
import '../services/app_settings_service.dart'; import '../services/app_settings_service.dart';
import '../services/background_service.dart'; import '../services/background_service.dart';
import '../services/timeout_prediction_service.dart';
import '../services/notification_service.dart'; import '../services/notification_service.dart';
import 'meshcore_connector_usb.dart'; import 'meshcore_connector_usb.dart';
import 'meshcore_connector_tcp.dart'; import 'meshcore_connector_tcp.dart';
@@ -166,6 +167,10 @@ class MeshCoreConnector extends ChangeNotifier {
bool _isLoadingContacts = false; bool _isLoadingContacts = false;
bool _isLoadingChannels = false; bool _isLoadingChannels = false;
bool _hasLoadedChannels = false; bool _hasLoadedChannels = false;
TimeoutPredictionService? _timeoutPredictionService;
// Intentionally global (not per-contact): tracks overall network activity.
// Frequent RX from any source indicates a busy network with more collisions.
DateTime _lastRxTime = DateTime.now();
bool _batteryRequested = false; bool _batteryRequested = false;
bool _awaitingSelfInfo = false; bool _awaitingSelfInfo = false;
bool _hasReceivedDeviceInfo = false; bool _hasReceivedDeviceInfo = false;
@@ -668,6 +673,7 @@ class MeshCoreConnector extends ChangeNotifier {
BleDebugLogService? bleDebugLogService, BleDebugLogService? bleDebugLogService,
AppDebugLogService? appDebugLogService, AppDebugLogService? appDebugLogService,
BackgroundService? backgroundService, BackgroundService? backgroundService,
TimeoutPredictionService? timeoutPredictionService,
}) { }) {
_retryService = retryService; _retryService = retryService;
_pathHistoryService = pathHistoryService; _pathHistoryService = pathHistoryService;
@@ -675,6 +681,7 @@ class MeshCoreConnector extends ChangeNotifier {
_bleDebugLogService = bleDebugLogService; _bleDebugLogService = bleDebugLogService;
_appDebugLogService = appDebugLogService; _appDebugLogService = appDebugLogService;
_backgroundService = backgroundService; _backgroundService = backgroundService;
_timeoutPredictionService = timeoutPredictionService;
_usbManager.setDebugLogService(_appDebugLogService); _usbManager.setDebugLogService(_appDebugLogService);
_tcpConnector.setDebugLogService(_appDebugLogService); _tcpConnector.setDebugLogService(_appDebugLogService);
@@ -689,13 +696,28 @@ class MeshCoreConnector extends ChangeNotifier {
updateMessageCallback: _updateMessage, updateMessageCallback: _updateMessage,
clearContactPathCallback: clearContactPath, clearContactPathCallback: clearContactPath,
setContactPathCallback: setContactPath, setContactPathCallback: setContactPath,
calculateTimeoutCallback: (pathLength, messageBytes) => calculateTimeoutCallback:
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes), (pathLength, messageBytes, {String? contactKey}) => calculateTimeout(
pathLength: pathLength,
messageBytes: messageBytes,
contactKey: contactKey,
),
getSelfPublicKeyCallback: () => _selfPublicKey, getSelfPublicKeyCallback: () => _selfPublicKey,
prepareContactOutboundTextCallback: prepareContactOutboundText, prepareContactOutboundTextCallback: prepareContactOutboundText,
appSettingsService: appSettingsService, appSettingsService: appSettingsService,
debugLogService: _appDebugLogService, debugLogService: _appDebugLogService,
recordPathResultCallback: _recordPathResult, recordPathResultCallback: _recordPathResult,
onDeliveryObservedCallback:
(contactKey, pathLength, messageBytes, tripTimeMs) {
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
_timeoutPredictionService?.recordObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
tripTimeMs: tripTimeMs,
secondsSinceLastRx: secSinceRx,
);
},
); );
} }
@@ -2498,6 +2520,7 @@ class MeshCoreConnector extends ChangeNotifier {
void _handleFrame(List<int> data) { void _handleFrame(List<int> data) {
if (data.isEmpty) return; if (data.isEmpty) return;
_lastRxTime = DateTime.now();
final frame = Uint8List.fromList(data); final frame = Uint8List.fromList(data);
_receivedFramesController.add(frame); _receivedFramesController.add(frame);
@@ -2874,38 +2897,68 @@ class MeshCoreConnector extends ChangeNotifier {
} }
} }
/// Calculate timeout for a message based on radio settings and path length /// Estimate single-packet airtime in ms from radio settings, or a fallback.
/// Returns timeout in milliseconds, considering number of hops int _estimateAirtimeMs(int messageBytes) {
int calculateTimeout({required int pathLength, int messageBytes = 100}) {
// If we have radio settings, use them for accurate calculation
if (_currentFreqHz != null && if (_currentFreqHz != null &&
_currentBwHz != null && _currentBwHz != null &&
_currentSf != null && _currentSf != null &&
_currentCr != null) { _currentCr != null) {
final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4; final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4;
return calculateMessageTimeout( return calculateLoRaAirtime(
freqHz: _currentFreqHz!, payloadBytes: messageBytes,
bwHz: _currentBwHz!, spreadingFactor: _currentSf!,
sf: _currentSf!, bandwidthHz: _currentBwHz!,
cr: cr, codingRate: cr,
pathLength: pathLength, lowDataRateOptimize: _currentSf! >= 11,
messageBytes: messageBytes,
); );
} }
return 50; // fallback: ~SF7/BW125 for 100 bytes
}
// Fallback: Conservative estimates based on typical settings /// Physics-based worst-case timeout (ceiling).
// Assume SF7, BW125, which gives ~50ms airtime for 100 bytes int _physicsMaxTimeout(int pathLength, int airtime) {
const estimatedAirtime = 50;
if (pathLength < 0) { if (pathLength < 0) {
// Flood mode: Base delay + 16× airtime return 500 + (16 * airtime);
return 500 + (16 * estimatedAirtime);
} else { } else {
// Direct path: Base delay + ((airtime×6 + 250ms)×(hops+1)) return 500 + ((airtime * 6 + 250) * (pathLength + 1));
return 500 + ((estimatedAirtime * 6 + 250) * (pathLength + 1));
} }
} }
/// Physics-based minimum timeout (floor): raw traversal time.
int _physicsMinTimeout(int pathLength, int airtime) {
if (pathLength < 0) {
return airtime;
} else {
return airtime * (pathLength + 1);
}
}
/// Calculate timeout for a message based on radio settings and path length.
/// Returns timeout in milliseconds, considering number of hops.
int calculateTimeout({
required int pathLength,
int messageBytes = 100,
String? contactKey,
}) {
final airtime = _estimateAirtimeMs(messageBytes);
final physicsMin = _physicsMinTimeout(pathLength, airtime);
final physicsMax = _physicsMaxTimeout(pathLength, airtime);
// Try ML-based prediction, clamped between physics bounds
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
final mlTimeout = _timeoutPredictionService?.predictTimeout(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: secSinceRx,
);
if (mlTimeout != null) {
return mlTimeout.clamp(physicsMin, physicsMax);
}
return physicsMax;
}
void _handleContact(Uint8List frame, {bool isContact = true}) { void _handleContact(Uint8List frame, {bool isContact = true}) {
final contact = Contact.fromFrame(frame); final contact = Contact.fromFrame(frame);
if (contact != null) { if (contact != null) {
+8
View File
@@ -19,6 +19,7 @@ import 'services/app_debug_log_service.dart';
import 'services/background_service.dart'; import 'services/background_service.dart';
import 'services/map_tile_cache_service.dart'; import 'services/map_tile_cache_service.dart';
import 'services/chat_text_scale_service.dart'; import 'services/chat_text_scale_service.dart';
import 'services/timeout_prediction_service.dart';
import 'storage/prefs_manager.dart'; import 'storage/prefs_manager.dart';
import 'utils/app_logger.dart'; import 'utils/app_logger.dart';
@@ -39,6 +40,7 @@ void main() async {
final backgroundService = BackgroundService(); final backgroundService = BackgroundService();
final mapTileCacheService = MapTileCacheService(); final mapTileCacheService = MapTileCacheService();
final chatTextScaleService = ChatTextScaleService(); final chatTextScaleService = ChatTextScaleService();
final timeoutPredictionService = TimeoutPredictionService(storage);
// Load settings // Load settings
await appSettingsService.loadSettings(); await appSettingsService.loadSettings();
@@ -56,6 +58,7 @@ void main() async {
_registerThirdPartyLicenses(); _registerThirdPartyLicenses();
await chatTextScaleService.initialize(); await chatTextScaleService.initialize();
await timeoutPredictionService.initialize();
// Wire up connector with services // Wire up connector with services
connector.initialize( connector.initialize(
@@ -65,6 +68,7 @@ void main() async {
bleDebugLogService: bleDebugLogService, bleDebugLogService: bleDebugLogService,
appDebugLogService: appDebugLogService, appDebugLogService: appDebugLogService,
backgroundService: backgroundService, backgroundService: backgroundService,
timeoutPredictionService: timeoutPredictionService,
); );
await connector.loadContactCache(); await connector.loadContactCache();
@@ -86,6 +90,7 @@ void main() async {
appDebugLogService: appDebugLogService, appDebugLogService: appDebugLogService,
mapTileCacheService: mapTileCacheService, mapTileCacheService: mapTileCacheService,
chatTextScaleService: chatTextScaleService, chatTextScaleService: chatTextScaleService,
timeoutPredictionService: timeoutPredictionService,
), ),
); );
} }
@@ -121,6 +126,7 @@ class MeshCoreApp extends StatelessWidget {
final AppDebugLogService appDebugLogService; final AppDebugLogService appDebugLogService;
final MapTileCacheService mapTileCacheService; final MapTileCacheService mapTileCacheService;
final ChatTextScaleService chatTextScaleService; final ChatTextScaleService chatTextScaleService;
final TimeoutPredictionService timeoutPredictionService;
const MeshCoreApp({ const MeshCoreApp({
super.key, super.key,
@@ -133,6 +139,7 @@ class MeshCoreApp extends StatelessWidget {
required this.appDebugLogService, required this.appDebugLogService,
required this.mapTileCacheService, required this.mapTileCacheService,
required this.chatTextScaleService, required this.chatTextScaleService,
required this.timeoutPredictionService,
}); });
@override @override
@@ -148,6 +155,7 @@ class MeshCoreApp extends StatelessWidget {
ChangeNotifierProvider.value(value: chatTextScaleService), ChangeNotifierProvider.value(value: chatTextScaleService),
Provider.value(value: storage), Provider.value(value: storage),
Provider.value(value: mapTileCacheService), Provider.value(value: mapTileCacheService),
ChangeNotifierProvider.value(value: timeoutPredictionService),
], ],
child: Consumer<AppSettingsService>( child: Consumer<AppSettingsService>(
builder: (context, settingsService, child) { builder: (context, settingsService, child) {
+43
View File
@@ -0,0 +1,43 @@
class DeliveryObservation {
final String contactKey;
final int pathLength;
final int messageBytes;
final int secondsSinceLastRx;
final bool isFlood;
final int deliveryMs;
final DateTime timestamp;
DeliveryObservation({
required this.contactKey,
required this.pathLength,
required this.messageBytes,
required this.secondsSinceLastRx,
required this.isFlood,
required this.deliveryMs,
required this.timestamp,
});
Map<String, dynamic> toJson() {
return {
'contact_key': contactKey,
'path_length': pathLength,
'message_bytes': messageBytes,
'seconds_since_last_rx': secondsSinceLastRx,
'is_flood': isFlood,
'delivery_ms': deliveryMs,
'timestamp': timestamp.toIso8601String(),
};
}
factory DeliveryObservation.fromJson(Map<String, dynamic> json) {
return DeliveryObservation(
contactKey: json['contact_key'] as String,
pathLength: json['path_length'] as int,
messageBytes: json['message_bytes'] as int,
secondsSinceLastRx: json['seconds_since_last_rx'] as int? ?? 0,
isFlood: json['is_flood'] as bool,
deliveryMs: json['delivery_ms'] as int,
timestamp: DateTime.parse(json['timestamp'] as String),
);
}
}
+45 -17
View File
@@ -58,12 +58,13 @@ class MessageRetryService extends ChangeNotifier {
Function(Message)? _updateMessageCallback; Function(Message)? _updateMessageCallback;
Function(Contact)? _clearContactPathCallback; Function(Contact)? _clearContactPathCallback;
Function(Contact, Uint8List, int)? _setContactPathCallback; Function(Contact, Uint8List, int)? _setContactPathCallback;
Function(int, int)? _calculateTimeoutCallback; Function(int, int, {String? contactKey})? _calculateTimeoutCallback;
Uint8List? Function()? _getSelfPublicKeyCallback; Uint8List? Function()? _getSelfPublicKeyCallback;
String Function(Contact, String)? _prepareContactOutboundTextCallback; String Function(Contact, String)? _prepareContactOutboundTextCallback;
AppSettingsService? _appSettingsService; AppSettingsService? _appSettingsService;
AppDebugLogService? _debugLogService; AppDebugLogService? _debugLogService;
Function(String, PathSelection, bool, int?)? _recordPathResultCallback; Function(String, PathSelection, bool, int?)? _recordPathResultCallback;
Function(String, int, int, int)? _onDeliveryObservedCallback;
MessageRetryService(); MessageRetryService();
@@ -73,12 +74,20 @@ class MessageRetryService extends ChangeNotifier {
required Function(Message) updateMessageCallback, required Function(Message) updateMessageCallback,
Function(Contact)? clearContactPathCallback, Function(Contact)? clearContactPathCallback,
Function(Contact, Uint8List, int)? setContactPathCallback, Function(Contact, Uint8List, int)? setContactPathCallback,
Function(int pathLength, int messageBytes)? calculateTimeoutCallback, Function(int pathLength, int messageBytes, {String? contactKey})?
calculateTimeoutCallback,
Uint8List? Function()? getSelfPublicKeyCallback, Uint8List? Function()? getSelfPublicKeyCallback,
String Function(Contact, String)? prepareContactOutboundTextCallback, String Function(Contact, String)? prepareContactOutboundTextCallback,
AppSettingsService? appSettingsService, AppSettingsService? appSettingsService,
AppDebugLogService? debugLogService, AppDebugLogService? debugLogService,
Function(String, PathSelection, bool, int?)? recordPathResultCallback, Function(String, PathSelection, bool, int?)? recordPathResultCallback,
Function(
String contactKey,
int pathLength,
int messageBytes,
int tripTimeMs,
)?
onDeliveryObservedCallback,
}) { }) {
_sendMessageCallback = sendMessageCallback; _sendMessageCallback = sendMessageCallback;
_addMessageCallback = addMessageCallback; _addMessageCallback = addMessageCallback;
@@ -91,6 +100,7 @@ class MessageRetryService extends ChangeNotifier {
_appSettingsService = appSettingsService; _appSettingsService = appSettingsService;
_debugLogService = debugLogService; _debugLogService = debugLogService;
_recordPathResultCallback = recordPathResultCallback; _recordPathResultCallback = recordPathResultCallback;
_onDeliveryObservedCallback = onDeliveryObservedCallback;
} }
/// Compute expected ACK hash using same algorithm as firmware: /// Compute expected ACK hash using same algorithm as firmware:
@@ -423,25 +433,33 @@ class MessageRetryService extends ChangeNotifier {
); );
} }
// Use device-provided timeout, or calculate from radio settings if timeout is 0 or invalid // Calculate timeout: prefer ML prediction, then device-provided, then physics fallback
int pathLengthValue;
if (selection != null) {
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
} else if (message.pathLength != null) {
pathLengthValue = message.pathLength!;
} else {
pathLengthValue = contact.pathLength;
}
int actualTimeout = timeoutMs; int actualTimeout = timeoutMs;
if (timeoutMs <= 0 && _calculateTimeoutCallback != null) { if (_calculateTimeoutCallback != null) {
int pathLengthValue; final calculated = _calculateTimeoutCallback!(
if (selection != null) {
pathLengthValue = selection.useFlood ? -1 : selection.hopCount;
if (pathLengthValue < 0) pathLengthValue = contact.pathLength;
} else if (message.pathLength != null) {
pathLengthValue = message.pathLength!;
} else {
pathLengthValue = contact.pathLength;
}
actualTimeout = _calculateTimeoutCallback!(
pathLengthValue, pathLengthValue,
message.text.length, message.text.length,
contactKey: contact.publicKeyHex,
); );
debugPrint( // calculateTimeout tries ML first, falls back to physics.
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue', // Use calculated value if device didn't provide one, or if ML
); // produced a tighter prediction than the device's estimate.
if (timeoutMs <= 0 || calculated < timeoutMs) {
actualTimeout = calculated;
debugPrint(
'Using calculated timeout: ${actualTimeout}ms for path length $pathLengthValue',
);
}
} }
final updatedMessage = message.copyWith( final updatedMessage = message.copyWith(
@@ -738,6 +756,16 @@ class MessageRetryService extends ChangeNotifier {
true, true,
tripTimeMs, tripTimeMs,
); );
if (_onDeliveryObservedCallback != null &&
tripTimeMs > 0 &&
message.pathLength != null) {
_onDeliveryObservedCallback!(
contact.publicKeyHex,
message.pathLength!,
message.text.length,
tripTimeMs,
);
}
_onMessageResolved(matchedMessageId, contact.publicKeyHex); _onMessageResolved(matchedMessageId, contact.publicKeyHex);
} }
+31
View File
@@ -1,4 +1,5 @@
import 'dart:convert'; import 'dart:convert';
import '../models/delivery_observation.dart';
import '../models/path_history.dart'; import '../models/path_history.dart';
import '../storage/prefs_manager.dart'; import '../storage/prefs_manager.dart';
@@ -6,6 +7,7 @@ class StorageService {
static const String _pathHistoryPrefix = 'path_history_'; static const String _pathHistoryPrefix = 'path_history_';
static const String _pendingMessagesKey = 'pending_messages'; static const String _pendingMessagesKey = 'pending_messages';
static const String _repeaterPasswordsKey = 'repeater_passwords'; static const String _repeaterPasswordsKey = 'repeater_passwords';
static const String _deliveryObservationsKey = 'delivery_observations';
Future<void> savePathHistory( Future<void> savePathHistory(
String contactPubKeyHex, String contactPubKeyHex,
@@ -122,4 +124,33 @@ class StorageService {
final prefs = PrefsManager.instance; final prefs = PrefsManager.instance;
await prefs.remove(_repeaterPasswordsKey); await prefs.remove(_repeaterPasswordsKey);
} }
Future<void> saveDeliveryObservations(
List<DeliveryObservation> observations,
) async {
final prefs = PrefsManager.instance;
final jsonStr = jsonEncode(observations.map((o) => o.toJson()).toList());
await prefs.setString(_deliveryObservationsKey, jsonStr);
}
Future<List<DeliveryObservation>> loadDeliveryObservations() async {
final prefs = PrefsManager.instance;
final jsonStr = prefs.getString(_deliveryObservationsKey);
if (jsonStr == null) return [];
try {
final list = jsonDecode(jsonStr) as List;
return list
.map((e) => DeliveryObservation.fromJson(e as Map<String, dynamic>))
.toList();
} catch (e) {
return [];
}
}
Future<void> clearDeliveryObservations() async {
final prefs = PrefsManager.instance;
await prefs.remove(_deliveryObservationsKey);
}
} }
@@ -0,0 +1,229 @@
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import '../models/delivery_observation.dart';
import 'storage_service.dart';
class _ContactStats {
int count = 0;
double _sum = 0;
void add(double ms) {
count++;
_sum += ms;
}
double get mean => _sum / count;
}
class TimeoutPredictionService extends ChangeNotifier {
final StorageService? _storage;
static const int minObservations = 10;
static const int maxObservations = 100;
static const int _retrainInterval = 5;
// 1.5x multiplier on raw prediction to account for variance in delivery
// times — tight enough to improve on worst-case physics, loose enough
// to avoid premature timeouts from model noise.
static const double _safetyMargin = 1.5;
static const int _minContactObservations = 10;
List<DeliveryObservation> _observations = [];
LinearRegressor? _model;
List<String> _activeFeatures = [];
int _observationsSinceLastTrain = 0;
final Map<String, _ContactStats> _contactStats = {};
Timer? _persistTimer;
TimeoutPredictionService(StorageService storage) : _storage = storage;
TimeoutPredictionService.noStorage() : _storage = null;
int get observationCount => _observations.length;
bool get hasModel => _model != null;
Future<void> initialize() async {
_observations = await _storage?.loadDeliveryObservations() ?? [];
_rebuildContactStats();
if (_observations.length >= minObservations) {
_trainModel();
}
debugPrint(
'TimeoutPrediction: initialized with ${_observations.length} observations, '
'model=${_model != null ? "ready" : "waiting for data"}',
);
}
void recordObservation({
required String contactKey,
required int pathLength,
required int messageBytes,
required int tripTimeMs,
int secondsSinceLastRx = 0,
}) {
final observation = DeliveryObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: secondsSinceLastRx,
isFlood: pathLength < 0,
deliveryMs: tripTimeMs,
timestamp: DateTime.now(),
);
_observations.add(observation);
if (_observations.length > maxObservations) {
_observations.removeAt(0);
}
_contactStats.putIfAbsent(contactKey, () => _ContactStats());
_contactStats[contactKey]!.add(tripTimeMs.toDouble());
_observationsSinceLastTrain++;
if (_observationsSinceLastTrain >= _retrainInterval &&
_observations.length >= minObservations) {
_trainModel();
}
_persistTimer?.cancel();
_persistTimer = Timer(const Duration(seconds: 2), () {
_storage?.saveDeliveryObservations(_observations);
});
debugPrint(
'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops '
'(${_observations.length} total)',
);
}
int? predictTimeout({
String? contactKey,
required int pathLength,
required int messageBytes,
int secondsSinceLastRx = 0,
}) {
if (_model == null) return null;
try {
if (_activeFeatures.isEmpty) return null;
final allFeatures = {
'pathLength': pathLength.toDouble(),
'messageBytes': messageBytes.toDouble(),
'secSinceRx': secondsSinceLastRx.toDouble(),
'isFlood': pathLength < 0 ? 1.0 : 0.0,
};
final row = _activeFeatures.map((f) => allFeatures[f]!).toList();
final features = DataFrame(
[row],
headerExists: false,
header: _activeFeatures,
);
final prediction = _model!.predict(features);
final rawValue = prediction.rows.first.first;
var predictedMs = (rawValue is double)
? rawValue
: (rawValue as num).toDouble();
debugPrint(
'TimeoutPrediction: raw prediction=$predictedMs for '
'pathLength=$pathLength, messageBytes=$messageBytes, '
'features=$_activeFeatures',
);
// Sanity check: if prediction is negative or zero, fall back
if (predictedMs <= 0) return null;
// Blend with per-contact mean if enough data
if (contactKey != null) {
final stats = _contactStats[contactKey];
if (stats != null && stats.count >= _minContactObservations) {
predictedMs = 0.5 * predictedMs + 0.5 * stats.mean;
}
}
// Connector clamps this between physics min/max bounds
final timeout = (predictedMs * _safetyMargin).ceil();
debugPrint(
'TimeoutPrediction: ML timeout ${timeout}ms '
'(raw: ${predictedMs.round()}ms, contact: $contactKey)',
);
return timeout;
} catch (e) {
debugPrint('TimeoutPrediction: prediction failed: $e');
return null;
}
}
void _trainModel() {
try {
// Build feature columns, then exclude any with zero variance
// (ml_algo's OLS produces all-zero coefficients for singular matrices)
final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood'];
final allExtractors = <double Function(DeliveryObservation)>[
(o) => o.pathLength.toDouble(),
(o) => o.messageBytes.toDouble(),
(o) => o.secondsSinceLastRx.toDouble(),
(o) => o.isFlood ? 1.0 : 0.0,
];
_activeFeatures = [];
for (var i = 0; i < allNames.length; i++) {
final values = _observations.map(allExtractors[i]).toSet();
if (values.length > 1) _activeFeatures.add(allNames[i]);
}
if (_activeFeatures.isEmpty) {
debugPrint(
'TimeoutPrediction: no features with variance, skipping training',
);
return;
}
final header = [..._activeFeatures, 'deliveryMs'];
final rows = _observations.map((o) {
final row = <double>[];
for (var i = 0; i < allNames.length; i++) {
if (_activeFeatures.contains(allNames[i])) {
row.add(allExtractors[i](o));
}
}
row.add(o.deliveryMs.toDouble());
return row;
});
final data = DataFrame([header, ...rows], headerExists: true);
_model = LinearRegressor(data, 'deliveryMs');
_observationsSinceLastTrain = 0;
// Log training summary with sample predictions
final avgMs =
_observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
_observations.length;
debugPrint(
'TimeoutPrediction: trained on ${_observations.length} observations '
'(avg: ${avgMs.round()}ms, features: $_activeFeatures)',
);
} catch (e) {
debugPrint('TimeoutPrediction: training failed: $e');
}
}
@override
void dispose() {
_persistTimer?.cancel();
super.dispose();
}
void _rebuildContactStats() {
_contactStats.clear();
for (final obs in _observations) {
_contactStats.putIfAbsent(obs.contactKey, () => _ContactStats());
_contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble());
}
}
}
+2
View File
@@ -69,6 +69,8 @@ dependencies:
material_symbols_icons: ^4.2906.0 material_symbols_icons: ^4.2906.0
web: ^1.1.1 web: ^1.1.1
flutter_svg: ^2.0.10+1 flutter_svg: ^2.0.10+1
ml_algo: ^16.0.0
ml_dataframe: ^1.0.0
dev_dependencies: dev_dependencies:
flutter_test: flutter_test:
+156
View File
@@ -0,0 +1,156 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
void main() {
test('LinearRegressor basic sanity check', () {
// Simple: y = 2x + 100
final data = DataFrame(
[
[1.0, 102.0],
[2.0, 104.0],
[3.0, 106.0],
[4.0, 108.0],
[5.0, 110.0],
[10.0, 120.0],
[20.0, 140.0],
[50.0, 200.0],
[0.0, 100.0],
[100.0, 300.0],
],
headerExists: false,
header: ['x', 'y'],
);
debugPrint('Training data columns: ${data.header}');
debugPrint('Training data rows: ${data.rows.length}');
final model = LinearRegressor(data, 'y');
final testDf = DataFrame(
[
[25.0],
],
headerExists: false,
header: ['x'],
);
final prediction = model.predict(testDf);
final value = prediction.rows.first.first;
debugPrint('Predict x=25 → y=$value (expected ~150)');
expect((value as num).toDouble(), closeTo(150, 5));
});
test('LinearRegressor multi-feature with constant column produces zeros', () {
// isFlood=0 for all rows → zero-variance column → singular matrix
final data = DataFrame(
[
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 14.0, 0.0, 2200.0],
[2.0, 50.0, 14.0, 0.0, 5000.0],
[4.0, 50.0, 14.0, 0.0, 9500.0],
],
headerExists: false,
header: [
'pathLength',
'messageBytes',
'hourOfDay',
'isFlood',
'deliveryMs',
],
);
final model = LinearRegressor(data, 'deliveryMs');
final testDf = DataFrame(
[
[2.0, 50.0, 14.0, 0.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint(
'With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)',
);
});
test('LinearRegressor 2-feature works correctly', () {
// Just pathLength + messageBytes → deliveryMs
final data = DataFrame(
[
[0.0, 50.0, 1900.0],
[0.0, 80.0, 2200.0],
[2.0, 50.0, 5000.0],
[2.0, 80.0, 5500.0],
[4.0, 50.0, 9500.0],
[4.0, 80.0, 10000.0],
[0.0, 30.0, 1800.0],
[2.0, 30.0, 4800.0],
[4.0, 30.0, 9000.0],
[0.0, 60.0, 2000.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes', 'deliveryMs'],
);
final model = LinearRegressor(data, 'deliveryMs');
for (final hops in [0.0, 2.0, 4.0]) {
final testDf = DataFrame(
[
[hops, 50.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('2-feature: hops=$hops${(pred as num).round()}ms');
}
});
test('LinearRegressor multi-feature with variance in all columns', () {
// Mix flood and direct so isFlood has variance
final data = DataFrame(
[
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 10.0, 0.0, 2200.0],
[2.0, 50.0, 16.0, 0.0, 5000.0],
[2.0, 80.0, 20.0, 0.0, 5500.0],
[4.0, 50.0, 8.0, 0.0, 9500.0],
[4.0, 80.0, 12.0, 0.0, 10000.0],
[-1.0, 40.0, 14.0, 1.0, 5000.0],
[-1.0, 60.0, 18.0, 1.0, 6500.0],
[-1.0, 30.0, 10.0, 1.0, 4000.0],
[-1.0, 80.0, 22.0, 1.0, 7000.0],
],
headerExists: false,
header: [
'pathLength',
'messageBytes',
'hourOfDay',
'isFlood',
'deliveryMs',
],
);
final model = LinearRegressor(data, 'deliveryMs');
for (final tc in [
[0.0, 50.0, 14.0, 0.0],
[2.0, 50.0, 14.0, 0.0],
[4.0, 50.0, 14.0, 0.0],
[-1.0, 50.0, 14.0, 1.0],
]) {
final testDf = DataFrame(
[tc],
headerExists: false,
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint(
'4-feature: hops=${tc[0]} flood=${tc[3]}${(pred as num).round()}ms',
);
}
});
}
@@ -0,0 +1,164 @@
import 'package:flutter/foundation.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:meshcore_open/models/delivery_observation.dart';
import 'package:meshcore_open/services/timeout_prediction_service.dart';
void main() {
late TimeoutPredictionService service;
setUp(() {
service = TimeoutPredictionService.noStorage();
});
test('trains on sample data and predicts sensible timeouts', () {
// Simulate realistic delivery data:
// Direct 0-hop messages: ~1500-2500ms
// 2-hop messages: ~4000-6000ms
// 4-hop messages: ~8000-12000ms
// Flood messages: ~3000-8000ms
final sampleData = [
// 0-hop direct
_obs(pathLength: 0, messageBytes: 20, deliveryMs: 1800),
_obs(pathLength: 0, messageBytes: 50, deliveryMs: 2100),
_obs(pathLength: 0, messageBytes: 80, deliveryMs: 2400),
_obs(pathLength: 0, messageBytes: 30, deliveryMs: 1925),
// 2-hop direct
_obs(pathLength: 2, messageBytes: 40, deliveryMs: 4500),
_obs(pathLength: 2, messageBytes: 60, deliveryMs: 5200),
_obs(pathLength: 2, messageBytes: 25, deliveryMs: 4100),
// 4-hop direct
_obs(pathLength: 4, messageBytes: 50, deliveryMs: 9800),
_obs(pathLength: 4, messageBytes: 30, deliveryMs: 8500),
_obs(pathLength: 4, messageBytes: 70, deliveryMs: 10570),
// Flood
_obs(pathLength: -1, messageBytes: 40, deliveryMs: 5000),
_obs(pathLength: -1, messageBytes: 60, deliveryMs: 6500),
];
// Feed all observations
for (final obs in sampleData) {
service.recordObservation(
contactKey: obs.contactKey,
pathLength: obs.pathLength,
messageBytes: obs.messageBytes,
tripTimeMs: obs.deliveryMs,
);
}
expect(service.hasModel, isTrue);
expect(service.observationCount, equals(12));
// Predict for different scenarios
final direct0 = service.predictTimeout(pathLength: 0, messageBytes: 50);
final direct2 = service.predictTimeout(pathLength: 2, messageBytes: 50);
final direct4 = service.predictTimeout(pathLength: 4, messageBytes: 50);
final flood = service.predictTimeout(pathLength: -1, messageBytes: 50);
// All should return non-null (model is trained)
expect(direct0, isNotNull);
expect(direct2, isNotNull);
expect(direct4, isNotNull);
expect(flood, isNotNull);
// More hops should predict longer timeouts
expect(direct4!, greaterThan(direct2!));
expect(direct2, greaterThan(direct0!));
// All should be positive
expect(direct0, greaterThan(0));
expect(direct4, greaterThan(0));
// Print predictions for visibility
debugPrint('Predictions (with 1.5x safety margin):');
debugPrint(' 0-hop direct: ${direct0}ms');
debugPrint(' 2-hop direct: ${direct2}ms');
debugPrint(' 4-hop direct: ${direct4}ms');
debugPrint(' flood: ${flood}ms');
});
test('returns null before minimum observations', () {
for (var i = 0; i < TimeoutPredictionService.minObservations - 1; i++) {
service.recordObservation(
contactKey: 'abc',
pathLength: 0,
messageBytes: 50,
tripTimeMs: 2000,
);
}
expect(service.hasModel, isFalse);
expect(service.predictTimeout(pathLength: 0, messageBytes: 50), isNull);
});
test('caps observations at maxObservations', () {
for (var i = 0; i < TimeoutPredictionService.maxObservations + 20; i++) {
service.recordObservation(
contactKey: 'abc',
pathLength: 0,
messageBytes: 50,
tripTimeMs: 2000 + i,
);
}
expect(
service.observationCount,
equals(TimeoutPredictionService.maxObservations),
);
});
test('blends per-contact stats after enough observations', () {
// Train with mixed contacts and varied features:
// contactA is fast (0-hop), contactB is slow (2-hop)
for (var i = 0; i < 12; i++) {
service.recordObservation(
contactKey: 'contactA',
pathLength: 0,
messageBytes: 30 + i,
tripTimeMs: 1500,
);
service.recordObservation(
contactKey: 'contactB',
pathLength: 2,
messageBytes: 30 + i,
tripTimeMs: 8000,
);
}
final predA = service.predictTimeout(
contactKey: 'contactA',
pathLength: 0,
messageBytes: 50,
);
final predB = service.predictTimeout(
contactKey: 'contactB',
pathLength: 0,
messageBytes: 50,
);
expect(predA, isNotNull);
expect(predB, isNotNull);
// Contact B (slow) should have a higher predicted timeout than A (fast)
expect(predB!, greaterThan(predA!));
debugPrint('Per-contact blending:');
debugPrint(' contactA (fast): ${predA}ms');
debugPrint(' contactB (slow): ${predB}ms');
});
}
DeliveryObservation _obs({
required int pathLength,
required int messageBytes,
required int deliveryMs,
String contactKey = 'test_contact',
}) {
return DeliveryObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
secondsSinceLastRx: 5,
isFlood: pathLength < 0,
deliveryMs: deliveryMs,
timestamp: DateTime.now(),
);
}