from rest_framework import status, viewsets
from rest_framework.decorators import api_view, permission_classes, action
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from django.shortcuts import get_object_or_404
from django.db.models import Q
import requests
from .models import Company, CollectiveAgreement, WorkingTimePolicy, VacationPolicy, OnboardingProgress
from .serializers import (
    CompanySerializer, CompanySearchSerializer, CollectiveAgreementSerializer, WorkingTimePolicySerializer,
    VacationPolicySerializer, OnboardingProgressSerializer, OnboardingDataSerializer,
    Step1CompanyProfileSerializer, Step2CollectiveAgreementSerializer,
    Step3WorkingTimePolicySerializer, Step4VacationPolicySerializer
)

@api_view(['GET', 'POST'])
def onboarding_data(request):
    """Get or create onboarding data"""
    if request.method == 'GET':
        try:
            company = Company.objects.first()  # For now, get the first company
            if not company:
                # Return empty data structure for new onboarding
                return Response({
                    'company': None,
                    'collective_agreement': None,
                    'working_time_policy': None,
                    'vacation_policy': None,
                    'onboarding_progress': None
                })
            
            data = {
                'company': CompanySerializer(company).data,
                'collective_agreement': None,
                'working_time_policy': None,
                'vacation_policy': None,
                'onboarding_progress': None
            }
            
            if hasattr(company, 'collective_agreement'):
                data['collective_agreement'] = CollectiveAgreementSerializer(company.collective_agreement).data
            
            if hasattr(company, 'working_time_policy'):
                data['working_time_policy'] = WorkingTimePolicySerializer(company.working_time_policy).data
            
            if hasattr(company, 'vacation_policy'):
                data['vacation_policy'] = VacationPolicySerializer(company.vacation_policy).data
            
            if hasattr(company, 'onboarding_progress'):
                data['onboarding_progress'] = OnboardingProgressSerializer(company.onboarding_progress).data
            
            return Response(data)
        except Exception as e:
            return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
    
    elif request.method == 'POST':
        serializer = OnboardingDataSerializer(data=request.data)
        if serializer.is_valid():
            # Create company
            company_data = serializer.validated_data['company']
            company = Company.objects.create(**company_data)
            
            # Create onboarding progress
            OnboardingProgress.objects.create(company=company, current_step='company_profile')
            
            return Response({'message': 'Onboarding data created successfully', 'company_id': company.id}, 
                          status=status.HTTP_201_CREATED)
        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['POST'])
def step1_company_profile(request):
    """Handle Step 1: Company Profile"""
    serializer = Step1CompanyProfileSerializer(data=request.data)
    if serializer.is_valid():
        # Check if company already exists
        company = Company.objects.first()
        if company:
            # Update existing company
            for key, value in serializer.validated_data.items():
                setattr(company, key, value)
            company.save()
        else:
            # Create new company
            company = Company.objects.create(**serializer.validated_data)
            # Create onboarding progress
            OnboardingProgress.objects.create(company=company, current_step='collective_agreement')
        
        return Response({'message': 'Company profile saved successfully', 'company_id': company.id}, 
                      status=status.HTTP_200_OK)
    return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['POST'])
def step2_collective_agreement(request):
    """Handle Step 2: Collective Agreement"""
    company = Company.objects.first()
    if not company:
        return Response({'error': 'Company not found'}, status=status.HTTP_404_NOT_FOUND)
    
    serializer = Step2CollectiveAgreementSerializer(data=request.data)
    if serializer.is_valid():
        # Update or create collective agreement
        if hasattr(company, 'collective_agreement'):
            collective_agreement = company.collective_agreement
            for key, value in serializer.validated_data.items():
                setattr(collective_agreement, key, value)
            collective_agreement.save()
        else:
            CollectiveAgreement.objects.create(company=company, **serializer.validated_data)
        
        # Update onboarding progress
        progress = company.onboarding_progress
        progress.current_step = 'working_time_policy'
        progress.save()
        
        return Response({'message': 'Collective agreement saved successfully'}, status=status.HTTP_200_OK)
    return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['POST'])
def step3_working_time_policy(request):
    """Handle Step 3: Working Time Policy"""
    company = Company.objects.first()
    if not company:
        return Response({'error': 'Company not found'}, status=status.HTTP_404_NOT_FOUND)
    
    serializer = Step3WorkingTimePolicySerializer(data=request.data)
    if serializer.is_valid():
        # Create new working time policy with default name
        policy_data = serializer.validated_data.copy()
        policy_data['name'] = policy_data.get('name', 'Default Policy')
        policy_data['company'] = company
        if request.user and hasattr(request.user, 'id'):
            policy_data['created_by'] = request.user
        
        policy = WorkingTimePolicy.objects.create(**policy_data)
        
        # Optionally assign to the creating user if they exist
        if request.user and hasattr(request.user, 'id'):
            from user.models import User
            try:
                user = User.objects.get(id=request.user.id)
                user.working_time_policy = policy
                user.save()
            except User.DoesNotExist:
                pass
        
        # Update onboarding progress
        progress = company.onboarding_progress
        progress.current_step = 'vacation_policy'
        progress.save()
        
        return Response({'message': 'Working time policy saved successfully', 'policy_id': str(policy.id)}, status=status.HTTP_200_OK)
    return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['POST'])
def step4_vacation_policy(request):
    """Handle Step 4: Vacation Policy"""
    company = Company.objects.first()
    if not company:
        return Response({'error': 'Company not found'}, status=status.HTTP_404_NOT_FOUND)
    
    serializer = Step4VacationPolicySerializer(data=request.data)
    if serializer.is_valid():
        # Update or create vacation policy
        if hasattr(company, 'vacation_policy'):
            vacation_policy = company.vacation_policy
            for key, value in serializer.validated_data.items():
                setattr(vacation_policy, key, value)
            vacation_policy.save()
        else:
            VacationPolicy.objects.create(company=company, **serializer.validated_data)
        
        # Update onboarding progress to completed
        progress = company.onboarding_progress
        progress.current_step = 'completed'
        progress.is_completed = True
        progress.save()
        
        return Response({'message': 'Vacation policy saved successfully. Onboarding completed!'}, 
                      status=status.HTTP_200_OK)
    return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

@api_view(['GET'])
def onboarding_progress(request):
    """Get current onboarding progress"""
    company = Company.objects.first()
    if not company:
        return Response({'error': 'Company not found'}, status=status.HTTP_404_NOT_FOUND)
    
    if hasattr(company, 'onboarding_progress'):
        serializer = OnboardingProgressSerializer(company.onboarding_progress)
        return Response(serializer.data)
    else:
        return Response({'current_step': 'company_profile', 'is_completed': False}, 
                      status=status.HTTP_200_OK)

@api_view(['GET'])
def onboarding_options(request):
    """Get available options for onboarding forms"""
    return Response({
        'industries': [
            ('construction', 'Construction'),
            ('technology', 'Technology'),
            ('healthcare', 'Healthcare'),
            ('retail', 'Retail'),
            ('manufacturing', 'Manufacturing'),
            ('services', 'Services'),
            ('agriculture', 'Agriculture'),
            ('forestry', 'Forestry'),
            ('mining', 'Mining'),
            ('energy', 'Energy'),
            ('transportation', 'Transportation'),
            ('hospitality', 'Hospitality'),
            ('education', 'Education'),
            ('finance', 'Finance'),
            ('real_estate', 'Real Estate'),
            ('other', 'Other'),
        ],
        'company_sizes': Company.COMPANY_SIZE_CHOICES,
        'agreement_types': CollectiveAgreement.AGREEMENT_TYPE_CHOICES,
        'vacation_accrual_types': VacationPolicy.VACATION_ACCRUAL_CHOICES,
        'countries': [
            ('FI', 'Finland'),
            ('SE', 'Sweden'),
            ('NO', 'Norway'),
            ('DK', 'Denmark'),
        ]
    })

@api_view(['GET'])
@permission_classes([AllowAny])
def search_companies(request):
    """Search companies by name"""
    query = request.GET.get('q', '').strip()
    
    if not query:
        return Response({'companies': []}, status=status.HTTP_200_OK)
    
    # Search companies by name (case-insensitive)
    companies = Company.objects.filter(
        Q(name__icontains=query)
    ).order_by('name')[:10]  # Limit to 10 results
    
    serializer = CompanySearchSerializer(companies, many=True)
    return Response({'companies': serializer.data}, status=status.HTTP_200_OK)

@api_view(['POST'])
@permission_classes([AllowAny])
def login(request):
    """Login user using company API endpoint. Company ID is optional - if not provided, authenticates directly from local User database."""
    company_id = request.data.get('company_id')
    email = request.data.get('email')
    password = request.data.get('password')
    
    if not email or not password:
        return Response(
            {'error': 'email and password are required'},
            status=status.HTTP_400_BAD_REQUEST
        )
    
    # If company_id is not provided, authenticate directly from local User database
    # This is typically used by the web frontend (worktime)
    if not company_id:
        from user.models import User
        from django.contrib.auth.hashers import check_password
        from rest_framework_simplejwt.tokens import RefreshToken
        
        try:
            user = User.objects.get(email=email)
            # Check password
            if check_password(password, user.password):
                # Check if user has supervisor rights (for web login only)
                from supervisorgroup.models import SupervisorGroup
                is_supervisor = SupervisorGroup.objects.filter(user_id=user.id).exists()
                
                if not is_supervisor:
                    return Response(
                        {'error': 'Access not allowed. Supervisor rights required.'},
                        status=status.HTTP_403_FORBIDDEN
                    )
                
                # Store must_change_password flag BEFORE any operations
                must_change_password_flag = bool(user.must_change_password)
                
                # If user must change password, invalidate the temporary password after first use
                # This makes it one-time use - they can't log in again until they change it
                if user.must_change_password:
                    from django.contrib.auth.hashers import make_password
                    import secrets
                    # Invalidate the password by setting it to a random hash
                    # User must change password through the forced password change screen
                    user.password = make_password(secrets.token_urlsafe(32))
                    user.save()
                    # Refresh user from database to ensure we have latest state
                    user.refresh_from_db()

                # Issue JWT tokens
                refresh = RefreshToken.for_user(user)
                access_token = str(refresh.access_token)
                refresh_token = str(refresh)

                # Get or create a default company
                company = Company.objects.first()
                if not company:
                    # Create a default company if none exists
                    company = Company.objects.create(
                        name='Default Company',
                        industry='other',
                        company_size='1-5',
                        primary_location='Unknown',
                        api_endpoint=''
                    )
                
                company_data = CompanySearchSerializer(company).data
                
                # Serialize user AFTER password invalidation (if any)
                from user.serializers import UserReadSerializer
                user_serializer = UserReadSerializer(user)
                user_data = user_serializer.data
                # Ensure must_change_password is explicitly set in user data
                user_data['must_change_password'] = must_change_password_flag

                data = {
                    'success': True,
                    'user': user_data,
                    'company': company_data,
                    'api_endpoint': company.api_endpoint or '',
                    'access': access_token,
                    'refresh': refresh_token,
                    'must_change_password': must_change_password_flag,  # Also include at top level for easy access
                }

                resp = Response(data, status=status.HTTP_200_OK)
                # Set cookies
                resp.set_cookie('access_token', access_token, samesite='Lax')
                resp.set_cookie('refresh_token', refresh_token, samesite='Lax')
                return resp
            else:
                return Response(
                    {'error': 'Invalid email or password'},
                    status=status.HTTP_401_UNAUTHORIZED
                )
        except User.DoesNotExist:
            return Response(
                {'error': 'Invalid email or password'},
                status=status.HTTP_401_UNAUTHORIZED
            )
    
    # If company_id is provided, use the existing flow
    try:
        company = Company.objects.get(id=company_id)
    except Company.DoesNotExist:
        return Response(
            {'error': 'Company not found'},
            status=status.HTTP_404_NOT_FOUND
        )
    
    if not company.api_endpoint:
        return Response(
            {'error': 'Company API endpoint not configured'},
            status=status.HTTP_400_BAD_REQUEST
        )
    
    # Check if company API endpoint is local (same as our base URL)
    from django.conf import settings
    import re
    
    # Determine if this is a local endpoint (for testing)
    is_local_endpoint = False
    if company.api_endpoint:
        # Check if endpoint points to our local server
        local_patterns = ['0.0.0.0', '127.0.0.1', 'localhost', '192.168.1.104',company.api_endpoint]
        is_local_endpoint = any(pattern in company.api_endpoint for pattern in local_patterns)
    
    if is_local_endpoint:
        # For local endpoints, authenticate directly from local User database
        from user.models import User
        from django.contrib.auth.hashers import check_password, make_password
        from rest_framework_simplejwt.tokens import RefreshToken
        
        try:
            user = User.objects.get(email=email)
            # Check password
            if check_password(password, user.password):
                # Store must_change_password flag BEFORE any operations
                must_change_password_flag = user.must_change_password
                
                # If user must change password, invalidate the temporary password after first use
                # This makes it one-time use - they can't log in again until they change it
                if user.must_change_password:
                    import secrets
                    # Invalidate the password by setting it to a random hash
                    # User must change password through the forced password change screen
                    user.password = make_password(secrets.token_urlsafe(32))
                    user.save()
                    # Refresh user from database to ensure we have latest state
                    user.refresh_from_db()

                # Issue JWT tokens
                refresh = RefreshToken.for_user(user)
                access_token = str(refresh.access_token)
                refresh_token = str(refresh)
                
                # Serialize user AFTER password invalidation (if any)
                from user.serializers import UserReadSerializer
                user_serializer = UserReadSerializer(user)
                user_data = user_serializer.data
                # Ensure must_change_password is in user data
                user_data['must_change_password'] = must_change_password_flag

                data = {
                    'success': True,
                    'user': user_data,
                    'company': CompanySearchSerializer(company).data,
                    'api_endpoint': company.api_endpoint,
                    'access': access_token,
                    'refresh': refresh_token,
                    'must_change_password': must_change_password_flag,  # Also include at top level for easy access
                }

                resp = Response(data, status=status.HTTP_200_OK)
                # Set cookies (non-HttpOnly for dev so frontend can read; in prod use HttpOnly)
                resp.set_cookie('access_token', access_token, samesite='Lax')
                resp.set_cookie('refresh_token', refresh_token, samesite='Lax')
                return resp
            else:
                return Response(
                    {'error': 'Invalid email or password'},
                    status=status.HTTP_401_UNAUTHORIZED
                )
        except User.DoesNotExist:
            return Response(
                {'error': 'Invalid email or password'},
                status=status.HTTP_401_UNAUTHORIZED
            )
    else:
        # Call company's external API to authenticate user
        try:
            # Make request to company's API
            auth_response = requests.post(
                company.api_endpoint,
                json={'email': email, 'password': password},
                timeout=10,
                headers={'Content-Type': 'application/json'}
            )
            
            if auth_response.status_code == 200:
                auth_data = auth_response.json()
                
                # Find or create user in local database
                from user.models import User
                from django.contrib.auth.hashers import make_password
                
                try:
                    user = User.objects.get(email=email)
                    # Update password hash if company API authenticated
                    user.password = make_password(password)
                    if 'firstname' in auth_data:
                        user.firstname = auth_data.get('firstname', user.firstname)
                    if 'lastname' in auth_data:
                        user.lastname = auth_data.get('lastname', user.lastname)
                    user.save()
                except User.DoesNotExist:
                    # Create new user if company API authenticated but user doesn't exist locally
                    user = User.objects.create(
                        email=email,
                        password=make_password(password),
                        firstname=auth_data.get('firstname', ''),
                        lastname=auth_data.get('lastname', '')
                    )
                
                # Return user data
                from user.serializers import UserReadSerializer
                from rest_framework_simplejwt.tokens import RefreshToken
                user_serializer = UserReadSerializer(user)

                # Issue JWT tokens locally even when company API authenticates remotely
                refresh = RefreshToken.for_user(user)
                access_token = str(refresh.access_token)
                refresh_token = str(refresh)

                data = {
                    'success': True,
                    'user': user_serializer.data,
                    'company': CompanySearchSerializer(company).data,
                    'api_endpoint': company.api_endpoint,
                    'access': access_token,
                    'refresh': refresh_token,
                    'must_change_password': user.must_change_password,  # Include flag for password change requirement
                }

                resp = Response(data, status=status.HTTP_200_OK)
                resp.set_cookie('access_token', access_token, samesite='Lax')
                resp.set_cookie('refresh_token', refresh_token, samesite='Lax')
                return resp
            else:
                return Response(
                    {'error': 'Invalid email or password'},
                    status=status.HTTP_401_UNAUTHORIZED
                )
                
        except requests.exceptions.RequestException as e:
            return Response(
                {'error': f'Failed to connect to company API: {str(e)}'},
                status=status.HTTP_503_SERVICE_UNAVAILABLE
            )


class WorkingTimePolicyViewSet(viewsets.ModelViewSet):
    """ViewSet for managing Working Time Policies"""
    serializer_class = WorkingTimePolicySerializer
    permission_classes = [IsAuthenticated]
    
    def get_queryset(self):
        """Filter policies by company"""
        # Get company from user
        user = self.request.user
        company_id = getattr(user, 'company_id', None)
        
        if company_id:
            return WorkingTimePolicy.objects.filter(company_id=company_id).order_by('-created_at')
        else:
            # Fallback: get first company (for backward compatibility)
            company = Company.objects.first()
            if company:
                return WorkingTimePolicy.objects.filter(company=company).order_by('-created_at')
            return WorkingTimePolicy.objects.none()
    
    def get_serializer_context(self):
        """Add request to serializer context"""
        context = super().get_serializer_context()
        context['request'] = self.request
        return context
    
    def perform_create(self, serializer):
        """Set company and created_by when creating"""
        user = self.request.user
        company_id = getattr(user, 'company_id', None)
        
        if not company_id:
            # Fallback: get first company
            company = Company.objects.first()
            if company:
                company_id = company.id
        
        if company_id:
            serializer.save(company_id=company_id, created_by=user if hasattr(user, 'id') else None)
        else:
            # If no company found, raise an error
            from rest_framework.exceptions import ValidationError
            raise ValidationError({'company': 'No company found. Please ensure a company exists in the system.'})
    
    def create(self, request, *args, **kwargs):
        """Override create to handle vacation policy data if provided"""
        # Extract vacation policy data from request if present (before serializer validation)
        vacation_data = {}
        vacation_fields = [
            'vacation_accrual_type', 'custom_accrual_rate', 'vacation_year_start',
            'vacation_period_start', 'vacation_period_end', 'track_vacation_days',
            'track_flex_time', 'track_time_bank', 'track_overtime_balance', 'track_toil',
            'self_certification_days', 'sick_pay_percentage',
            'include_finnish_holidays', 'include_swedish_holidays'
        ]
        
        # Make a mutable copy of request.data
        if hasattr(request.data, 'copy'):
            request_data = request.data.copy()
        else:
            request_data = dict(request.data) if isinstance(request.data, dict) else {}
        
        for field in vacation_fields:
            if field in request_data:
                vacation_data[field] = request_data.pop(field)
        
        # Create working time policy with filtered data
        serializer = self.get_serializer(data=request_data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        headers = self.get_success_headers(serializer.data)
        response = Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
        
        # Create/update vacation policy if data provided
        if vacation_data:
            try:
                company_id = getattr(request.user, 'company_id', None)
                if not company_id:
                    company = Company.objects.first()
                    if company:
                        company_id = company.id
                
                if company_id:
                    company = Company.objects.get(id=company_id)
                    # Update or create vacation policy
                    if hasattr(company, 'vacation_policy'):
                        vacation_policy = company.vacation_policy
                        for key, value in vacation_data.items():
                            setattr(vacation_policy, key, value)
                        vacation_policy.save()
                    else:
                        VacationPolicy.objects.create(company=company, **vacation_data)
            except Exception as e:
                # Log error but don't fail the working time policy creation
                import logging
                logger = logging.getLogger(__name__)
                logger.error(f'Error creating vacation policy: {e}')
        
        return response
    
    def update(self, request, *args, **kwargs):
        """Override update to handle vacation policy data if provided"""
        # Extract vacation policy data from request if present (before serializer validation)
        vacation_data = {}
        vacation_fields = [
            'vacation_accrual_type', 'custom_accrual_rate', 'vacation_year_start',
            'vacation_period_start', 'vacation_period_end', 'track_vacation_days',
            'track_flex_time', 'track_time_bank', 'track_overtime_balance', 'track_toil',
            'self_certification_days', 'sick_pay_percentage',
            'include_finnish_holidays', 'include_swedish_holidays'
        ]
        
        # Make a mutable copy of request.data
        if hasattr(request.data, 'copy'):
            request_data = request.data.copy()
        else:
            request_data = dict(request.data) if isinstance(request.data, dict) else {}
        
        for field in vacation_fields:
            if field in request_data:
                vacation_data[field] = request_data.pop(field)
        
        # Update working time policy with filtered data
        partial = kwargs.pop('partial', False)
        instance = self.get_object()
        serializer = self.get_serializer(instance, data=request_data, partial=partial)
        serializer.is_valid(raise_exception=True)
        self.perform_update(serializer)
        
        if getattr(instance, '_prefetched_objects_cache', None):
            instance._prefetched_objects_cache = {}
        
        response = Response(serializer.data)
        
        # Update vacation policy if data provided
        if vacation_data:
            try:
                company_id = getattr(request.user, 'company_id', None)
                if not company_id:
                    company = Company.objects.first()
                    if company:
                        company_id = company.id
                
                if company_id:
                    company = Company.objects.get(id=company_id)
                    # Update or create vacation policy
                    if hasattr(company, 'vacation_policy'):
                        vacation_policy = company.vacation_policy
                        for key, value in vacation_data.items():
                            setattr(vacation_policy, key, value)
                        vacation_policy.save()
                    else:
                        VacationPolicy.objects.create(company=company, **vacation_data)
            except Exception as e:
                # Log error but don't fail the working time policy update
                import logging
                logger = logging.getLogger(__name__)
                logger.error(f'Error updating vacation policy: {e}')
        
        return response
    
    @action(detail=True, methods=['get'])
    def assigned_users(self, request, pk=None):
        """Get list of users assigned to this policy"""
        policy = self.get_object()
        from user.serializers import UserReadSerializer
        users = policy.assigned_users.all()
        serializer = UserReadSerializer(users, many=True)
        return Response(serializer.data)