#!/usr/bin/env python3
#                           PUBLIC DOMAIN NOTICE
#              National Center for Biotechnology Information
#  
# This software is a "United States Government Work" under the
# terms of the United States Copyright Act.  It was written as part of
# the authors' official duties as United States Government employees and
# thus cannot be copyrighted.  This software is freely available
# to the public for use.  The National Library of Medicine and the U.S.
# Government have not placed any restriction on its use or reproduction.
#   
# Although all reasonable efforts have been taken to ensure the accuracy
# and reliability of the software and data, the NLM and the U.S.
# Government do not and cannot warrant the performance or results that
# may be obtained by using this software or data.  The NLM and the U.S.
# Government disclaim all warranties, express or implied, including
# warranties of performance, merchantability or fitness for any particular
# purpose.
#   
# Please cite NCBI in any work or product based on this material.

"""
Split FASTA file into smaller chunks

File can be on local filesystem, available by URL,
or reside in AWS bucket (S3). File can be compressed and/or
archived (contents of all files in the archive is treated as
one large merged file). Following combinations are recognized:
.gz, .tar, .tar.gz, .tgz, .tar.bz2 .

Author: Victor Joukov joukovv@ncbi.nlm.nih.gov
"""

import sys
import argparse
import os
import io
import gzip, tarfile, re, tempfile, shutil
import logging
import urllib.request
from string import digits
from random import sample
import boto3  # type: ignore
from botocore import UNSIGNED as BotoUNSIGNED  # type: ignore
from botocore.config import Config as BotoConfig  # type: ignore
from typing import Dict, IO, Generator, Union, List, Iterable, TextIO, Tuple

DEFAULT_PART_LEN    = 5000000
DEFAULT_OUT_PATH     = 'parts'

manifest_file = sys.stdout

boto_config = None


def parse_arguments():
    parser = argparse.ArgumentParser(description="Split FASTA file")
    parser.add_argument('input', help='FASTA file, possible gzipped')
    parser.add_argument('-N', '--n-parts', type=int, default=-1,
        help='Divide into N_PARTS parts, default: disabled')
    parser.add_argument('-l', '--part-size', type=int, default=DEFAULT_PART_LEN,
        help=f'Divide into parts of size PART_SIZE, default: {DEFAULT_PART_LEN}')
    parser.add_argument('-o', '--output',    default=DEFAULT_OUT_PATH,
        help=f'Output directory for split FASTA files, default: {DEFAULT_OUT_PATH}')
    parser.add_argument('-m', '--manifest',  default='',
        help='Manifest file containing list of split FASTA files generated')
    parser.add_argument('-c', '--count',  default='',
        help='File to report total number of bases/residues in input file')
    parser.add_argument("--no-creds", action='store_true', default=False,
                        help="Read from public buckets without credentials")
    parser.add_argument("-n", "--dry-run", action='store_true', default=False,
                        help="Do not run any commands, just show what would be executed")
    return parser.parse_args()

def main():
    global manifest_file
    global boto_config
    args = parse_arguments()
    input_path   = args.input
    out_path     = args.output
    n_parts      = args.n_parts
    part_size    = args.part_size
    manifest     = args.manifest
    count_file   = args.count
    dry_run      = args.dry_run
    if args.no_creds:
        boto_config = BotoConfig(signature_version=BotoUNSIGNED)
    else:
        boto_config = BotoConfig()
    total_count = 0
    try:
        strict = True
        if n_parts > 0:
            with open_for_read(input_path) as s:
                reader = FASTAReader(s, part_size, out_path)
                total_count = reader.read_and_count()
                part_size = int(total_count/n_parts + 0.5)
                strict = False

        with open_for_read(input_path) as s:
            reader = FASTAReader(s, part_size, out_path, strict)
            total_count, queries = reader.read_and_cut()
        if count_file:
            if count_file == '-':
                sys.stdout.write(str(total_count)+'\n')
            else:
                with open_for_write(count_file) as f:
                    f.write(str(total_count))
        if manifest:
            manifest_text = '\n'.join(queries)+'\n'
            if manifest == '-':
                sys.stdout.write(manifest_text)
            else:
                with open_for_write(manifest) as manifest_file:
                    manifest_file.write(manifest_text)
    except FileNotFoundError as e:
        print(e, "for input file", file=sys.stderr)
        return 2
    except PermissionError as e:
        print(e, file=sys.stderr)
        return 3
    except UnicodeDecodeError as e:
        print(e, "\nPossibly missing .gz or .tar extension for compressed or archived file", file=sys.stderr)
        return 4
    except OSError as e:
        # If .gz extension is present on not gzipped file
        print(e, file=sys.stderr)
        return 5
    except tarfile.ReadError as e:
        print(e, "\nProbably not a tar file", file=sys.stderr)
        return 6
    except NotImplementedError as e:
        print(e, file=sys.stderr)
        return 7
    except Exception as e:
        # If the file is empty
        print(e, file=sys.stderr)
        return 8
    copy_to_bucket(dry_run)
    return 0

#######################################################################
"""
Module elb.split

Split FASTA file into smaller chunks

Author: Victor Joukov joukovv@ncbi.nlm.nih.gov
"""

def make_full_name(out_path, nchunk, suffix):
    """ Generate full name for chunk in a uniform manner """
    return os.path.join(out_path, f'batch_{nchunk:03d}.{suffix}')


def write_chunk(out_path, nchunk, buffer) -> str:
    """ Write buffer into a batch file, return file name """
    full_name = make_full_name(out_path, nchunk, 'fa')
    with open_for_write(full_name) as outf:
        outf.write(''.join(buffer))
    return full_name


class FASTAReader():
    """ Class for reading single file with FASTA sequences and cutting
    into chunks (batches) of length no longer than threshold, if possible.
    Sequences longer than threshold are written in their own chunks without
    breaks mid-sequence.
    """
    def __init__(self, f: Union[Iterable[TextIO], TextIO], batch_len: int,
                 out_path: str, strict=True):
        """Initialize an object
        Arguments:
            f: Open file handle or stream or an Iterable of open file handles
               or streams.
            n_parts: If positive, number of parts to cut into
            batch_len: Batch length in bases/residues
            out_path: Output directory to save batches
            strict: Make parts strictly less or equal in size to batch_len
        """
        self.file: Union[Iterable[TextIO], TextIO]
        if isinstance(f, io.TextIOBase):
            self.file = [f]
        else:
            self.file = f
        self.batch_len = batch_len
        self.out_path = out_path
        self.strict = strict
        self.queries: List[str] = []

        self.nchunk = 0
        self.buffer: List[str] = []
        self.seq_buffer: List[str] = []
        self.total_count = 0 # count of base/residue in all processed files
        self.chunk_count = 0 # running base/residue count for chunk
        self.seq_count   = 0 # base/residue counter in current sequence
        # To have accurate split into predefined number of parts
        # we need to have a running threshold for next part
        self.chunk_threshold = self.batch_len

    def process_chunk(self):
        if not self.buffer: return
        query_fqn = write_chunk(self.out_path, self.nchunk, self.buffer)
        self.queries.append(query_fqn)
        self.nchunk += 1
        self.buffer = []
        self.total_count += self.chunk_count
        self.chunk_count = 0
        self.chunk_threshold += self.batch_len

    def process_new_sequence(self):
        if self.strict:
            if self.chunk_count + self.seq_count > self.batch_len:
                self.process_chunk()
                self.buffer = self.seq_buffer
                self.chunk_count = self.seq_count
            else:
                self.buffer += self.seq_buffer
                self.chunk_count += self.seq_count
        else:
            if self.total_count + self.chunk_count + self.seq_count > self.chunk_threshold:
                self.buffer += self.seq_buffer
                self.chunk_count += self.seq_count
                self.process_chunk()
            else:
                self.buffer += self.seq_buffer
                self.chunk_count += self.seq_count
        self.seq_buffer = []
        self.seq_count  = 0

    def read_and_count(self) -> int:
        """ Read List of streams, parse it as FASTA, and count sequence length.
        Return the total number of bases/residues in the input
        """
        nline = 0
        count = 0
        for f in self.file:
            for line in f:
                nline += 1
                if not line: continue
                if line[0] != '>':
                    count += len(line) - 1
        return count

    def read_and_cut(self) -> Tuple[int, List[str]]:
        """ Read List of streams, parse it as FASTA, and write sequences into
        batches.
        Return the total number
        of bases/residues in the input and list of query files written
        """
        nline = 0
        for f in self.file:
            for line in f:
                nline += 1
                if not line: continue
                if line[0] == '>':
                    self.process_new_sequence()
                else:
                    self.seq_count += len(line) - 1
                self.seq_buffer.append(line)
            if len(self.seq_buffer) and not self.seq_buffer[-1].endswith('\n'):
                self.seq_buffer.append('\n')
        self.process_new_sequence()
        self.process_chunk()
        if not nline:
            error = get_error(f)
            if error:
                raise FileNotFoundError(error)
            raise Exception("Empty input file")
        return self.total_count, self.queries


#######################################################################
"""
Module filehelper

Facilitates reads and writes of text files to/from remote filesystems and read
compressed/archived text files.

Implemented variants:
  read from local, AWS S, http(s)/ftp URL
  write to local and AWS S3
  read gzip, tar/tgz/tar.gz/tar.bz2 (all files in archive merged into one)

Author: Victor Joukov joukovv@ncbi.nlm.nih.gov
"""

# Write remote files to temp directory, then move to cloud
# mapping from gs bucket place to temp dir created by open_for_write
bucket_temp_dirs: Dict[str, str] = {}

def copy_to_bucket(dry_run: bool = False):
    """ Copy files open in temp local dirs to corresponding places in cloud filesystem.
    Works in concert with open_for_write.
    """
    global bucket_temp_dirs # FIXME: remove global variables from library code
    s3 = boto3.resource('s3', config=boto_config)
    for bucket_dir, tempdir in bucket_temp_dirs.items():
        # gsutil -mq cp tempdir/* bucket_dir/
        if bucket_dir.startswith('gs:'):
            raise NotImplementedError
        elif bucket_dir.startswith('s3:'):
            if dry_run:
                logging.info(f'Copy to bucket prefix {bucket_dir}')
            else:
                bucket_name, prefix = parse_bucket_name_key(bucket_dir)
                bucket = s3.Bucket(bucket_name)
                num_files = len(os.listdir(path=tempdir))
                for i, fn in enumerate(os.listdir(path=tempdir)):
                    if prefix:
                        full_name = prefix+'/'+fn
                    else:
                        full_name = fn
                    perc_done = i / num_files * 100.
                    logging.debug(f'Uploading {os.path.join(tempdir, fn)} to s3://{bucket_name}/{full_name} file # {i} of {num_files} {perc_done:.2f}% done')
                    bucket.upload_file(os.path.join(tempdir, fn), full_name)
        else:
            raise ValueError(f'Incorrect bucket prefix {bucket_dir}')


def random_filename():
    return f'.random-probe-{"".join(sample(digits, 10))}'


def check_dir_for_write(dirname: str, dry_run=False) -> None:
    """ Check that path on local or GS filesystem can be written to.
        raises PermissionError if write is not possible
    """
    # GS
    if dirname.startswith('gs:'):
        raise NotImplementedError
    # AWS
    elif dirname.startswith('s3:'):
        # TODO: implement the write test, see EB-491
        test_file_name = os.path.join(dirname, random_filename())
        # try to write to this file
        # if not possible, raise PermissionError(e.returncode, e.message)
        return
    # Local file system
    test_file_name = os.path.join(dirname, random_filename())
    if dry_run:
        logging.info(f'Trying to write file {test_file_name}')
    try:
        with open(test_file_name, 'w'): pass
        os.remove(test_file_name)
    except:
        raise PermissionError()


def open_for_write(fname):
    """ Open file on either local (no prefix) or AWS S3 (s3:// prefix) filesystem
    for write in text mode.
    """
    global bucket_temp_dirs
    if fname.startswith('gs:'):
        raise NotImplementedError
    if fname.startswith('s3:'):
        # for the same gs path open files in temp dir and put it into
        # bucket_temp_dirs dictionary, copy through to bucket in copy_to_bucket later
        last_slash = fname.rfind('/')
        if last_slash == -1:
            raise "Incorrect bucket path %s" % fname
        bucket_dir = fname[:last_slash]
        filename = fname[last_slash+1:]
        if bucket_dir in bucket_temp_dirs:
            tempdir = bucket_temp_dirs[bucket_dir]
        else:
            tempdir = tempfile.mkdtemp()
            bucket_temp_dirs[bucket_dir] = tempdir
        return open(os.path.join(tempdir, filename), 'wt')
    # file on a regular filesystem
    last_sep = fname.rfind('/')
    if last_sep > 0:
        path = fname[:last_sep]
        os.makedirs(path, exist_ok=True)
    return open(fname, 'wt')


def tar_reader(tar):
    """ Helper generator for reading all files in a tar file as single
    stream of lines.
    """
    for tarinfo in tar:
        if not tarinfo.isfile():
            continue
        f = tar.extractfile(tarinfo)
        # Add missing function for TextIOWrapper
        f.seekable = lambda : False
        f = io.TextIOWrapper(f)
        for line in f:
            yield line


class TarMerge(io.TextIOWrapper):
    """ Wrapper for tar_reader generator, implements
    - iterator for reading as if it is a file, and
    - context manager so it can release resources when
      used in 'with' construct
    """
    def __init__(self, tar):
        self.tar = tar
        self.reader = tar_reader(tar)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        return next(self.reader)

    def __enter__(self):
        return self
    
    def __exit__(self, *args):
        self.tar.close()
        return False
    
    def read(self):
        return ''.join(list(self))
    
    def readline(self):
        return next(self.reader)


# FIXME: this function returns object of three possible classes and there
# are typing problems, which may indicate incompatibility between classes.
# We need tests checking downstream logic for these different return types.
# (EB-340)
def unpack_stream(s:IO, gzipped:bool, tarred:bool) -> IO:
    """ Helper function which inserts uncompressing/unarchiving
    transformers as needed depending on detected file type
    """
    if tarred:
        tar = tarfile.open(fileobj=s, mode='r|*')
        return TarMerge(tar)
    # type checking in the line below is ignored because io.TextIOWrapper
    # conflicts with gzip.GzipFile, but duck-typing-wise everthing seems fine
    return io.TextIOWrapper(gzip.GzipFile(fileobj=s)) if gzipped else s   #type: ignore


def check_for_read(fname: str, dry_run=False) -> None:
    """ Check that path on local, AWS S3 or URL-available filesystem can be read from.
    raises FileNotFoundError if there is no such file
    """
    if fname.startswith('gs:'):
        raise NotImplementedError
    if fname.startswith('s3:'):
        if dry_run:
            logging.info(f'Open S3 file {fname}')
            return
        s3 = boto3.resource('s3', config=boto_config)
        bucket, key = parse_bucket_name_key(fname)
        obj = s3.Object(bucket, key)
        obj.load()
        return
    if fname.startswith('http') or fname.startswith('ftp:'):
        if dry_run:
            logging.info(f'Open URL request for {fname}')
            return
        req = urllib.request.Request(fname, method='HEAD')
        try:
            urllib.request.urlopen(req)
        except:
            raise FileNotFoundError()
        return
    open(fname, 'r')


def get_length(fname: str, dry_run=False) -> int:
    """ Get length of a path on local, AWS S3, or URL-available filesystem.
    raises FileNotFoundError if there is no such file
    """
    if fname.startswith('gs:'):
        raise NotImplementedError
    if fname.startswith('s3:'):
        if dry_run:
            logging.info(f'Check length of S3 file {fname}')
            return 10000
        s3 = boto3.resource('s3', config=boto_config)
        bucket, key = parse_bucket_name_key(fname)
        try:
            obj = s3.Object(bucket, key)
            obj.load()
            return obj.content_length
        except:
            raise FileNotFoundError()
    if fname.startswith('http') or fname.startswith('ftp:'):
        if dry_run:
            logging.info(f'Check length of URL {fname}')
            return 10000
        req = urllib.request.Request(fname, method='HEAD')
        try:
            obj = urllib.request.urlopen(req)
            return int(obj.headers['Content-Length'])
        except:
            raise FileNotFoundError()
    return os.stat(fname).st_size

error_report_funcs = {}

def open_for_read(fname):
    """ Open path for read on local, GS, or URL-available filesystem defined by prefix.
    File can be gzipped, and archived with tar.
    """
    global error_report_funcs
    gzipped = fname[-3:] == ".gz"
    tarred = re.match(r'^.*\.(tar(|\.gz|\.bz2)|tgz)$', fname)
    binary = gzipped or tarred
    mode = 'rb' if binary else 'rt'
    if fname.startswith('gs:'):
        raise NotImplementedError
    if fname.startswith('s3:'):
        s3 = boto3.client('s3', config=boto_config)
        bucket, key = parse_bucket_name_key(fname)
        resp = s3.get_object(Bucket=bucket, Key=key)
        body = resp['Body']
        body.readable = lambda: True
        body.writable = lambda: False
        body.seekable = lambda: False
        body.closed = False
        body.flush = lambda: None
        if tarred or gzipped:
            fileobj = unpack_stream(body, gzipped, tarred)
        else:
            fileobj = io.TextIOWrapper(body)
        return fileobj
    if fname.startswith('http') or fname.startswith('ftp:'):
        response = urllib.request.urlopen(fname)
        return unpack_stream(response, gzipped, tarred)
    # regular file
    f = open(fname, mode)
    return unpack_stream(f, gzipped, tarred)


def open_for_read_iter(fnames: Iterable[str]) -> Generator[TextIO, None, None]:
    """Generator function that Iterates over paths/uris and open them for
    reading.

    Arguments:
        An iterable with paths to open

    Returns:
        Generator of files open for reading"""
    for fname in fnames:
        with open_for_read(fname) as f:
            yield f
        

def get_error(fileobj):
    global error_report_funcs
    func = error_report_funcs.get(fileobj)
    if func: return func()
    return ''


def parse_bucket_name_key(fname: str) -> Tuple[str, str]:
    """ Parse S3 or GS uri name into bucket and key.
    Parameters:
        fname - S3 or GS full name with possible s3:// or gs:// prefix,
                i.e. both names s3://test-bucket/file_name.ext and
                test-bucket/file_name.ext are valid
    Returns:
        tuple of bucket name and the rest of the name (key in AWS parlance), e.g.
        for example above ('test-bucket', 'file_name.ext')
    """
    bare_name = fname
    if fname.startswith('s3://') or fname.startswith('gs://'):
        bare_name = fname[5:]
    parts = bare_name.split('/')
    bucket = parts[0]
    key = '/'.join(parts[1:])
    return bucket, key


# Code from https://stackoverflow.com/posts/63045786/revisions
import struct
def get_unpacked_size_s3_gz(bucket, key):
    s3_client = boto3.client("s3", config=boto_config)

    compressed_size = s3_client.get_object(Bucket=bucket, Key=key)["ContentLength"]
    last_four_bytes = s3_client.get_object(
        Bucket=bucket,
        Key=key,
        Range=f"bytes={compressed_size-4}-{compressed_size}"
    )["Body"]
    return struct.unpack("I", last_four_bytes.read(4))[0]


if __name__ == "__main__":
    sys.exit(main())
