"""This script is for downloading all released PDB entries (of a single file type/format) from the Beta PDB Archive. 

It uses asynchronous aiohttp library to download multiple files asynchronously when performing bulk downloads.

It requires python 3.8 or higher and aiofiles, aiohttp packages. The aiofiles, aiohttp packages can be installed
 with the following commands:

   pip install aiofiles
   pip install aiohttp

The script requires two input arguments to run. The following example command line downloads all mmCIF files and stores
 the downloaded files under the directory, `/home/my_user_id/download`:

   python BetaArchiveBatchDownloader.py --file_type mmcif --output_dir /home/my_user_id/download

(Run the following command lines to see all supported download file types:

   python BetaArchiveBatchDownloader.py 
or
   python BetaArchiveBatchDownloader.py -h
or
   python BetaArchiveBatchDownloader.py --help  
)

How the downloaded files are stored:

 Since the current Archive has more than 246000+ entries, it is not desirable to have quarter million files under a single
 directory. 

 The script first creates a top sub directory using file type name as sub directory name (/home/my_user_id/download/mmcif),
 then creates the hash directories based on pdb ids. The downloaded files are stored in hash directories based on pdb ids. 

 For the above example command, the downloaded files are stored as following:

   /home/my_user_id/download/mmcif/00/pdb_0000100d.cif.gz
   /home/my_user_id/download/mmcif/00/pdb_0000200d.cif.gz
   /home/my_user_id/download/mmcif/00/pdb_0000200l.cif.gz
   /home/my_user_id/download/mmcif/00/pdb_0000300d.cif.gz
   /home/my_user_id/download/mmcif/00/pdb_0000400d.cif.gz

   /home/my_user_id/download/mmcif/01/pdb_0000101d.cif.gz
   /home/my_user_id/download/mmcif/01/pdb_0000101m.cif.gz
   /home/my_user_id/download/mmcif/01/pdb_0000201d.cif.gz
   /home/my_user_id/download/mmcif/01/pdb_0000201l.cif.gz
   /home/my_user_id/download/mmcif/01/pdb_0000301d.cif.gz
   /home/my_user_id/download/mmcif/01/pdb_0000401d.cif.gz

   ......
"""
import aiofiles
import aiohttp
import argparse
import asyncio
import gzip
import json
import os
import shutil
import sys
import textwrap

file_type_help=\
""" 
The supported file types for downloading are listed in left column.
The corresponding file naming conventions are listed in right column.

mmcif               : pdb_xxxxxxxx.cif.gz
pdb                 : pdb_xxxxxxxx.pdb.gz
assemblies          : pdb_xxxxxxxx-assembly#.cif.gz
XML                 : pdb_xxxxxxxx.xml.gz
XML-extatom         : pdb_xxxxxxxx-extatom.xml.gz
XML-noatom          : pdb_xxxxxxxx-noatom.xml.gz
structure_factors   : pdb_xxxxxxxx-sf.cif.gz
nmr_data_str        : pdb_xxxxxxxx_nmr-data.str.gz
nmr_data_nef        : pdb_xxxxxxxx_nmr-data.nef.gz
nmr_chemical_shifts : pdb_xxxxxxxx_cs.str.gz
nmr_restraints      : pdb_xxxxxxxx.mr.gz
nmr_restraints_v2   : pdb_xxxxxxxx_mr.str.gz
validation_cif      : pdb_xxxxxxxx_validation.cif.gz
validation_xml      : pdb_xxxxxxxx_validation.xml.gz
validation_pdf      : pdb_xxxxxxxx_validation.pdf.gz
full_validation_pdf : pdb_xxxxxxxx_full_validation.pdf.gz
"""

baseUrl = "https://files-beta.wwpdb.org"

MAX_RETRIES = 3
MAX_TASKS = 5

total_task_numbers = 0
finish_task_numbers = 0

wait_time = { 0: 0.5, 1: 2 }

async def fetch_gzipped_json(url):
    """ Fetches a gzipped JSON file from a URL, decompresses it, and parses the JSON.
    """
    for attempt in range(MAX_RETRIES):
        try:
            async with aiohttp.ClientSession() as session:
                async with session.get(url) as response:
                    # Raise an exception for bad status codes
                    response.raise_for_status()

                    # Read the gzipped content as bytes
                    gzipped_content = await response.read()

                    try:
                        # Decompress the content
                        decompressed_content = gzip.decompress(gzipped_content)

                        # Decode the bytes to a string (assuming UTF-8 encoding)
                        json_string = decompressed_content.decode('utf-8')

                        # Parse the JSON string
                        data = json.loads(json_string)
                        return data
                    except gzip.BadGzipFile:
                        print(f"Error: The file {url} is not a valid gzip file.")
                        return {}
                    except json.JSONDecodeError:
                        print("Error: Could not decode JSON from the decompressed content.")
                        return {}
                    #
                #
            #
        except (aiohttp.ClientError, Exception) as e:
            print(f"Attempt {attempt + 1} for {url} failed: {e}")
            if attempt < MAX_RETRIES - 1:
                await asyncio.sleep(wait_time[attempt])
            else:
                print(f"All retries failed for {url}")
                return {}
            #
        #
    #

async def download_file(url, session, semaphore, filePath):
    """ Fetches a file from a URL, and writes out to a local file.
    """
    async with semaphore:
        for attempt in range(MAX_RETRIES):
            try:
                async with session.get(url) as response:
                    # Raise an exception for bad status codes
                    response.raise_for_status()

                    async with aiofiles.open(filePath, mode="wb") as fp:
                        while True:
                            chunk = await response.content.read()
                            if not chunk:
                                break
                            #
                            await fp.write(chunk)
                        #
                    #
                    return "downloaded " + url
                #
            except (aiohttp.ClientError, Exception) as e:
                print(f"Attempt {attempt + 1} for {url} failed: {e}")
                if attempt < MAX_RETRIES - 1:
                    await asyncio.sleep(wait_time[attempt])
                else:
                    print(f"All retries failed for {url}")
                    return "All retries failed for " + url
                #
            #
        #
    #

def callback_function(task):
    """ This method is called when the task is done. It receives the task object as its only argument.
    """
    global total_task_numbers
    global finish_task_numbers
    #
    try:
        finish_task_numbers += 1
        if finish_task_numbers == 1:
            print(f"Finished Downloading first (total {total_task_numbers}) file.")
        elif (int(finish_task_numbers / 1000) * 1000) == finish_task_numbers:
            print(f"Finished Downloading {finish_task_numbers} (total {total_task_numbers}) files.")
        elif finish_task_numbers == total_task_numbers:
            print(f"Finished Downloading last (total {total_task_numbers}) file.")
        #
    except asyncio.CancelledError:
        print("Task was cancelled.")
    except Exception as e:
        print(f"Task raised an exception: {e}")
    #

async def main(fileType, fileInfoTuple, topOutputPath):
    """ Main method for batch file downloading.
    """
    global total_task_numbers
    try:
        # Read total file size information from ${baseUrl}/pub/wwpdb/pdb/holdings/file_size.json.gz file
        sizeObj = await fetch_gzipped_json(os.path.join(baseUrl, "pub", "wwpdb", "pdb", "holdings", "file_size.json.gz"))
        if len(sizeObj) > 0:
            required_size = ""
            if fileType in sizeObj:
                required_size = sizeObj[fileType]
            elif fileInfoTuple[0] in sizeObj:
                required_size = sizeObj[fileInfoTuple[0]]
            #
            if required_size:
                total, used, free = shutil.disk_usage(topOutputPath)
                free_size = float(free) / float(1024**3)
                float_required_size = float(required_size)
                if (float(free) / float(1024**3)) < float(required_size):
                    print(f"The local file system does not have enough disk space ({free / (1024**3):.2f} GB available) for downloading '{fileType}' files: it requires {required_size} GB.")
                    return
                #
            #
        #
        # Read all released entry information from ${baseUrl}/pub/wwpdb/pdb/holdings/current_file_holdings.json.gz file.
        jsonObj = await fetch_gzipped_json(os.path.join(baseUrl, "pub", "wwpdb", "pdb", "holdings", "current_file_holdings.json.gz"))
        if len(jsonObj) == 0:
            return
        #

        hashIdList = []
        fileList = []
        count = 0
        for pdb_id,entryObj in jsonObj.items():
            if fileInfoTuple[0] not in entryObj:
                continue
            #
            hash_id = pdb_id[-3:-1]
            #
            for fileName in entryObj[fileInfoTuple[0]]: 
                baseFileName = os.path.basename(fileName)
                if fileInfoTuple[0] == "validation_report":
                    if baseFileName != (pdb_id + fileInfoTuple[2]):
                        continue
                    #
                #
                hashIdList.append(hash_id)
                #
                # Use full path URL
                fileUrl = os.path.join(baseUrl, "pub") + fileName
                #
                # Download file path
                downloadFilePath = os.path.join(topOutputPath, fileType, hash_id, baseFileName)
                #
                fileList.append( ( fileUrl, downloadFilePath ) )
                #
                count += 1
            #
        #
        if (len(hashIdList) == 0) or (len(fileList) == 0):
            return
        #
        # Make file type sub directory
        if not os.access(os.path.join(topOutputPath, fileType), os.F_OK):
            os.makedirs(os.path.join(topOutputPath, fileType))
        #
        # Make hash sub directory
        hashIdList = sorted(list(set(hashIdList)))
        for hash_id in hashIdList:
            if not os.access(os.path.join(topOutputPath, fileType, hash_id), os.F_OK):
                os.makedirs(os.path.join(topOutputPath, fileType, hash_id))
            #
        #
        # Set script to download multiple files (max 5 requests) at any given time
        semaphore = asyncio.Semaphore(MAX_TASKS)

        total_task_numbers = len(fileList)

        # Use Asynchronous HTTP Client/Server Protocol to download files
        async with aiohttp.ClientSession() as session:
            tasks = []
            for fileTuple in fileList:
                task = asyncio.create_task(download_file(fileTuple[0], session, semaphore, fileTuple[1]))
                task.add_done_callback(callback_function)
                tasks.append(task)
            #
            return await asyncio.gather(*tasks)
        #
    except (aiohttp.ClientError, Exception) as e:
        print(f"Error fetching data: {e}")
    #

if __name__ == "__main__":
    # The following dictionary defines the mapping between supported file types and the attribute key in current_file_holdings.json,
    # keyword in short link URL, & file naming suffix pattern (validation report files only).
    #
    file_type_map = { "mmcif" : ( "mmcif", "download" ), \
                      "pdb" : ( "pdb", "download" ), \
                      "assemblies" : ( "assembly_mmcif", "download" ), \
                      "XML" : ( "pdbml", "download" ), \
                      "XML-noatom" : ( "pdbml_noatom", "download" ), \
                      "XML-extatom" : ( "pdbml_extatom", "download" ), \
                      "structure_factors" : ( "structure_factors", "download" ), \
                      "nmr_data_str" : ( "combined_nmr_data_nmr-star", "download" ), \
                      "nmr_data_nef" : ( "combined_nmr_data_nef", "download" ), \
                      "nmr_chemical_shifts" : ( "nmr_chemical_shifts", "download" ), \
                      "nmr_restraints" : ( "nmr_restraints_v1", "download" ), \
                      "nmr_restraints_v2" : ( "nmr_restraints_v2", "download" ), \
                      "validation_cif" : ( "validation_report", "validation/download", "_validation.cif.gz" ), \
                      "validation_xml" : ( "validation_report", "validation/download", "_validation.xml.gz" ), \
                      "validation_pdf" : ( "validation_report", "validation/download", "_validation.pdf.gz" ), \
                      "full_validation_pdf" : ( "validation_report", "validation/download", "_full_validation.pdf.gz" ) }
    #
    scriptDescription = "The script is for downloading one type of files for all released PDB entries from Beta PDB Archive."
    parser = argparse.ArgumentParser(description=scriptDescription, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--file_type", help=textwrap.dedent(file_type_help))
    parser.add_argument("--output_dir", help="The top directory path where the downloaded files are stored.")
    args = parser.parse_args()
    if (args.file_type is None) or (args.output_dir is None):
        parser.print_help()
        sys.exit(1)
    #
    if args.file_type not in file_type_map:
        print(f'The FILE_TYPE value "{args.file_type}" is not allowed. See below for the allowed FILE_TYPE values.\n')
        parser.print_help()
    #
    outputDirPath = os.path.abspath(args.output_dir)
    if not os.path.isdir(outputDirPath):
        if not os.access(outputDirPath, os.F_OK): 
            topDirPath = os.path.dirname(outputDirPath)
            if not os.access(topDirPath, os.W_OK):
                print(f"You do NOT have write access to the directory: {topDirPath}")
                sys.exit(1)
            else:
                os.makedirs(outputDirPath)
            #
        else:
            print(f"{outputDirPath} is not a directory.")
            sys.exit(1)
        #
    elif not os.access(outputDirPath, os.W_OK):
        print(f"You do NOT have write access to the directory: {outputDirPath}")
        sys.exit(1)
    #
    asyncio.run(main(args.file_type, file_type_map[args.file_type], outputDirPath))
