import 'dart:convert'; import 'dart:math'; import 'package:flutter/foundation.dart'; import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import '../models/delivery_observation.dart'; import 'storage_service.dart'; class _ContactStats { int count = 0; double _sum = 0; double _sumSq = 0; void add(double ms) { count++; _sum += ms; _sumSq += ms * ms; } double get mean => _sum / count; double get stdDev => sqrt((_sumSq / count) - (mean * mean)); } class TimeoutPredictionService extends ChangeNotifier { final StorageService? _storage; static const int minObservations = 10; static const int maxObservations = 100; static const int _retrainInterval = 5; static const double _safetyMargin = 1.5; static const int _minTimeoutMs = 2000; static const int _maxTimeoutMs = 120000; static const int _minContactObservations = 10; List _observations = []; LinearRegressor? _model; List _activeFeatures = []; int _observationsSinceLastTrain = 0; final Map _contactStats = {}; TimeoutPredictionService(StorageService storage) : _storage = storage; TimeoutPredictionService.noStorage() : _storage = null; int get observationCount => _observations.length; bool get hasModel => _model != null; Future initialize() async { _observations = await _storage?.loadDeliveryObservations() ?? []; _rebuildContactStats(); if (_observations.length >= minObservations) { _trainModel(); } debugPrint( 'TimeoutPrediction: initialized with ${_observations.length} observations, ' 'model=${_model != null ? "ready" : "waiting for data"}', ); } void recordObservation({ required String contactKey, required int pathLength, required int messageBytes, required int tripTimeMs, int secondsSinceLastRx = 0, }) { final observation = DeliveryObservation( contactKey: contactKey, pathLength: pathLength, messageBytes: messageBytes, secondsSinceLastRx: secondsSinceLastRx, isFlood: pathLength < 0, deliveryMs: tripTimeMs, timestamp: DateTime.now(), ); _observations.add(observation); if (_observations.length > maxObservations) { _observations.removeAt(0); } _contactStats.putIfAbsent(contactKey, () => _ContactStats()); _contactStats[contactKey]!.add(tripTimeMs.toDouble()); _observationsSinceLastTrain++; if (_observationsSinceLastTrain >= _retrainInterval && _observations.length >= minObservations) { _trainModel(); } _storage?.saveDeliveryObservations(_observations); debugPrint( 'TimeoutPrediction: recorded ${tripTimeMs}ms for $pathLength hops ' '(${_observations.length} total)', ); } int? predictTimeout({ String? contactKey, required int pathLength, required int messageBytes, int secondsSinceLastRx = 0, }) { if (_model == null) return null; try { if (_activeFeatures.isEmpty) return null; final allFeatures = { 'pathLength': pathLength.toDouble(), 'messageBytes': messageBytes.toDouble(), 'secSinceRx': secondsSinceLastRx.toDouble(), 'isFlood': pathLength < 0 ? 1.0 : 0.0, }; final row = _activeFeatures.map((f) => allFeatures[f]!).toList(); final features = DataFrame( [row], headerExists: false, header: _activeFeatures, ); final prediction = _model!.predict(features); final rawValue = prediction.rows.first.first; var predictedMs = (rawValue is double) ? rawValue : (rawValue as num).toDouble(); debugPrint( 'TimeoutPrediction: raw prediction=$predictedMs for ' 'pathLength=$pathLength, messageBytes=$messageBytes, ' 'features=$_activeFeatures', ); // Sanity check: if prediction is negative or zero, fall back if (predictedMs <= 0) return null; // Blend with per-contact mean if enough data if (contactKey != null) { final stats = _contactStats[contactKey]; if (stats != null && stats.count >= _minContactObservations) { predictedMs = 0.5 * predictedMs + 0.5 * stats.mean; } } final timeout = (predictedMs * _safetyMargin).ceil().clamp(_minTimeoutMs, _maxTimeoutMs); debugPrint( 'TimeoutPrediction: ML timeout ${timeout}ms ' '(raw: ${predictedMs.round()}ms, contact: $contactKey)', ); return timeout; } catch (e) { debugPrint('TimeoutPrediction: prediction failed: $e'); return null; } } void _trainModel() { try { // Build feature columns, then exclude any with zero variance // (ml_algo's OLS produces all-zero coefficients for singular matrices) final allNames = ['pathLength', 'messageBytes', 'secSinceRx', 'isFlood']; final allExtractors = [ (o) => o.pathLength.toDouble(), (o) => o.messageBytes.toDouble(), (o) => o.secondsSinceLastRx.toDouble(), (o) => o.isFlood ? 1.0 : 0.0, ]; _activeFeatures = []; for (var i = 0; i < allNames.length; i++) { final values = _observations.map(allExtractors[i]).toSet(); if (values.length > 1) _activeFeatures.add(allNames[i]); } if (_activeFeatures.isEmpty) { debugPrint('TimeoutPrediction: no features with variance, skipping training'); return; } final header = [..._activeFeatures, 'deliveryMs']; final rows = _observations.map((o) { final row = []; for (var i = 0; i < allNames.length; i++) { if (_activeFeatures.contains(allNames[i])) { row.add(allExtractors[i](o)); } } row.add(o.deliveryMs.toDouble()); return row; }); final data = DataFrame( [header, ...rows], headerExists: true, ); _model = LinearRegressor(data, 'deliveryMs'); _observationsSinceLastTrain = 0; // Log training summary with sample predictions final avgMs = _observations.map((o) => o.deliveryMs).reduce((a, b) => a + b) / _observations.length; debugPrint( 'TimeoutPrediction: trained on ${_observations.length} observations ' '(avg: ${avgMs.round()}ms, features: $_activeFeatures)', ); final modelJson = jsonEncode(_model!.toJson()); _storage?.saveTimeoutModel(modelJson); notifyListeners(); } catch (e) { debugPrint('TimeoutPrediction: training failed: $e'); } } void _rebuildContactStats() { _contactStats.clear(); for (final obs in _observations) { _contactStats.putIfAbsent(obs.contactKey, () => _ContactStats()); _contactStats[obs.contactKey]!.add(obs.deliveryMs.toDouble()); } } }