"""Custom JWT authentication to handle UUID user IDs"""
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework_simplejwt.serializers import TokenRefreshSerializer
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from user.models import User
import uuid


class UUIDJWTAuthentication(JWTAuthentication):
    """JWT Authentication that properly handles UUID user IDs"""
    
    def get_user(self, validated_token):
        """
        Override to properly handle UUID user IDs
        """
        try:
            user_id = validated_token.get('user_id')
            if not user_id:
                raise InvalidToken('Token contained no recognizable user identification')
            
            # Convert to UUID if it's a string
            if isinstance(user_id, str):
                try:
                    user_id = uuid.UUID(user_id)
                except (ValueError, AttributeError):
                    raise InvalidToken('Token contained invalid user identification')
            
            # Get user using UUID
            try:
                user = User.objects.get(id=user_id)
            except User.DoesNotExist:
                raise InvalidToken('User not found')
            
            return user
        except TokenError:
            raise
        except Exception as e:
            raise InvalidToken(f'Token contained invalid user identification: {str(e)}')


class UUIDRefreshToken(RefreshToken):
    """RefreshToken that handles UUID user IDs"""
    
    @classmethod
    def for_user(cls, user):
        """Override to ensure UUID is properly handled"""
        token = super().for_user(user)
        return token
    
    @property
    def access_token(self):
        """Override access_token property to handle UUID user lookup"""
        # Get user_id from refresh token
        user_id = self.get('user_id')
        if not user_id:
            raise InvalidToken('Token contained no recognizable user identification')
        
        # Convert to UUID if it's a string
        if isinstance(user_id, str):
            try:
                user_id = uuid.UUID(user_id)
            except (ValueError, AttributeError):
                raise InvalidToken('Token contained invalid user identification')
        
        # Get user to create proper access token
        try:
            user = User.objects.get(id=user_id)
        except User.DoesNotExist:
            raise InvalidToken('User not found')
        
        # Create access token using for_user method (this sets token_type correctly to "access")
        access = AccessToken.for_user(user)
        
        # Copy other important claims from refresh token (but not token_type, exp, iat, jti)
        # The token_type is already set correctly by for_user() to "access"
        for claim, value in self.payload.items():
            # Skip system claims that should be unique to each token type
            if claim not in ['token_type', 'exp', 'iat', 'jti']:
                access[claim] = value
        
        return access


class UUIDTokenRefreshSerializer(TokenRefreshSerializer):
    """Token refresh serializer that properly handles UUID user IDs"""
    token_class = UUIDRefreshToken
    
    def validate(self, attrs):
        refresh = self.token_class(attrs["refresh"])
        
        # Validate the refresh token
        try:
            refresh.verify()
        except TokenError as e:
            raise InvalidToken(str(e))
        
        # Verify user exists with UUID
        user_id = refresh.get('user_id')
        if user_id:
            if isinstance(user_id, str):
                try:
                    user_id = uuid.UUID(user_id)
                except (ValueError, AttributeError):
                    raise InvalidToken('Token contained invalid user identification')
            
            try:
                User.objects.get(id=user_id)
            except User.DoesNotExist:
                raise InvalidToken('User not found')
        
        # Generate new tokens
        data = {}
        refresh.set_jti()
        refresh.set_exp()
        refresh.set_iat()
        
        data["refresh"] = str(refresh)
        data["access"] = str(refresh.access_token)
        
        return data


def get_uuid_token_refresh_view():
    """Get UUID-aware token refresh view (lazy import to avoid circular import)"""
    from rest_framework_simplejwt.views import TokenRefreshView
    
    class UUIDTokenRefreshView(TokenRefreshView):
        """Token refresh view that uses UUID-aware serializer"""
        serializer_class = UUIDTokenRefreshSerializer
    
    return UUIDTokenRefreshView
