fix: address PR #296 code review feedback

- Clamp ML predictions between physics floor (raw airtime) and ceiling
  (worst-case formula) so model can never produce unsafe timeouts
- Replace hourOfDay feature with secondsSinceLastRx for network activity
- Remove unused _ContactStats.stdDev and dead model persistence code
- Debounce observation writes (2s) instead of writing on every delivery
- Skip recording observations when pathLength is null to avoid corrupting
  training data
- Add comment explaining global (not per-contact) RX time tracking
- Remove notifyListeners from retrain to avoid unnecessary widget rebuilds
- Run dart format
This commit is contained in:
zjs81
2026-03-14 17:32:08 -07:00
parent 2ee2358ecc
commit b336aedbc5
6 changed files with 187 additions and 142 deletions
+65 -42
View File
@@ -168,6 +168,8 @@ class MeshCoreConnector extends ChangeNotifier {
bool _isLoadingChannels = false;
bool _hasLoadedChannels = false;
TimeoutPredictionService? _timeoutPredictionService;
// Intentionally global (not per-contact): tracks overall network activity.
// Frequent RX from any source indicates a busy network with more collisions.
DateTime _lastRxTime = DateTime.now();
bool _batteryRequested = false;
bool _awaitingSelfInfo = false;
@@ -694,23 +696,28 @@ class MeshCoreConnector extends ChangeNotifier {
updateMessageCallback: _updateMessage,
clearContactPathCallback: clearContactPath,
setContactPathCallback: setContactPath,
calculateTimeoutCallback: (pathLength, messageBytes, {String? contactKey}) =>
calculateTimeout(pathLength: pathLength, messageBytes: messageBytes, contactKey: contactKey),
calculateTimeoutCallback:
(pathLength, messageBytes, {String? contactKey}) => calculateTimeout(
pathLength: pathLength,
messageBytes: messageBytes,
contactKey: contactKey,
),
getSelfPublicKeyCallback: () => _selfPublicKey,
prepareContactOutboundTextCallback: prepareContactOutboundText,
appSettingsService: appSettingsService,
debugLogService: _appDebugLogService,
recordPathResultCallback: _recordPathResult,
onDeliveryObservedCallback: (contactKey, pathLength, messageBytes, tripTimeMs) {
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
_timeoutPredictionService?.recordObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
tripTimeMs: tripTimeMs,
secondsSinceLastRx: secSinceRx,
);
},
onDeliveryObservedCallback:
(contactKey, pathLength, messageBytes, tripTimeMs) {
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
_timeoutPredictionService?.recordObservation(
contactKey: contactKey,
pathLength: pathLength,
messageBytes: messageBytes,
tripTimeMs: tripTimeMs,
secondsSinceLastRx: secSinceRx,
);
},
);
}
@@ -2890,14 +2897,54 @@ class MeshCoreConnector extends ChangeNotifier {
}
}
/// Calculate timeout for a message based on radio settings and path length
/// Returns timeout in milliseconds, considering number of hops
/// Estimate single-packet airtime in ms from radio settings, or a fallback.
int _estimateAirtimeMs(int messageBytes) {
if (_currentFreqHz != null &&
_currentBwHz != null &&
_currentSf != null &&
_currentCr != null) {
final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4;
return calculateLoRaAirtime(
payloadBytes: messageBytes,
spreadingFactor: _currentSf!,
bandwidthHz: _currentBwHz!,
codingRate: cr,
lowDataRateOptimize: _currentSf! >= 11,
);
}
return 50; // fallback: ~SF7/BW125 for 100 bytes
}
/// Physics-based worst-case timeout (ceiling).
int _physicsMaxTimeout(int pathLength, int airtime) {
if (pathLength < 0) {
return 500 + (16 * airtime);
} else {
return 500 + ((airtime * 6 + 250) * (pathLength + 1));
}
}
/// Physics-based minimum timeout (floor): raw traversal time.
int _physicsMinTimeout(int pathLength, int airtime) {
if (pathLength < 0) {
return airtime;
} else {
return airtime * (pathLength + 1);
}
}
/// Calculate timeout for a message based on radio settings and path length.
/// Returns timeout in milliseconds, considering number of hops.
int calculateTimeout({
required int pathLength,
int messageBytes = 100,
String? contactKey,
}) {
// Try ML-based prediction first
final airtime = _estimateAirtimeMs(messageBytes);
final physicsMin = _physicsMinTimeout(pathLength, airtime);
final physicsMax = _physicsMaxTimeout(pathLength, airtime);
// Try ML-based prediction, clamped between physics bounds
final secSinceRx = DateTime.now().difference(_lastRxTime).inSeconds;
final mlTimeout = _timeoutPredictionService?.predictTimeout(
contactKey: contactKey,
@@ -2905,35 +2952,11 @@ class MeshCoreConnector extends ChangeNotifier {
messageBytes: messageBytes,
secondsSinceLastRx: secSinceRx,
);
if (mlTimeout != null) return mlTimeout;
// If we have radio settings, use them for accurate calculation
if (_currentFreqHz != null &&
_currentBwHz != null &&
_currentSf != null &&
_currentCr != null) {
final cr = _currentCr! <= 4 ? _currentCr! : _currentCr! - 4;
return calculateMessageTimeout(
freqHz: _currentFreqHz!,
bwHz: _currentBwHz!,
sf: _currentSf!,
cr: cr,
pathLength: pathLength,
messageBytes: messageBytes,
);
if (mlTimeout != null) {
return mlTimeout.clamp(physicsMin, physicsMax);
}
// Fallback: Conservative estimates based on typical settings
// Assume SF7, BW125, which gives ~50ms airtime for 100 bytes
const estimatedAirtime = 50;
if (pathLength < 0) {
// Flood mode: Base delay + 16× airtime
return 500 + (16 * estimatedAirtime);
} else {
// Direct path: Base delay + ((airtime×6 + 250ms)×(hops+1))
return 500 + ((estimatedAirtime * 6 + 250) * (pathLength + 1));
}
return physicsMax;
}
void _handleContact(Uint8List frame, {bool isContact = true}) {
+13 -5
View File
@@ -74,14 +74,20 @@ class MessageRetryService extends ChangeNotifier {
required Function(Message) updateMessageCallback,
Function(Contact)? clearContactPathCallback,
Function(Contact, Uint8List, int)? setContactPathCallback,
Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback,
Function(int pathLength, int messageBytes, {String? contactKey})?
calculateTimeoutCallback,
Uint8List? Function()? getSelfPublicKeyCallback,
String Function(Contact, String)? prepareContactOutboundTextCallback,
AppSettingsService? appSettingsService,
AppDebugLogService? debugLogService,
Function(String, PathSelection, bool, int?)? recordPathResultCallback,
Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)?
onDeliveryObservedCallback,
Function(
String contactKey,
int pathLength,
int messageBytes,
int tripTimeMs,
)?
onDeliveryObservedCallback,
}) {
_sendMessageCallback = sendMessageCallback;
_addMessageCallback = addMessageCallback;
@@ -750,10 +756,12 @@ class MessageRetryService extends ChangeNotifier {
true,
tripTimeMs,
);
if (_onDeliveryObservedCallback != null && tripTimeMs > 0) {
if (_onDeliveryObservedCallback != null &&
tripTimeMs > 0 &&
message.pathLength != null) {
_onDeliveryObservedCallback!(
contact.publicKeyHex,
message.pathLength ?? 0,
message.pathLength!,
message.text.length,
tripTimeMs,
);
+1 -20
View File
@@ -8,7 +8,6 @@ class StorageService {
static const String _pendingMessagesKey = 'pending_messages';
static const String _repeaterPasswordsKey = 'repeater_passwords';
static const String _deliveryObservationsKey = 'delivery_observations';
static const String _timeoutModelKey = 'timeout_ml_model';
Future<void> savePathHistory(
String contactPubKeyHex,
@@ -143,10 +142,7 @@ class StorageService {
try {
final list = jsonDecode(jsonStr) as List;
return list
.map(
(e) =>
DeliveryObservation.fromJson(e as Map<String, dynamic>),
)
.map((e) => DeliveryObservation.fromJson(e as Map<String, dynamic>))
.toList();
} catch (e) {
return [];
@@ -157,19 +153,4 @@ class StorageService {
final prefs = PrefsManager.instance;
await prefs.remove(_deliveryObservationsKey);
}
Future<void> saveTimeoutModel(String modelJson) async {
final prefs = PrefsManager.instance;
await prefs.setString(_timeoutModelKey, modelJson);
}
Future<String?> loadTimeoutModel() async {
final prefs = PrefsManager.instance;
return prefs.getString(_timeoutModelKey);
}
Future<void> clearTimeoutModel() async {
final prefs = PrefsManager.instance;
await prefs.remove(_timeoutModelKey);
}
}
+20 -21
View File
@@ -1,5 +1,4 @@
import 'dart:convert';
import 'dart:math';
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
@@ -9,16 +8,13 @@ import 'storage_service.dart';
class _ContactStats {
int count = 0;
double _sum = 0;
double _sumSq = 0;
void add(double ms) {
count++;
_sum += ms;
_sumSq += ms * ms;
}
double get mean => _sum / count;
double get stdDev => sqrt((_sumSq / count) - (mean * mean));
}
class TimeoutPredictionService extends ChangeNotifier {
@@ -27,9 +23,10 @@ class TimeoutPredictionService extends ChangeNotifier {
static const int minObservations = 10;
static const int maxObservations = 100;
static const int _retrainInterval = 5;
// 1.5x multiplier on raw prediction to account for variance in delivery
// times — tight enough to improve on worst-case physics, loose enough
// to avoid premature timeouts from model noise.
static const double _safetyMargin = 1.5;
static const int _minTimeoutMs = 2000;
static const int _maxTimeoutMs = 120000;
static const int _minContactObservations = 10;
List<DeliveryObservation> _observations = [];
@@ -37,6 +34,7 @@ class TimeoutPredictionService extends ChangeNotifier {
List<String> _activeFeatures = [];
int _observationsSinceLastTrain = 0;
final Map<String, _ContactStats> _contactStats = {};
Timer? _persistTimer;
TimeoutPredictionService(StorageService storage) : _storage = storage;
TimeoutPredictionService.noStorage() : _storage = null;
@@ -89,7 +87,10 @@ class TimeoutPredictionService extends ChangeNotifier {
_trainModel();
}
_storage?.saveDeliveryObservations(_observations);
_persistTimer?.cancel();
_persistTimer = Timer(const Duration(seconds: 2), () {
_storage?.saveDeliveryObservations(_observations);
});
debugPrint(
'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops '
'(${_observations.length} total)',
@@ -123,7 +124,9 @@ class TimeoutPredictionService extends ChangeNotifier {
final prediction = _model!.predict(features);
final rawValue = prediction.rows.first.first;
var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble();
var predictedMs = (rawValue is double)
? rawValue
: (rawValue as num).toDouble();
debugPrint(
'TimeoutPrediction: raw prediction=$predictedMs for '
@@ -142,8 +145,8 @@ class TimeoutPredictionService extends ChangeNotifier {
}
}
final timeout =
(predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs);
// Connector clamps this between physics min/max bounds
final timeout = (predictedMs * _safetyMargin).ceil();
debugPrint(
'TimeoutPrediction: ML timeout ${timeout}ms '
'(raw: ${predictedMs.round()}ms, contact: $contactKey)',
@@ -174,7 +177,9 @@ class TimeoutPredictionService extends ChangeNotifier {
}
if (_activeFeatures.isEmpty) {
debugPrint('TimeoutPrediction: no features with variance, skipping training');
debugPrint(
'TimeoutPrediction: no features with variance, skipping training',
);
return;
}
@@ -190,25 +195,19 @@ class TimeoutPredictionService extends ChangeNotifier {
return row;
});
final data = DataFrame(
[header, ...rows],
headerExists: true,
);
final data = DataFrame([header, ...rows], headerExists: true);
_model = LinearRegressor(data, 'deliveryMs');
_observationsSinceLastTrain = 0;
// Log training summary with sample predictions
final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
final avgMs =
_observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
_observations.length;
debugPrint(
'TimeoutPrediction: trained on ${_observations.length} observations '
'(avg: ${avgMs.round()}ms, features: $_activeFeatures)',
);
final modelJson = jsonEncode(_model!.toJson());
_storage?.saveTimeoutModel(modelJson);
notifyListeners();
} catch (e) {
debugPrint('TimeoutPrediction: training failed: $e');
}
+85 -51
View File
@@ -6,18 +6,22 @@ import 'package:ml_dataframe/ml_dataframe.dart';
void main() {
test('LinearRegressor basic sanity check', () {
// Simple: y = 2x + 100
final data = DataFrame([
[1.0, 102.0],
[2.0, 104.0],
[3.0, 106.0],
[4.0, 108.0],
[5.0, 110.0],
[10.0, 120.0],
[20.0, 140.0],
[50.0, 200.0],
[0.0, 100.0],
[100.0, 300.0],
], headerExists: false, header: ['x', 'y']);
final data = DataFrame(
[
[1.0, 102.0],
[2.0, 104.0],
[3.0, 106.0],
[4.0, 108.0],
[5.0, 110.0],
[10.0, 120.0],
[20.0, 140.0],
[50.0, 200.0],
[0.0, 100.0],
[100.0, 300.0],
],
headerExists: false,
header: ['x', 'y'],
);
debugPrint('Training data columns: ${data.header}');
debugPrint('Training data rows: ${data.rows.length}');
@@ -25,7 +29,9 @@ void main() {
final model = LinearRegressor(data, 'y');
final testDf = DataFrame(
[[25.0]],
[
[25.0],
],
headerExists: false,
header: ['x'],
);
@@ -38,45 +44,63 @@ void main() {
test('LinearRegressor multi-feature with constant column produces zeros', () {
// isFlood=0 for all rows → zero-variance column → singular matrix
final data = DataFrame([
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 14.0, 0.0, 2200.0],
[2.0, 50.0, 14.0, 0.0, 5000.0],
[4.0, 50.0, 14.0, 0.0, 9500.0],
], headerExists: false, header: [
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
]);
final data = DataFrame(
[
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 14.0, 0.0, 2200.0],
[2.0, 50.0, 14.0, 0.0, 5000.0],
[4.0, 50.0, 14.0, 0.0, 9500.0],
],
headerExists: false,
header: [
'pathLength',
'messageBytes',
'hourOfDay',
'isFlood',
'deliveryMs',
],
);
final model = LinearRegressor(data, 'deliveryMs');
final testDf = DataFrame(
[[2.0, 50.0, 14.0, 0.0]],
[
[2.0, 50.0, 14.0, 0.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)');
debugPrint(
'With constant isFlood column: hops=2 → ${(pred as num).round()}ms (likely 0)',
);
});
test('LinearRegressor 2-feature works correctly', () {
// Just pathLength + messageBytes → deliveryMs
final data = DataFrame([
[0.0, 50.0, 1900.0],
[0.0, 80.0, 2200.0],
[2.0, 50.0, 5000.0],
[2.0, 80.0, 5500.0],
[4.0, 50.0, 9500.0],
[4.0, 80.0, 10000.0],
[0.0, 30.0, 1800.0],
[2.0, 30.0, 4800.0],
[4.0, 30.0, 9000.0],
[0.0, 60.0, 2000.0],
], headerExists: false, header: ['pathLength', 'messageBytes', 'deliveryMs']);
final data = DataFrame(
[
[0.0, 50.0, 1900.0],
[0.0, 80.0, 2200.0],
[2.0, 50.0, 5000.0],
[2.0, 80.0, 5500.0],
[4.0, 50.0, 9500.0],
[4.0, 80.0, 10000.0],
[0.0, 30.0, 1800.0],
[2.0, 30.0, 4800.0],
[4.0, 30.0, 9000.0],
[0.0, 60.0, 2000.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes', 'deliveryMs'],
);
final model = LinearRegressor(data, 'deliveryMs');
for (final hops in [0.0, 2.0, 4.0]) {
final testDf = DataFrame(
[[hops, 50.0]],
[
[hops, 50.0],
],
headerExists: false,
header: ['pathLength', 'messageBytes'],
);
@@ -87,20 +111,28 @@ void main() {
test('LinearRegressor multi-feature with variance in all columns', () {
// Mix flood and direct so isFlood has variance
final data = DataFrame([
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 10.0, 0.0, 2200.0],
[2.0, 50.0, 16.0, 0.0, 5000.0],
[2.0, 80.0, 20.0, 0.0, 5500.0],
[4.0, 50.0, 8.0, 0.0, 9500.0],
[4.0, 80.0, 12.0, 0.0, 10000.0],
[-1.0, 40.0, 14.0, 1.0, 5000.0],
[-1.0, 60.0, 18.0, 1.0, 6500.0],
[-1.0, 30.0, 10.0, 1.0, 4000.0],
[-1.0, 80.0, 22.0, 1.0, 7000.0],
], headerExists: false, header: [
'pathLength', 'messageBytes', 'hourOfDay', 'isFlood', 'deliveryMs',
]);
final data = DataFrame(
[
[0.0, 50.0, 14.0, 0.0, 1900.0],
[0.0, 80.0, 10.0, 0.0, 2200.0],
[2.0, 50.0, 16.0, 0.0, 5000.0],
[2.0, 80.0, 20.0, 0.0, 5500.0],
[4.0, 50.0, 8.0, 0.0, 9500.0],
[4.0, 80.0, 12.0, 0.0, 10000.0],
[-1.0, 40.0, 14.0, 1.0, 5000.0],
[-1.0, 60.0, 18.0, 1.0, 6500.0],
[-1.0, 30.0, 10.0, 1.0, 4000.0],
[-1.0, 80.0, 22.0, 1.0, 7000.0],
],
headerExists: false,
header: [
'pathLength',
'messageBytes',
'hourOfDay',
'isFlood',
'deliveryMs',
],
);
final model = LinearRegressor(data, 'deliveryMs');
@@ -116,7 +148,9 @@ void main() {
header: ['pathLength', 'messageBytes', 'hourOfDay', 'isFlood'],
);
final pred = model.predict(testDf).rows.first.first;
debugPrint('4-feature: hops=${tc[0]} flood=${tc[3]}${(pred as num).round()}ms');
debugPrint(
'4-feature: hops=${tc[0]} flood=${tc[3]}${(pred as num).round()}ms',
);
}
});
}
@@ -64,9 +64,9 @@ void main() {
expect(direct4!, greaterThan(direct2!));
expect(direct2, greaterThan(direct0!));
// All should be within the clamp range
expect(direct0, greaterThanOrEqualTo(2000));
expect(direct4, lessThanOrEqualTo(120000));
// All should be positive
expect(direct0, greaterThan(0));
expect(direct4, greaterThan(0));
// Print predictions for visibility
debugPrint('Predictions (with 1.5x safety margin):');