#!/usr/bin/env python3
"""
Script to download Geneious Cloud workspace export parts.
Usage: python script.py <API_KEY> <WORKSPACE_EXPORT_ID>
"""

import sys
import argparse
import requests
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
import json
from typing import Dict, List, Any
import random
import time


class WorkspaceExportDownloader:
    def __init__(self, api_key: str, workspace_export_id: str, base_url: str = "https://api.geneious.eu", from_part_number: int = 1):
        self.api_key = api_key
        self.workspace_export_id = workspace_export_id
        self.base_url = base_url.rstrip('/')
        self.from_part_number = from_part_number
        self.session = requests.Session()
        self.session.headers.update({
            'X-Api-Key': api_key
        })

    def make_api_request(self, url: str, max_retries: int = 3) -> Dict[str, Any]:

        # HTTP status codes that should trigger a retry
        retry_status_codes = {502, 503, 504}

        # Exception types that should trigger a retry
        retry_exceptions = (
            requests.exceptions.ConnectionError,
            requests.exceptions.Timeout,
            requests.exceptions.ChunkedEncodingError,
            requests.exceptions.ContentDecodingError,
        )

        def handle_permanent_error(resp: requests.Response) -> None:
            """Handle permanent API errors (401, 403, etc.)."""
            if resp.status_code == 401:
                print("Error: Invalid API key (401 Unauthorized)", file=sys.stderr)
                sys.exit(1)
            elif resp.status_code == 403:
                print(f"Error: Access denied to workspace export with ID '{self.workspace_export_id}' (403 Forbidden)", file=sys.stderr)
                sys.exit(1)
            else:
                print(f"Error: API request failed with HTTP code {resp.status_code}", file=sys.stderr)
                print(f"Response: {resp.text}", file=sys.stderr)
                sys.exit(1)

        def calculate_backoff_delay(attempt: int) -> float:
            """Calculate exponential backoff delay with jitter."""
            # Exponential backoff: 1s, 2s, 4s, etc.
            base_delay = 2 ** (attempt - 1)
            # Add jitter to avoid thundering herd
            jitter = random.uniform(0.1, 0.5)
            return base_delay + jitter

        last_exception = None

        for attempt in range(1, max_retries + 1):
            try:
                print(f"Making request to: {url} (attempt {attempt}/{max_retries})")
                response = self.session.get(url, timeout=30)

                # Check for permanent errors first
                if response.status_code in {401, 403}:
                    handle_permanent_error(response)

                # Check for success
                if response.ok:
                    try:
                        return response.json()
                    except json.JSONDecodeError:
                        print("Error: Invalid JSON response from API", file=sys.stderr)
                        sys.exit(1)

                # Check for retryable status codes
                if response.status_code in retry_status_codes:
                    if attempt < max_retries:
                        delay = calculate_backoff_delay(attempt)
                        print(f"Server error {response.status_code}, retrying in {delay:.1f} seconds...")
                        time.sleep(delay)
                        continue
                    else:
                        print(f"Error: Max retries exceeded. Last HTTP code: {response.status_code}", file=sys.stderr)
                        print(f"Response: {response.text}", file=sys.stderr)
                        sys.exit(1)

                # Other HTTP errors (4xx, 5xx not in retry list)
                handle_permanent_error(response)

            except retry_exceptions as e:
                last_exception = e
                if attempt < max_retries:
                    delay = calculate_backoff_delay(attempt)
                    print(f"Connection error ({type(e).__name__}), retrying in {delay:.1f} seconds...")
                    time.sleep(delay)
                    continue
                else:
                    print(f"Error: Max retries exceeded due to connection issues: {str(e)}", file=sys.stderr)
                    sys.exit(1)

            except requests.exceptions.RequestException as e:
                # Non-retryable request exceptions
                print(f"Error: Request failed: {str(e)}", file=sys.stderr)
                sys.exit(1)

        # This should never be reached, but just in case
        if last_exception:
            print(f"Error: Unexpected failure after retries: {str(last_exception)}", file=sys.stderr)
        else:
            print("Error: Unexpected failure in retry logic", file=sys.stderr)
        sys.exit(1)

    def get_parts_list(self) -> List[Dict[str, str]]:
        """Get the list of parts from the API."""
        print("Fetching parts list...")
        url = f"{self.base_url}/api/nucleus/v2/workspace-exports/{self.workspace_export_id}"
        response_data = self.make_api_request(url).get('data')

        # Extract parts array
        parts = [p for p in response_data.get('dataParts', []) if p['partNumber'] >= self.from_part_number]

        if not parts or len(parts) == 0:
            print(f"No parts found for workspace export ID: {self.workspace_export_id} starting from part number {self.from_part_number}")
            return []

        print(f"Found {len(parts)} parts to download")
        return parts

    def get_s3_details(self, part_number: str) -> Dict[str, str]:
        """Get S3 details for a specific part."""
        url = f"{self.base_url}/api/nucleus/v2/workspace-exports/{self.workspace_export_id}/parts/{part_number}/download"
        response_data = self.make_api_request(url).get('data')

        required_fields = {
            'bucket_name': 's3BucketName',
            'object_key': 's3Key',
            's3_endpoint': 's3Endpoint',
            'aws_access_key_id': 'accessKeyID',
            'aws_secret_access_key': 'secretAccessKey',
            'aws_session_token': 'sessionToken'
        }

        s3_details = {}
        for field, response_field in required_fields.items():
            value = response_data[response_field]

            if value is None:
                print(f"Error: Missing required field '{response_field}' in S3 response for part {part_number}", file=sys.stderr)
                sys.exit(1)

            s3_details[field] = value

        return s3_details

    def download_from_s3(self, s3_details: Dict[str, str], filename: str) -> None:
        """Download file from S3 using boto3."""
        print(f"Downloading {filename} from S3...")

        try:
            # Create S3 client with temporary credentials
            s3_client = boto3.client(
                's3',
                endpoint_url=s3_details['s3_endpoint'],
                aws_access_key_id=s3_details['aws_access_key_id'],
                aws_secret_access_key=s3_details['aws_secret_access_key'],
                aws_session_token=s3_details['aws_session_token']
            )

            # Download the file
            s3_client.download_file(
                s3_details['bucket_name'],
                s3_details['object_key'],
                filename
            )

            print(f"Successfully downloaded: {filename}")

        except NoCredentialsError:
            print(f"Error: Invalid AWS credentials for downloading {filename}", file=sys.stderr)
            sys.exit(1)
        except ClientError as e:
            error_code = e.response['Error']['Code']
            if error_code == 'NoSuchBucket':
                print(f"Error: S3 bucket '{s3_details['bucket_name']}' does not exist", file=sys.stderr)
            elif error_code == 'NoSuchKey':
                print(f"Error: S3 object '{s3_details['object_key']}' does not exist", file=sys.stderr)
            elif error_code == 'AccessDenied':
                print(f"Error: Access denied to S3 object '{s3_details['object_key']}'", file=sys.stderr)
            else:
                print(f"Error: S3 download failed: {str(e)}", file=sys.stderr)
            sys.exit(1)
        except Exception as e:
            print(f"Error: Failed to download {filename} from S3: {str(e)}", file=sys.stderr)
            sys.exit(1)

    def process_parts(self) -> None:
        """Main processing function to download all parts."""
        parts = self.get_parts_list()

        if not parts:
            return

        for i, part in enumerate(parts, 1):
            part_number = part.get('partNumber')
            filename = part.get('originalFilename')

            if not part_number or not filename:
                print(f"Error: Part {i} missing required fields (partNumber or originalFilename)", file=sys.stderr)
                print(f"Part data: {json.dumps(part, indent=2)}", file=sys.stderr)
                sys.exit(1)

            print(f"Processing part {i}/{len(parts)}: {part_number} -> {filename}")

            # Get S3 details for this part
            s3_details = self.get_s3_details(part_number)

            # Download the file from S3
            self.download_from_s3(s3_details, filename)

        print("All downloads completed successfully!")


def main():
    parser = argparse.ArgumentParser(
        description="Download Geneious Cloud workspace export parts",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument("api_key", help="The Geneious Cloud API key for authentication")
    parser.add_argument("workspace_export_id", help="The id of the workspace export to download")
    parser.add_argument(
        "--geneious-cloud-api-url",
        default="https://api.geneious.eu",
        help="Base URL for the Geneious Cloud API (default: https://api.geneious.eu)"
    )
    parser.add_argument(
        "--from-part-number",
        type=int,
        default=1,
        help="The workspace export part number to start downloading from (default is to start from the first part)"
    )

    args = parser.parse_args()

    # Validate arguments
    if not args.api_key.strip():
        print("Error: api_key cannot be empty", file=sys.stderr)
        sys.exit(1)

    if not args.workspace_export_id.strip():
        print("Error: workspace_export_id cannot be empty", file=sys.stderr)
        sys.exit(1)

    # Check for required libraries (this will fail at import time if missing)
    try:
        import requests
        import boto3
    except ImportError as e:
        print(f"Error: Missing required library: {str(e)}", file=sys.stderr)
        print("Install with: pip install requests boto3", file=sys.stderr)
        sys.exit(1)

    print(f"Starting download process for workspace export ID: {args.workspace_export_id}")

    # Create downloader and process parts
    downloader = WorkspaceExportDownloader(args.api_key, args.workspace_export_id, args.geneious_cloud_api_url, args.from_part_number)
    downloader.process_parts()


if __name__ == "__main__":
    main()