#!/usr/bin/env -S python3 -u
#                           PUBLIC DOMAIN NOTICE
#              National Center for Biotechnology Information
#
# This software is a "United States Government Work" under the
# terms of the United States Copyright Act.  It was written as part of
# the authors' official duties as United States Government employees and
# thus cannot be copyrighted.  This software is freely available
# to the public for use.  The National Library of Medicine and the U.S.
# Government have not placed any restriction on its use or reproduction.
#   
# Although all reasonable efforts have been taken to ensure the accuracy
# and reliability of the software and data, the NLM and the U.S.
# Government do not and cannot warrant the performance or results that
# may be obtained by using this software or data.  The NLM and the U.S.
# Government disclaim all warranties, express or implied, including
# warranties of performance, merchantability or fitness for any particular
# purpose.
#   
# Please cite NCBI in any work or product based on this material.

# Script to download a BLAST database and queries in a FASTA file, 
# split query into parts if needed, run a BLAST
# search, and copy BLAST results to cloud storage.
#
# Author: Victor Joukov joukovv@ncbi.nlm.nih.gov


import os
import sys
import re
import time
import argparse
import subprocess
import shlex, shutil
import logging
import filelock
from hashlib import md5
from typing import Union, List
from dataclasses import dataclass
from pathlib import Path
import requests
from ec2_metadata import ec2_metadata
import tempfile

# Const
VERSION = "$VERSION"

DESC = 'Helper script to run BLAST on Elastic-BLAST nodes'
BLAST_WD = '/blast/blastdb'

# URL of the AWS EC2 metadata 
AWS_METADATA_URL = 'http://169.254.169.254/latest'
# Timeout in seconds to reach the URL above
TIMEOUT_AWS_METADATA_URL = 3

BLASTDB_ERROR = 2

MAX_PROCS_TO_DOWNLOAD_DB = 12

log_levels = {
    'debug': logging.DEBUG,
    'info': logging.INFO,
    'warning': logging.WARNING,
    'error': logging.ERROR,
    'critical': logging.CRITICAL
}

SUPPORTED_PROGRAMS = [
    'blastp',
    'blastn',
    'megablast',
    'blastx',
    'psiblast',
    'rpsblast',
    'rpstblastn',
    'tblastn',
    'tblastx'
]

MAX_BLASTDB_FILE_SIZE = 5  # in GB

# Var
dry_run = False
script_dir = '.'


def is_aws_instance():
    """ Return true if running on an EC2 instance, otherwise return false """
    retval = False
    try:
        r = requests.head(AWS_METADATA_URL, timeout=TIMEOUT_AWS_METADATA_URL)
        retval = r.status_code == 200
    except requests.exceptions.Timeout:
        pass
    return retval


class SafeExecError(Exception):
    """Exception generated by safe_exec.
    Attributes:
        returncode: Error code
        message: Error message"""

    def __init__(self, returncode: int, message: str):
        """Initialize parameters:"""
        self.returncode = returncode
        self.message = message

    def __str__(self):
        """Conversion to a string"""
        return self.message

@dataclass
class FakeProcess:
    returncode: int = 0
    stdout = b''
    stderr = b''


def safe_exec(cmd: Union[List[str], str], shell=False) -> subprocess.CompletedProcess:
    """Wrapper around subprocess.run that raises SafeExecError on errors from
    command line with error messages assembled from all available information"""
    if isinstance(cmd, str):
        cmd = cmd.split()
    if not isinstance(cmd, list):
        raise ValueError('safe_exec "cmd" argument must be a list or string')

    try:
        print(' '.join(map(lambda x: "'"+x+"'" if ' ' in x else x, cmd)), file=sys.stderr)
        if dry_run:
            p = FakeProcess()
        else:
            p = subprocess.run(cmd, check=True, shell=shell, stdout=subprocess.PIPE,
                               stderr=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        msg = f'The command "{" ".join(e.cmd)}" returned with exit code {e.returncode}\n{e.stderr.decode()}\n{e.stdout.decode()}'
        if e.output is not None:
            '\n'.join([msg, f'{e.output.decode()}'])
            raise SafeExecError(e.returncode, msg)
    except PermissionError as e:
        raise SafeExecError(e.errno, str(e))
    except FileNotFoundError as e:
        raise SafeExecError(e.errno, e.strerror)
    return p


def parse_args(argv):
    global dry_run
    global script_dir
    script_dir = os.path.dirname(os.path.realpath(argv[0]))
    parser = argparse.ArgumentParser(prog=os.path.basename(os.path.splitext(argv[0])[0]), 
                                     description=DESC)
    parser.add_argument('--db', type=str, required=True, help='BLAST database to search')
    parser.add_argument('--db-path', type=str, help='Path to the user database in the AWS S3')
    parser.add_argument('--source', type=str, required=True, help='Source for standard database: AWS, GCP, or NCBI')
    parser.add_argument('--query', type=str, required=True, help='Query path in AWS S3')
    parser.add_argument('--num-threads', type=int, required=True, help='Number of threads to use for search program')
    parser.add_argument('--program', type=str, required=True, help='BLAST program to run',
                        choices=SUPPORTED_PROGRAMS)
    parser.add_argument('--bucket', type=str, required=True, help='Bucket to put results to')
    parser.add_argument('--params', type=str, help='Search program parameters')
    parser.add_argument('--db-mol-type', type=str, required=True, help='Molecular type of the database',
                        choices=['prot', 'nucl'])
    parser.add_argument('--taxidlist', type=str, help='Taxonomy ID list to use')
    parser.add_argument('--verbose', default=False, action='store_true', help='Verbosity of debug output')
    parser.add_argument('--workdir', default=BLAST_WD, type=str, help='Working directory to use')
    parser.add_argument('--loglevel', type=str, choices=log_levels.keys(), help='Level of logging')
    parser.add_argument('--dry-run', default=False, action='store_true', help='Show actions but don\'t execute')
    parser.add_argument('--version', action='version',
                        version='%(prog)s ' + VERSION)

    experimental_opts = parser.add_argument_group("Experimental cloud query splitting options")
    experimental_opts.add_argument('--num-parts', type=int, default=-1, help='Number of parts to split query into, default=-1 (disabled)')
    experimental_opts.add_argument('--split-part', type=int, default=-1, help='Part to use in search, zero-based, default=-1 (disabled)')
    experimental_opts.add_argument('--search', default=True, dest='run_search', action='store_true', help="Run BLAST search")
    experimental_opts.add_argument('--no-search', default=True, action='store_false', help="Don't run BLAST search, for testing query splitting")

    testing_opts = parser.add_argument_group("Testing options")
    testing_opts.add_argument('--no-vmtouch', default=False, action='store_true', help="Don't run the vmtouch phase")
    testing_opts.add_argument('--no-creds', default=False, action='store_true', help="Public AWS access only, don't write the results")

    args = parser.parse_args(argv[1:])
    if args.loglevel:
        logging.basicConfig(level=log_levels[args.loglevel])
    dry_run = args.dry_run

    # Defensive act of cleaning up empty optional args
    if args.db_path:
        args.db_path = args.db_path.strip()
    if args.taxidlist:
        args.taxidlist = args.taxidlist.strip()
    if args.params:
        args.params = args.params.strip()
    return args


def log_args(args):
    print(f'DB: {args.db}')
    if args.db_path:
        print(f'DB_PATH: {args.db_path}')
    print(f'DB_SOURCE: {args.source}')
    print(f'QUERY_BATCH: {args.query}')
    print(f'ELB_NUM_CPUS: {args.num_threads}')
    print(f'PROGRAM: {args.program}')
    print(f'BUCKET: {args.bucket}')
    print(f'DB_MOL_TYPE: {args.db_mol_type}')
    if args.taxidlist:
        print(f'TAXIDLIST: {args.taxidlist}')
    if args.params:
        print(f'BLAST_PARAMS: {args.params}')
    if is_aws_instance():
        print(f'INSTANCE_ID: {ec2_metadata.instance_id}')
        print(f'INSTANCE_TYPE: {ec2_metadata.instance_type}')
        print(f'PUBLIC_HOSTNAME: {ec2_metadata.public_hostname}')
    
    # Print tool versions
    if shutil.which('aws'):
        p = safe_exec("aws --version")
        if len(p.stdout.decode()):
            print(f'AWS CLI version: {p.stdout.decode().strip()}')

    if shutil.which('blastn'):
        p = safe_exec("blastn -version")
        if len(p.stdout.decode()):
            print(f'BLAST+ version: {p.stdout.decode().strip()}')
    if shutil.which('update_blastdb.pl'):
        p = safe_exec("update_blastdb.pl --version")
        if len(p.stdout.decode()):
            print(f'update_blastdb.pl version: {p.stdout.decode().strip().split()[-1]}')

def log_disk_usage():
    """ Logs the diks utilization of the current working directory """
    BYTES_PER_GB = 1024 * 1024 * 1024
    cwd = os.getcwd()
    (total, used, free) = shutil.disk_usage(cwd)
    print(f'Disk capacity at {cwd}: {(float(total)/BYTES_PER_GB):.2f} GB')
    print(f'Disk used at {cwd}: {(float(used)/BYTES_PER_GB):.2f} GB')
    print(f'Disk free at {cwd}: {(float(free)/BYTES_PER_GB):.2f} GB')


def _download_database(args, is_user, db_done):
    """ Download the database if it's not downloaded already,
        test database integrity """
    if os.path.exists(db_done):
        return
    log_disk_usage()
    print('Start database download')
    verbose = ' --verbose --verbose --verbose --verbose --verbose --verbose' if args.verbose else ''
    creds = ' --no-sign-request' if args.no_creds else ''
    nprocs_to_download_db = min(MAX_PROCS_TO_DOWNLOAD_DB, int(os.cpu_count()/args.num_threads))
    p = safe_exec(f"time update_blastdb.pl taxdb --decompress --source {args.source}{verbose} --num_threads {nprocs_to_download_db}")
    print(p.stdout.decode(), end='')
    print(p.stderr.decode(), end='')
    if is_user:
        p = safe_exec(f"time aws s3 cp --only-show-errors{creds} {os.path.join(args.db_path,'')} . --recursive --exclude * --include {args.db}.* --include taxdb.*")
        print(p.stdout.decode(), end='')
        print(p.stderr.decode(), end='')
    else:
        p = safe_exec(f"time update_blastdb.pl {args.db} --decompress --source {args.source}{verbose} --num_threads {nprocs_to_download_db}")
        print(p.stdout.decode(), end='')
        print(p.stderr.decode(), end='')
    print('End database download')
    #print(db_done) # DEBUG
    test_database(args)
    if not dry_run:
        with open(db_done, 'w'): pass


def download_database(args):
    """ Decide whether we have standard or user database,
        prepare parameters accordingly, call download proper
        in the mutually exclusive manner """
    is_user = False
    if args.db_path and args.db_path != 'None':
        is_user = True
        # User database
        digest = md5(os.path.join(args.db_path, args.db).encode())
        prefix = 'custom_blastdb_' + digest.hexdigest()
        db_done = prefix + '.done'
        db_lock = prefix + '.lock'
    else:
        # Standard BLAST database
        db_done = args.db + '.done'
        db_lock = args.db + '.lock'
    with filelock.FileLock(db_lock):
        _download_database(args, is_user, db_done)


def test_database(args):
    print('Start database check')
    p = safe_exec(f'blastdbcmd -info -db {args.db} -dbtype {args.db_mol_type}')
    print(p.stdout.decode(), end='')
    print(p.stderr.decode(), end='')
    verbosity = ' -verbosity 4' if args.verbose else ''
    p = safe_exec(f'blastdbcheck -db {args.db} -dbtype {args.db_mol_type} -no_isam -ends 5{verbosity}')
    print(p.stdout.decode(), end='')
    print(p.stderr.decode(), end='')
    print('End database check')


def vmtouch_database(args):
    if args.no_vmtouch:
        return
    # Check if database is memory-mapped (testing purposes)
    print('Start database vmtouch')
    p = safe_exec(f'blastdb_path -dbtype {args.db_mol_type} -db {args.db} -getvolumespath')
    dbfiles = p.stdout.decode()
    p = safe_exec(f'vmtouch -m{MAX_BLASTDB_FILE_SIZE}G {dbfiles}')
    print(p.stdout.decode())
    print('End database vmtouch')


def download_taxidlist(args):
    if not args.taxidlist or args.taxidlist == 'NONE':
        return
    creds = ' --no-sign-request' if args.no_creds else ''
    if dry_run:
        print('Start taxidlist copy')
        print(f"aws s3 cp {args.taxidlist} . --quiet{creds}")
        print('End taxidlist copy')
        return
    hashsum = md5(args.taxidlist.encode()).hexdigest()
    sanitized = re.sub(r'[^-A-Za-z0-9.]', '-', args.taxidlist)
    taxidlist_base = f"taxidlist_{sanitized}_{hashsum}"
    taxidlist_done = taxidlist_base + '.done'
    taxidlist_lock = taxidlist_base + '.lock'
    with filelock.FileLock(taxidlist_lock):
        if os.path.exists(taxidlist_done):
            return
        log_disk_usage()
        print('Start taxidlist copy')
        safe_exec(f"aws s3 cp {args.taxidlist} . --quiet{creds}")
        print('End taxidlist copy')
        with open(taxidlist_done, 'w'): pass


def _do_download_query(args):
    """ Download query from S3 if it's not downloaded already,
        unzip it if necessary
        Parameters:
            args - all arguments to the script, uses
                   args.query and args.no_creds
        Returns:
            name of the file with unzipped query
    """
    creds = ' --no-sign-request' if args.no_creds else ''
    log_disk_usage()
    print("Start query download")
    safe_exec(f'aws s3 cp {args.query} .{creds}')
    print("End query download")
    # Unzip query if needed
    query = os.path.basename(args.query)
    name, ext = os.path.splitext(query)
    if ext == '.gz':
        print('Start query unpacking')
        safe_exec(f'gunzip -f {query}')
        print('End query unpacking')
        query = name
    return query


def _download_query(args):
    """ Download and unzip query file as needed
        in mutually exclusive manner
        Returns:
            name of the unzipped query file
    """
    query_lock = 'query.lock'
    with filelock.FileLock(query_lock):
        query = _do_download_query(args)
    return query


def prepare_query(args):
    """ Download, optionally unzip, and split query
        Return query file name
    """
    query = _download_query(args)
    if args.split_part < 0 or args.num_parts <= 0:
        return query
    name, ext = os.path.splitext(query)
    query_split_lock = 'query-split.lock'
    query_split_done = 'query-split.done'
    with filelock.FileLock(query_split_lock):
        part_ord = args.split_part
        part_num_prefix = '_'
        part_dir = 'parts'
        if not os.path.exists(query_split_done):
            print(f'Start query splitting, part {part_ord} of {args.num_parts}')
            safe_exec(f'python3 {script_dir}/fasta-split --output {part_dir} --n-parts {args.num_parts} {query}')
            print('End query splitting')
            if not dry_run:
                with open(query_split_done, 'w'): pass
#        return f'{part_dir}/batch{part_num_prefix}{part_ord:03d}{ext}'
        return f'{part_dir}/batch{part_num_prefix}{part_ord:03d}.fa'


def run_search(args, query):
    # Do the BLAST search
    if not args.run_search:
        # Prevent too fast jobs to overload AWS Batch communication
        time.sleep(10)
        return '' 
    log_disk_usage()
    name, ext = os.path.splitext(query)
    fn_out = f'{name}-{args.program}-{args.db}.out'
    params = shlex.split(args.params) if args.params else []
    cmd = ['time', args.program, '-query', query, '-db', args.db,
        '-num_threads', str(args.num_threads), '-out', fn_out] + params
    print('Start blast search')
    p = safe_exec(cmd)
    print(p.stdout.decode(), end='')
    print(p.stderr.decode(), end='')
    print('End blast search')

    # Print first few lines of the results (testing purposes)
    if not dry_run:
        with open(fn_out, 'rt') as f:
            nline = 0
            for line in f:
                print(line, end='')
                nline += 1
                if nline > 10:
                    break
    safe_exec(f'gzip -f {fn_out}')
    return fn_out + '.gz'


def upload_results(args, fn):
    # Copy results to S3
    if args.no_creds:
        # Simulate uploading by copying to another file with
        # bucket name prepended
        mo = re.match(r'(?:s3://)?(.+)', args.bucket)
        bucket = mo.group(1)
        safe_exec(f'mkdir -p {bucket}')
        safe_exec(f'cp {fn} {os.path.join(bucket, os.path.basename(fn))}')
        return
    print('Start copy results')
    safe_exec(f'aws s3 cp {fn} {args.bucket}/')
    print('End copy results')


def upload_error_if_not_present(args, message):
    """Upload the error file to results bucket"""
    # To avoid synchronization problems, and given that in many cases the all
    # jobs will have the same error, upload the error file only if one is not
    # already present
    BUCKET_ERROR_FILE = 'metadata/FAILURE.txt'
    cmd = f'aws s3 ls {args.bucket}/{BUCKET_ERROR_FILE}'
    try:
        safe_exec(cmd)
    except SafeExecError:
        with tempfile.NamedTemporaryFile() as f:
            f.write(message.encode())
            f.flush()
            cmd = f'aws s3 cp --only-show-errors {f.name} {args.bucket}/{BUCKET_ERROR_FILE}'
            safe_exec(cmd)


def cleanup(*files2delete):
    for f in files2delete:
        file2rm = Path(f)
        if file2rm.exists() and file2rm.is_file():
            logging.debug(f"Deleting {file2rm}")
            file2rm.unlink()


def main(argv):
    args = parse_args(argv)
    print('Start execution')
    logging.debug(' '.join(map(lambda x: "'"+x+"'" if ' ' in x else x,argv)))
    log_args(args)
    if dry_run:
        print(f'mkdir -p {args.workdir}')
        print(f'cd {args.workdir}')
    else:
        try:
            os.mkdir(args.workdir)
        except FileExistsError:
            pass
        os.chdir(args.workdir)
    fn_result = ''
    local_query_file = ''
    try:
        download_database(args)
    except SafeExecError as e:
        print(e.message, file=sys.stderr)
        return BLASTDB_ERROR
    try:
        vmtouch_database(args)
        download_taxidlist(args)
        local_query_file = prepare_query(args)
        fn_result = run_search(args, local_query_file)
        if fn_result:
            upload_results(args, fn_result)
        cleanup(local_query_file, fn_result)
    except SafeExecError as e:
        print(e.message)
        upload_error_if_not_present(args, e.message)
        print('End execution, exception raised')
        cleanup(local_query_file, fn_result)
        return e.returncode
    print('End execution')
    return 0

if __name__ == '__main__':
    sys.exit(main(sys.argv))

# vim: set syntax=python ts=4 et :
