fix: address PR #296 code review feedback

- Clamp ML predictions between physics floor (raw airtime) and ceiling
  (worst-case formula) so model can never produce unsafe timeouts
- Replace hourOfDay feature with secondsSinceLastRx for network activity
- Remove unused _ContactStats.stdDev and dead model persistence code
- Debounce observation writes (2s) instead of writing on every delivery
- Skip recording observations when pathLength is null to avoid corrupting
  training data
- Add comment explaining global (not per-contact) RX time tracking
- Remove notifyListeners from retrain to avoid unnecessary widget rebuilds
- Run dart format
This commit is contained in:
zjs81
2026-03-14 17:32:08 -07:00
parent 2ee2358ecc
commit b336aedbc5
6 changed files with 187 additions and 142 deletions
+13 -5
View File
@@ -74,14 +74,20 @@ class MessageRetryService extends ChangeNotifier {
required Function(Message) updateMessageCallback,
Function(Contact)? clearContactPathCallback,
Function(Contact, Uint8List, int)? setContactPathCallback,
Function(int pathLength, int messageBytes, {String? contactKey})? calculateTimeoutCallback,
Function(int pathLength, int messageBytes, {String? contactKey})?
calculateTimeoutCallback,
Uint8List? Function()? getSelfPublicKeyCallback,
String Function(Contact, String)? prepareContactOutboundTextCallback,
AppSettingsService? appSettingsService,
AppDebugLogService? debugLogService,
Function(String, PathSelection, bool, int?)? recordPathResultCallback,
Function(String contactKey, int pathLength, int messageBytes, int tripTimeMs)?
onDeliveryObservedCallback,
Function(
String contactKey,
int pathLength,
int messageBytes,
int tripTimeMs,
)?
onDeliveryObservedCallback,
}) {
_sendMessageCallback = sendMessageCallback;
_addMessageCallback = addMessageCallback;
@@ -750,10 +756,12 @@ class MessageRetryService extends ChangeNotifier {
true,
tripTimeMs,
);
if (_onDeliveryObservedCallback != null && tripTimeMs > 0) {
if (_onDeliveryObservedCallback != null &&
tripTimeMs > 0 &&
message.pathLength != null) {
_onDeliveryObservedCallback!(
contact.publicKeyHex,
message.pathLength ?? 0,
message.pathLength!,
message.text.length,
tripTimeMs,
);
+1 -20
View File
@@ -8,7 +8,6 @@ class StorageService {
static const String _pendingMessagesKey = 'pending_messages';
static const String _repeaterPasswordsKey = 'repeater_passwords';
static const String _deliveryObservationsKey = 'delivery_observations';
static const String _timeoutModelKey = 'timeout_ml_model';
Future<void> savePathHistory(
String contactPubKeyHex,
@@ -143,10 +142,7 @@ class StorageService {
try {
final list = jsonDecode(jsonStr) as List;
return list
.map(
(e) =>
DeliveryObservation.fromJson(e as Map<String, dynamic>),
)
.map((e) => DeliveryObservation.fromJson(e as Map<String, dynamic>))
.toList();
} catch (e) {
return [];
@@ -157,19 +153,4 @@ class StorageService {
final prefs = PrefsManager.instance;
await prefs.remove(_deliveryObservationsKey);
}
Future<void> saveTimeoutModel(String modelJson) async {
final prefs = PrefsManager.instance;
await prefs.setString(_timeoutModelKey, modelJson);
}
Future<String?> loadTimeoutModel() async {
final prefs = PrefsManager.instance;
return prefs.getString(_timeoutModelKey);
}
Future<void> clearTimeoutModel() async {
final prefs = PrefsManager.instance;
await prefs.remove(_timeoutModelKey);
}
}
+20 -21
View File
@@ -1,5 +1,4 @@
import 'dart:convert';
import 'dart:math';
import 'dart:async';
import 'package:flutter/foundation.dart';
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
@@ -9,16 +8,13 @@ import 'storage_service.dart';
class _ContactStats {
int count = 0;
double _sum = 0;
double _sumSq = 0;
void add(double ms) {
count++;
_sum += ms;
_sumSq += ms * ms;
}
double get mean => _sum / count;
double get stdDev => sqrt((_sumSq / count) - (mean * mean));
}
class TimeoutPredictionService extends ChangeNotifier {
@@ -27,9 +23,10 @@ class TimeoutPredictionService extends ChangeNotifier {
static const int minObservations = 10;
static const int maxObservations = 100;
static const int _retrainInterval = 5;
// 1.5x multiplier on raw prediction to account for variance in delivery
// times — tight enough to improve on worst-case physics, loose enough
// to avoid premature timeouts from model noise.
static const double _safetyMargin = 1.5;
static const int _minTimeoutMs = 2000;
static const int _maxTimeoutMs = 120000;
static const int _minContactObservations = 10;
List<DeliveryObservation> _observations = [];
@@ -37,6 +34,7 @@ class TimeoutPredictionService extends ChangeNotifier {
List<String> _activeFeatures = [];
int _observationsSinceLastTrain = 0;
final Map<String, _ContactStats> _contactStats = {};
Timer? _persistTimer;
TimeoutPredictionService(StorageService storage) : _storage = storage;
TimeoutPredictionService.noStorage() : _storage = null;
@@ -89,7 +87,10 @@ class TimeoutPredictionService extends ChangeNotifier {
_trainModel();
}
_storage?.saveDeliveryObservations(_observations);
_persistTimer?.cancel();
_persistTimer = Timer(const Duration(seconds: 2), () {
_storage?.saveDeliveryObservations(_observations);
});
debugPrint(
'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops '
'(${_observations.length} total)',
@@ -123,7 +124,9 @@ class TimeoutPredictionService extends ChangeNotifier {
final prediction = _model!.predict(features);
final rawValue = prediction.rows.first.first;
var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble();
var predictedMs = (rawValue is double)
? rawValue
: (rawValue as num).toDouble();
debugPrint(
'TimeoutPrediction: raw prediction=$predictedMs for '
@@ -142,8 +145,8 @@ class TimeoutPredictionService extends ChangeNotifier {
}
}
final timeout =
(predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs);
// Connector clamps this between physics min/max bounds
final timeout = (predictedMs * _safetyMargin).ceil();
debugPrint(
'TimeoutPrediction: ML timeout ${timeout}ms '
'(raw: ${predictedMs.round()}ms, contact: $contactKey)',
@@ -174,7 +177,9 @@ class TimeoutPredictionService extends ChangeNotifier {
}
if (_activeFeatures.isEmpty) {
debugPrint('TimeoutPrediction: no features with variance, skipping training');
debugPrint(
'TimeoutPrediction: no features with variance, skipping training',
);
return;
}
@@ -190,25 +195,19 @@ class TimeoutPredictionService extends ChangeNotifier {
return row;
});
final data = DataFrame(
[header, ...rows],
headerExists: true,
);
final data = DataFrame([header, ...rows], headerExists: true);
_model = LinearRegressor(data, 'deliveryMs');
_observationsSinceLastTrain = 0;
// Log training summary with sample predictions
final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
final avgMs =
_observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) /
_observations.length;
debugPrint(
'TimeoutPrediction: trained on ${_observations.length} observations '
'(avg: ${avgMs.round()}ms, features: $_activeFeatures)',
);
final modelJson = jsonEncode(_model!.toJson());
_storage?.saveTimeoutModel(modelJson);
notifyListeners();
} catch (e) {
debugPrint('TimeoutPrediction: training failed: $e');
}