# worktime/tasks.py
from celery import shared_task

from stamps.models import Stamp
from .models import WorkBalance
from .service import calculate_daily_balances_for_user_and_date
from datetime import date, timedelta, datetime
from django.core.cache import cache

LOCK_ID = 'schedule_catch_up_balances_lock'
LOCK_EXPIRE = 60 * 6 # Lock expires after 6 minutes (must be > schedule time)

@shared_task(bind=True, max_retries=3, default_retry_delay=5)
def calculate_balance_task(self, user_id, date_iso):
    try:
        # date_iso is 'YYYY-MM-DD'
        from datetime import datetime
        d = datetime.strptime(date_iso, "%Y-%m-%d").date()

        wb = calculate_daily_balances_for_user_and_date(user_id, d)
        return {'status':'ok', 'user_id': user_id, 'date': date_iso}
    except Exception as exc:
        raise self.retry(exc=exc)

@shared_task(bind=True, max_retries=3, default_retry_delay=5)
def calculate_balance_tasks(self, user_id, start_date_iso=None):
    try:
        today = date.today()

        if start_date_iso:
            start_date = datetime.strptime(start_date_iso, "%Y-%m-%d").date()
        else:
            # start from last balance or first stamp
            last_balance = WorkBalance.objects.filter(user_id=user_id).order_by('-date').first()
            if last_balance:
                start_date = last_balance.date + timedelta(days=1)
            else:
                first_stamp = Stamp.objects.filter(user_id=user_id).order_by('date').first()
                if first_stamp:
                    start_date = first_stamp.date
                else:
                    # nothing to compute
                    return {'status':'no_data', 'user_id': user_id}

        # compute for all dates up to today
        current_date = start_date
        while current_date <= today:
            calculate_daily_balances_for_user_and_date(user_id, current_date)
            current_date += timedelta(days=1)

        return {'status':'ok', 'user_id': user_id, 'start_date': start_date.isoformat(), 'end_date': today.isoformat()}

    except Exception as exc:
        raise self.retry(exc=exc)


# --- Task 1: The worker that calculates the historical range for a single user ---
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def catch_up_work_balances_for_user(self, user_id):
    """
    Computes balances for a user from their last recorded balance up to yesterday,
    but limits the historical depth to 7 days for efficiency.
    """
    try:
        yesterday = date.today() - timedelta(days=1)

        # KEY OPTIMIZATION: Only check the last 7 days for historical gaps.
        hard_limit = yesterday - timedelta(days=7)

        # Determine start date based on last balance or first stamp
        last_balance = WorkBalance.objects.filter(user_id=user_id).order_by('-date').first()
        if last_balance:
            start_date = last_balance.date + timedelta(days=1)
        else:
            first_stamp = Stamp.objects.filter(user_id=user_id).order_by('date').first()
            if first_stamp:
                start_date = first_stamp.date
            else:
                return {'status': 'no_data', 'user_id': user_id}

        # Apply the hard limit to the start date
        start_date = max(start_date, hard_limit)

        current_date = start_date
        while current_date <= yesterday:
            calculate_daily_balances_for_user_and_date(user_id, current_date)
            current_date += timedelta(days=1)

        return {'status': 'ok', 'user_id': user_id, 'end_date': yesterday.isoformat()}

    except Exception as exc:
        raise self.retry(exc=exc)


@shared_task
def schedule_catch_up_balances():
    """
    Triggers the catch-up task for all users who have stamps.
    Protected by a lock to ensure only one orchestrator runs at a time.
    """
    # Attempt to acquire a lock (requires Django caching configured, e.g., Redis)
    if cache.add(LOCK_ID, 'true', LOCK_EXPIRE):
        try:
            # Efficiently get all unique user IDs that have ever created a stamp
            user_ids = Stamp.objects.values_list('user_id', flat=True).distinct()

            for user_id in user_ids:
                # Enqueue the worker task for each user
                catch_up_work_balances_for_user.delay(user_id)
        finally:
            # Always release the lock
            cache.delete(LOCK_ID)
    else:
        # Another instance is already running; skip this iteration
        print("Catch-up orchestrator skipped due to existing lock.")