"""
Burnout Detection Service
Uses ML and pattern analysis to detect burnout risk
"""

from datetime import datetime, timedelta
from django.utils import timezone
from stamps.models import Stamp
from worktimeservice.models import WorkBalance
from functions.models import Function
from company.models import WorkingTimePolicy
from userSettings.models import UserSettings


class BurnoutDetector:
    """Detects burnout patterns from user work data"""
    
    def __init__(self, user_id):
        self.user_id = user_id
    
    def analyze_user(self, days=30):
        """
        Analyze user for burnout risk over last N days
        
        Returns:
            dict: {
                'risk_score': float (0-100),
                'severity': str ('low', 'medium', 'high', 'critical'),
                'risk_factors': dict,
                'recommendations': list
            }
        """
        
        # Get user settings
        try:
            settings = UserSettings.objects.get(user__id=self.user_id)
            if not settings.burnout_notifications_enabled:
                return None
        except UserSettings.DoesNotExist:
            return None
        
        # Get date range
        end_date = timezone.now().date()
        start_date = end_date - timedelta(days=days)
        
        # Extract features
        features = self._extract_features(start_date, end_date)
        
        # Calculate risk
        risk_analysis = self._calculate_risk(features)
        
        # Get recommendations
        recommendations = self._get_recommendations(risk_analysis['risk_score'])
        
        return {
            'risk_score': risk_analysis['risk_score'],
            'severity': risk_analysis['severity'],
            'risk_factors': risk_analysis['risk_factors'],
            'recommendations': recommendations,
        }
    
    def _extract_features(self, start_date, end_date):
        """Extract features from stamps and balances"""
        
        # Get all stamps in period
        from user.models import User
        try:
            user = User.objects.get(id=self.user_id)
        except User.DoesNotExist:
            return None
        stamps = Stamp.objects.filter(
            user=user,
            date__gte=start_date,
            date__lte=end_date
        ).order_by('date', 'time')
        
        # Get all work balances
        balances = WorkBalance.objects.filter(
            user_id=self.user_id,
            date__gte=start_date,
            date__lte=end_date
        )
        
        # Analyze daily patterns
        daily_features = []
        for balance in balances:
            day_stamps = stamps.filter(date=balance.date)
            daily_features.append(self._analyze_day(balance, day_stamps))
        
        # Aggregate features
        features = {
            'avg_daily_hours': sum(d['total_hours'] for d in daily_features) / len(daily_features) if daily_features else 0,
            'avg_overtime_hours': sum(d.get('overtime_hours', 0) for d in daily_features) / len(daily_features) if daily_features else 0,
            'avg_break_duration': self._calculate_avg_break_duration(stamps),
            'break_compliance': self._calculate_break_compliance(stamps),
            'consecutive_days': self._calculate_consecutive_days(balances),
            'weekend_work_frequency': self._calculate_weekend_work_frequency(stamps),
            'total_break_time': self._calculate_total_break_time(stamps),
        }
        
        return features
    
    def _analyze_day(self, balance, stamps):
        """Analyze a single day"""
        total_hours = balance.total_work_seconds / 3600 if balance.total_work_seconds else 0
        overtime_hours = balance.overtime_seconds / 3600 if balance.overtime_seconds else 0
        
        # Analyze breaks
        breaks = self._get_break_durations_for_day(stamps)
        
        return {
            'total_hours': total_hours,
            'overtime_hours': overtime_hours,
            'flex_hours': balance.flex_seconds / 3600 if balance.flex_seconds else 0,
            'break_duration': sum(b['duration'] for b in breaks),
            'num_breaks': len(breaks),
        }
    
    def _get_break_durations_for_day(self, stamps):
        """Get break durations for a day"""
        breaks = []
        
        # Get question_type from Function model
        from functions.models import Function as StampFunction
        stamp_functions = {}
        for stamp in stamps:
            try:
                func = StampFunction.objects.get(id=stamp.stamp_function)
                stamp_functions[stamp.id] = func.question_type
            except:
                stamp_functions[stamp.id] = None
        
        # Separate by type
        break_outs = [s for s in stamps if stamp_functions.get(s.id) == 'break_out']
        break_ins = [s for s in stamps if stamp_functions.get(s.id) == 'break_in']
        lunch_outs = [s for s in stamps if stamp_functions.get(s.id) == 'lunch_out']
        lunch_ins = [s for s in stamps if stamp_functions.get(s.id) == 'lunch_in']
        
        # Match short breaks
        for out_stamp in sorted(break_outs, key=lambda x: x.time):
            matching_in = next(
                (in_stamp for in_stamp in break_ins if in_stamp.time > out_stamp.time),
                None
            )
            if matching_in:
                duration = (matching_in.time - out_stamp.time).total_seconds() / 60
                breaks.append({
                    'type': 'short_break',
                    'duration': duration,
                    'out_time': out_stamp.time,
                    'in_time': matching_in.time,
                })
        
        # Match lunch breaks
        for out_stamp in sorted(lunch_outs, key=lambda x: x.time):
            matching_in = next(
                (in_stamp for in_stamp in lunch_ins if in_stamp.time > out_stamp.time),
                None
            )
            if matching_in:
                duration = (matching_in.time - out_stamp.time).total_seconds() / 60
                breaks.append({
                    'type': 'lunch_break',
                    'duration': duration,
                    'out_time': out_stamp.time,
                    'in_time': matching_in.time,
                })
        
        return breaks
    
    def _calculate_avg_break_duration(self, stamps):
        """Calculate average break duration"""
        total_duration = 0
        count = 0
        
        for date in stamps.values_list('date', flat=True).distinct():
            day_stamps = stamps.filter(date=date)
            breaks = self._get_break_durations_for_day(day_stamps)
            for break_info in breaks:
                total_duration += break_info['duration']
                count += 1
        
        return total_duration / count if count > 0 else 0
    
    def _calculate_break_compliance(self, stamps):
        """Calculate break compliance rate"""
        # Get policy
        try:
            policy = WorkingTimePolicy.objects.filter(
                company__user__id=self.user_id
            ).order_by('-created_at').first()
        except:
            policy = None
        
        if not policy:
            return 1.0  # No policy, assume compliance
        
        total_days = stamps.values_list('date', flat=True).distinct().count()
        compliant_days = 0
        
        for date in stamps.values_list('date', flat=True).distinct():
            day_stamps = stamps.filter(date=date)
            breaks = self._get_break_durations_for_day(day_stamps)
            
            # Check if lunch is taken
            has_lunch = any(b['type'] == 'lunch_break' for b in breaks)
            
            # Check if short breaks are taken
            has_short_breaks = any(b['type'] == 'short_break' for b in breaks)
            
            # Check compliance
            if has_lunch and has_short_breaks:
                compliant_days += 1
        
        return compliant_days / total_days if total_days > 0 else 1.0
    
    def _calculate_total_break_time(self, stamps):
        """Calculate total break time"""
        total = 0
        
        for date in stamps.values_list('date', flat=True).distinct():
            day_stamps = stamps.filter(date=date)
            breaks = self._get_break_durations_for_day(day_stamps)
            total += sum(b['duration'] for b in breaks)
        
        return total
    
    def _calculate_consecutive_days(self, balances):
        """Calculate consecutive working days"""
        if not balances.exists():
            return 0
        
        dates = sorted([b.date for b in balances])
        
        max_consecutive = 0
        current_streak = 1
        
        for i in range(1, len(dates)):
            if (dates[i] - dates[i-1]).days == 1:
                current_streak += 1
            else:
                max_consecutive = max(max_consecutive, current_streak)
                current_streak = 1
        
        return max(max_consecutive, current_streak)
    
    def _calculate_weekend_work_frequency(self, stamps):
        """Calculate frequency of weekend work"""
        weekend_stamps = stamps.filter(
            date__week_day__in=[1, 7]  # Saturday=1, Sunday=7
        )
        return weekend_stamps.values_list('date', flat=True).distinct().count()
    
    def _calculate_risk(self, features):
        """Calculate burnout risk score"""
        risk_score = 0
        risk_factors = {}
        
        # Factor 1: Excessive Hours (30 points max)
        if features['avg_daily_hours'] > 9:
            points = min(30, (features['avg_daily_hours'] - 8) * 10)
            risk_factors['excessive_hours'] = points
            risk_score += points
        
        # Factor 2: High Overtime (20 points max)
        if features['avg_overtime_hours'] > 1:
            points = min(20, features['avg_overtime_hours'] * 20)
            risk_factors['overtime_frequency'] = points
            risk_score += points
        
        # Factor 3: Poor Break Compliance (20 points max)
        if features['break_compliance'] < 0.5:
            points = (1 - features['break_compliance']) * 20
            risk_factors['poor_break_compliance'] = points
            risk_score += points
        
        # Factor 4: Many Consecutive Days (15 points max)
        if features['consecutive_days'] > 10:
            points = min(15, (features['consecutive_days'] - 10) * 3)
            risk_factors['consecutive_days'] = points
            risk_score += points
        
        # Factor 5: Weekend Work (10 points max)
        if features['weekend_work_frequency'] > 2:
            points = min(10, features['weekend_work_frequency'] * 5)
            risk_factors['weekend_work'] = points
            risk_score += points
        
        # Factor 6: Insufficient Total Breaks (5 points max)
        if features['avg_break_duration'] < 45:  # Less than 45 minutes average
            points = (45 - features['avg_break_duration']) / 10
            risk_factors['insufficient_breaks'] = points
            risk_score += points
        
        # Determine severity
        if risk_score < 40:
            severity = 'low'
        elif risk_score < 70:
            severity = 'medium'
        elif risk_score < 85:
            severity = 'high'
        else:
            severity = 'critical'
        
        return {
            'risk_score': min(100, risk_score),
            'severity': severity,
            'risk_factors': risk_factors,
        }
    
    def _get_recommendations(self, risk_score):
        """Get balance recommendations based on risk"""
        recommendations = []
        
        # Get available balances
        latest_balance = WorkBalance.objects.filter(
            user_id=self.user_id
        ).order_by('-date').first()
        
        if latest_balance:
            # Check flex time
            if latest_balance.flex_seconds > 0:
                recommendations.append({
                    'balance_type': 'flex_time',
                    'message': f'Use {latest_balance.flex_seconds / 3600:.1f}h flex time to reduce workload',
                    'priority': 'high' if risk_score > 60 else 'medium',
                    'reason': 'Available balance to offset excess hours',
                    'action': 'Use flex time to finish early or start late',
                })
            
            # Check overtime balance (comp time)
            if latest_balance.overtime_seconds > 0:
                recommendations.append({
                    'balance_type': 'comp_time',
                    'message': f'Take {latest_balance.overtime_seconds / 3600:.1f}h comp time to recharge',
                    'priority': 'high' if risk_score > 70 else 'medium',
                    'reason': 'Available comp time from overtime',
                    'action': 'Schedule a day off using comp time',
                })
            
            # Suggest vacation for critical risk
            if risk_score > 80:
                recommendations.append({
                    'balance_type': 'vacation',
                    'message': 'Take vacation days to prevent burnout',
                    'priority': 'critical',
                    'reason': 'High burnout risk + no recovery days',
                    'action': 'Plan vacation to recover',
                })
        
        return recommendations

