from django.utils import timezone
from datetime import datetime, timedelta
from typing import List, Dict, Optional
from .models import UserPattern, MLModel, AnomalyAlert
from .push_service import PushNotificationService
from stamps.models import Stamp
import logging

logger = logging.getLogger(__name__)

class MLService:
    """Main ML Service with exact time scheduling integration"""
    
    def __init__(self):
        self.push_service = PushNotificationService()
        
    def analyze_user_patterns(self, user_id: str) -> Dict:
        """
        Analyze user patterns using question_type field
        """
        try:
            # Get user's stamp data
            stamps = Stamp.objects.filter(user__id=user_id)
            
            if len(stamps) < 5:
                return {'error': 'Insufficient data for ML analysis'}
            
            # Categorize stamps by question_type
            categorized_stamps = self.categorize_stamps_by_question_type(stamps)
            
            results = {
                'user_id': user_id,
                'analysis_date': timezone.now().isoformat(),
                'total_stamps': len(stamps),
                'categorized_stamps': {k: len(v) for k, v in categorized_stamps.items()},
                'patterns': {}
            }
            
            # Analyze each question_type pattern
            for question_type, stamp_list in categorized_stamps.items():
                if len(stamp_list) >= 3:  # Minimum samples for pattern
                    pattern_analysis = self.analyze_pattern_type(question_type, stamp_list)
                    results['patterns'][question_type] = pattern_analysis
            
            # Save patterns to database
            self.save_patterns_to_database(user_id, results)
            
            return results
            
        except Exception as e:
            logger.error(f"Error analyzing user patterns: {e}")
            return {'error': str(e)}
    
    def categorize_stamps_by_question_type(self, stamps) -> Dict:
        """
        Categorize stamps using stamp_function field
        """
        patterns = {
            'clock_in': [],
            'lunch_out': [],
            'lunch_in': [],
            'clock_out': []
        }
        
        for stamp in stamps:
            question_type = stamp.stamp_function
            
            # Only process categorized functions (exclude 'none')
            if question_type in patterns:
                patterns[question_type].append(stamp)
        
        return patterns
    
    def analyze_pattern_type(self, question_type: str, stamps: List) -> Dict:
        """
        Analyze specific pattern type
        """
        import numpy as np
        
        times = [stamp.time for stamp in stamps]
        dates = [stamp.date for stamp in stamps]
        
        # Basic statistics
        times_seconds = [self.time_to_seconds(t) for t in times]
        mean_seconds = np.mean(times_seconds)
        std_seconds = np.std(times_seconds)
        
        # Calculate confidence based on consistency
        confidence = self.calculate_confidence(len(times), std_seconds)
        
        return {
            'question_type': question_type,
            'average_time': self.seconds_to_time(int(mean_seconds)),
            'variance_minutes': int(std_seconds / 60),
            'confidence_score': confidence,
            'sample_size': len(times),
            'pattern_strength': self.calculate_pattern_strength(times)
        }
    
    def save_patterns_to_database(self, user_id: str, results: Dict):
        """Save learned patterns to database"""
        if 'patterns' not in results:
            return
        
        for question_type, pattern_data in results['patterns'].items():
            UserPattern.objects.update_or_create(
                user_id=user_id,
                question_type=question_type,
                defaults={
                    'average_time': pattern_data['average_time'],
                    'variance_minutes': pattern_data['variance_minutes'],
                    'confidence_score': pattern_data['confidence_score'],
                    'sample_size': pattern_data['sample_size']
                }
            )
    
    def detect_anomalies_for_new_stamp(self, user_id: str, new_stamp: Stamp) -> Optional[Dict]:
        """
        Detect anomalies for new stamp using stamp_function
        """
        try:
            # Check sequence anomalies
            sequence_anomaly = self.detect_sequence_anomalies(user_id, new_stamp)
            if sequence_anomaly:
                # Save anomaly alert
                AnomalyAlert.objects.create(
                    user_id=user_id,
                    stamp_id=new_stamp.id,
                    question_type=new_stamp.stamp_function,
                    anomaly_score=1.0,  # High score for sequence violations
                    algorithm_used='sequence_detection',
                    severity=sequence_anomaly['severity']
                )
                return sequence_anomaly
            
            # Check time-based anomalies
            time_anomaly = self.detect_time_anomaly(user_id, new_stamp)
            if time_anomaly:
                return time_anomaly
            
            return None
            
        except Exception as e:
            logger.error(f"Error detecting anomalies: {e}")
            return None
    
    def detect_sequence_anomalies(self, user_id: str, new_stamp: Stamp) -> Optional[Dict]:
        """
        Detect sequence violations using stamp_function
        """
        try:
            # Get last stamp for this user
            last_stamp = Stamp.objects.filter(
                user__id=user_id
            ).exclude(id=new_stamp.id).order_by('-date', '-time').first()
            
            if not last_stamp:
                return None
            
            current_question_type = new_stamp.stamp_function
            last_question_type = last_stamp.stamp_function
            
            # Define invalid sequences
            invalid_sequences = [
                ('clock_in', 'clock_in'),    # Can't clock in twice
                ('lunch_out', 'lunch_out'),  # Can't go to lunch twice
                ('lunch_in', 'lunch_in'),    # Can't return from lunch twice
                ('clock_out', 'clock_out'),  # Can't clock out twice
            ]
            
            sequence = (last_question_type, current_question_type)
            
            if sequence in invalid_sequences:
                return {
                    'is_anomaly': True,
                    'anomaly_type': 'sequence_violation',
                    'message': f"Invalid sequence: {last_question_type} → {current_question_type}",
                    'severity': 'high',
                    'suggestion': self.get_sequence_suggestion(last_question_type, current_question_type)
                }
            
            return None
            
        except Exception as e:
            logger.error(f"Error detecting sequence anomalies: {e}")
            return None
    
    def detect_time_anomaly(self, user_id: str, new_stamp: Stamp) -> Optional[Dict]:
        """Detect if stamp time is unusual"""
        try:
            pattern = UserPattern.objects.get(
                user_id=user_id,
                question_type=new_stamp.stamp_function
            )
            
            # Calculate z-score
            typical_time_seconds = self.time_to_seconds(pattern.average_time)
            new_time_seconds = self.time_to_seconds(new_stamp.time)
            variance_seconds = pattern.variance_minutes * 60
            
            if variance_seconds > 0:
                z_score = abs(new_time_seconds - typical_time_seconds) / variance_seconds
                
                if z_score > 2.0:  # More than 2 standard deviations
                    return {
                        'is_anomaly': True,
                        'anomaly_type': 'time_deviation',
                        'message': f"Unusual time for {new_stamp.stamp_function}: {new_stamp.time.strftime('%H:%M')} (typical: {pattern.average_time.strftime('%H:%M')})",
                        'severity': 'high' if z_score > 3.0 else 'medium',
                        'z_score': z_score
                    }
        
        except UserPattern.DoesNotExist:
            pass
        
        return None
    
    def get_sequence_suggestion(self, last_type: str, current_type: str) -> str:
        """Get suggestion for correct sequence"""
        suggestions = {
            ('clock_in', 'clock_in'): "You're already clocked in. Did you mean to go to lunch?",
            ('lunch_out', 'lunch_out'): "You're already on lunch break. Did you mean to return?",
            ('lunch_in', 'lunch_in'): "You're already back from lunch. Did you mean to clock out?",
            ('clock_out', 'clock_out'): "You're already clocked out. Did you mean to clock in for a new day?",
        }
        return suggestions.get((last_type, current_type), "Please check your sequence.")
    
    def get_user_patterns_summary(self, user_id: str) -> Dict:
        """Get summary of user's learned patterns"""
        try:
            patterns = UserPattern.objects.filter(user_id=user_id)
            
            summary = {
                'user_id': user_id,
                'total_patterns': patterns.count(),
                'patterns': []
            }
            
            for pattern in patterns:
                summary['patterns'].append({
                    'question_type': pattern.question_type,
                    'question_type_display': pattern.get_question_type_display(),
                    'average_time': pattern.average_time.strftime('%H:%M'),
                    'variance_minutes': pattern.variance_minutes,
                    'confidence_score': pattern.confidence_score,
                    'sample_size': pattern.sample_size,
                    'last_updated': pattern.last_updated.isoformat()
                })
            
            return summary
            
        except Exception as e:
            logger.error(f"Error getting user patterns summary: {e}")
            return {'error': str(e)}
    
    # Helper methods
    def time_to_seconds(self, t) -> int:
        """Convert time to seconds since midnight"""
        return t.hour * 3600 + t.minute * 60 + t.second
    
    def seconds_to_time(self, seconds: int):
        """Convert seconds since midnight to time"""
        from datetime import time
        hours = seconds // 3600
        minutes = (seconds % 3600) // 60
        secs = seconds % 60
        return time(hours, minutes, secs)
    
    def calculate_confidence(self, sample_size: int, std_deviation: float) -> float:
        """Calculate confidence score"""
        import numpy as np
        consistency = max(0, 1 - (std_deviation / 3600))
        sample_score = min(1, sample_size / 20)
        return (consistency + sample_score) / 2
    
    def calculate_pattern_strength(self, times: List) -> str:
        """Calculate pattern strength"""
        import numpy as np
        times_seconds = [self.time_to_seconds(t) for t in times]
        std_deviation = np.std(times_seconds)
        
        if std_deviation < 900:  # Less than 15 minutes
            return 'very_strong'
        elif std_deviation < 1800:  # Less than 30 minutes
            return 'strong'
        elif std_deviation < 3600:  # Less than 1 hour
            return 'moderate'
        else:
            return 'weak'
