import argparse
import json
from pathlib import Path
from typing import List, Dict

TASKS = [
    "filter-langs",
    "merge-paragraphs-into-documents",
    "split-documents-into-paragraphs",
    "verify-submission",
]


def read_arguments():
    parser = argparse.ArgumentParser(
        description="Filter a blindset for specific languages, split documents into paragraphs, merge paragraphs into documents."
    )

    parser.add_argument(
        "--task",
        type=str,
        choices=TASKS,
        required=True,
        help="Task to perform: filter-langs, merge-paragraphs-into-documents, split-documents-into-paragraphs, or verify-submission",
    )

    parser.add_argument(
        "--langs",
        type=str,
        nargs="+",
        default=[],
        help="List of languages to include in the filtered blindset",
    )
    parser.add_argument(
        "--input-blindset",
        type=str,
        help="Path to the input blindset file",
    )
    parser.add_argument(
        "--input-submission",
        type=str,
        help="Path to the input submission file",
    )
    parser.add_argument(
        "--output-blindset",
        type=str,
        help="File to save the output blindset after processing",
    )
    parser.add_argument(
        "--output-submission",
        type=str,
        help="File to save the output submission after processing",
    )

    return parser.parse_args()


def read_input_data(file_path: str) -> List[Dict]:
    blindset = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                blindset.append(json.loads(line.strip()))
    return blindset


def split_documents_into_paragraphs(blindset: List[Dict]) -> List[Dict]:
    """
    Extract paragraphs (separated by double newlines) from the blindset documents.
    Paragraphs will be stored in correspondence of the key "source_paragraphs" as lists of strings.
    This function modifies the blindset in place, adding the "source_paragraphs" key to each document.
    Each paragraph is stripped of leading and trailing whitespaces.
    """

    for document in blindset:
        src_text = document.get("src_text", None)
        if not src_text:
            raise ValueError(
                f"Document with ID {document.get('doc_id', None)} has no 'src_text' field."
            )
        src_paragraphs = src_text.strip().split("\n\n")

        document["source_paragraphs"] = [para.strip() for para in src_paragraphs]

    return blindset


def merge_paragraphs_into_documents(submission: List[Dict]) -> List[Dict]:
    """
    Merge the paragraphs of a submission into documents.
    Each document is reconstructed from its paragraphs, concatenating them with double newlines.
    The hypothesis field of each document is updated with the merged text.
    """

    for document in submission:
        paragraphs = document.get("hypothesis", [])
        if not paragraphs:
            raise ValueError(
                f"Document with ID {document.get('doc_id', None)} has no hypothesis."
            )

        if isinstance(paragraphs, str):
            continue
        elif isinstance(paragraphs, list):
            # Concatenate paragraphs with double newlines, terminating with a newline
            hypothesis = "\n\n".join(
                para.strip() for para in paragraphs if para.strip()
            )

            # Add information that the granularity of translation was paragraph-level
            document["granularity"] = "paragraph-level"
            document["hypothesis"] = hypothesis
        else:
            raise ValueError(
                f"Document with ID {document.get('doc_id', None)} has 'hypothesis' field of unexpected type: {type(paragraphs)}."
            )

    return submission


def filter_blindset_by_languages(
    blindset: List[Dict], languages: List[str]
) -> List[Dict]:
    """
    Filter the blindset to include only documents related to the specified languages.
    A document is included if either its source or target language codes match any of the specified languages.
    """
    filtered_blindset = []
    languages = [lang.lower() for lang in languages]

    for document in blindset:
        src_lang_code = document.get("src_lang", "").lower()
        tgt_lang_code = document.get("tgt_lang", "").lower()

        # Check if any of the language fields match the specified languages
        if src_lang_code in languages or tgt_lang_code in languages:
            filtered_blindset.append(document)

    return filtered_blindset


def verify_submission(input_blindset: List[Dict], input_submission: List[Dict]):

    print(f"Verifying submission against blindset...")

    # Assert that document ids are unique (len(set) == len(list))
    assert len(set(doc["doc_id"] for doc in input_blindset if "doc_id" in doc)) == len(
        input_blindset
    ), "Document IDs in blindset are not unique or not all documents have a 'doc_id' field."
    assert len(
        set(doc["doc_id"] for doc in input_submission if "doc_id" in doc)
    ) == len(
        input_submission
    ), "Document IDs in submission are not unique or not all documents have a 'doc_id' field."

    # Create mappings for quick lookup
    blindset_docs = {doc["doc_id"]: doc for doc in input_blindset}
    submission_docs = {doc["doc_id"]: doc for doc in input_submission}

    # Check for missing documents
    missing_doc_ids = set(blindset_docs.keys()) - set(submission_docs.keys())
    missing_docs = [blindset_docs[doc_id] for doc_id in missing_doc_ids]

    # Track missing documents statistics
    missing_by_domain = {}
    missing_by_src_lang = {}
    missing_by_tgt_lang = {}
    total_by_domain = {}
    total_by_src_lang = {}
    total_by_tgt_lang = {}

    # Count totals by categories
    for doc in input_blindset:
        domain = doc.get("domain", "unknown")
        src_lang = doc.get("src_lang", "unknown")
        tgt_lang = doc.get("tgt_lang", "unknown")

        total_by_domain[domain] = total_by_domain.get(domain, 0) + 1
        total_by_src_lang[src_lang] = total_by_src_lang.get(src_lang, 0) + 1
        total_by_tgt_lang[tgt_lang] = total_by_tgt_lang.get(tgt_lang, 0) + 1

    # Count missing documents by categories
    for doc in missing_docs:
        domain = doc.get("domain", "unknown")
        src_lang = doc.get("src_lang", "unknown")
        tgt_lang = doc.get("tgt_lang", "unknown")

        missing_by_domain[domain] = missing_by_domain.get(domain, 0) + 1
        missing_by_src_lang[src_lang] = missing_by_src_lang.get(src_lang, 0) + 1
        missing_by_tgt_lang[tgt_lang] = missing_by_tgt_lang.get(tgt_lang, 0) + 1

    # Check paragraph counts for documents in both sets
    paragraph_mismatches = []
    for doc_id, blindset_doc in blindset_docs.items():
        if doc_id in submission_docs:
            submission_doc = submission_docs[doc_id]

            # Count paragraphs in blindset document
            blindset_text = blindset_doc.get("src_text", "").strip()
            blindset_paragraphs = blindset_text.split("\n\n")
            blindset_paragraph_count = len(blindset_paragraphs)

            # Count paragraphs in submission document
            hypothesis = submission_doc.get("hypothesis", None)
            if hypothesis is None:
                raise ValueError(
                    f"Document {doc_id} in submission is missing the 'hypothesis' field."
                )

            if isinstance(hypothesis, list):
                raise ValueError(
                    f"Document {doc_id} in submission has 'hypothesis' as a list, but it should be a string. Use the 'merge-paragraphs-into-documents' task of this script to merge paragraphs into documents."
                )
            elif isinstance(hypothesis, str):
                submission_paragraphs = hypothesis.strip().split("\n\n")
                submission_paragraph_count = len(submission_paragraphs)
            else:
                raise ValueError(
                    f"Unexpected type for hypothesis in document {doc_id}: {type(hypothesis)}. Expected a string."
                )

            # Check if paragraph counts match
            if blindset_paragraph_count != submission_paragraph_count:
                paragraph_mismatches.append(
                    {
                        "doc_id": doc_id,
                        "blindset_paragraphs": blindset_paragraph_count,
                        "submission_paragraphs": submission_paragraph_count,
                    }
                )

    # Print verification report
    print("\nVerification Report:")
    print(f"Total documents in blindset: {len(input_blindset)}")
    print(f"Total documents in submission: {len(input_submission)}")
    print(f"Missing documents: {len(missing_docs)}")

    if missing_docs:
        print("\nMissing documents by domain:")
        for domain, count in missing_by_domain.items():
            total = total_by_domain.get(domain, 0)
            print(f"- {count}/{total} documents from the {domain} domain are missing")

        print("\nMissing documents by source language:")
        for lang, count in missing_by_src_lang.items():
            total = total_by_src_lang.get(lang, 0)
            print(
                f"- {count}/{total} documents with source language {lang} are missing"
            )

        print("\nMissing documents by target language:")
        for lang, count in missing_by_tgt_lang.items():
            total = total_by_tgt_lang.get(lang, 0)
            print(
                f"- {count}/{total} documents with target language {lang} are missing"
            )

    print(f"\nParagraph count mismatches: {len(paragraph_mismatches)}")
    if paragraph_mismatches:
        print("Documents with paragraph count mismatches:")
        for mismatch in paragraph_mismatches[:10]:  # Show first 10 mismatches
            print(
                f"- Document ID: {mismatch['doc_id']}, Blindset paragraphs: {mismatch['blindset_paragraphs']}, Submission paragraphs: {mismatch['submission_paragraphs']}"
            )
        if len(paragraph_mismatches) > 10:
            print(f"  ... and {len(paragraph_mismatches) - 10} more mismatches")

    if not missing_docs and not paragraph_mismatches:
        print(
            "✓ Verification successful! The submission contains all documents with the correct number of paragraphs."
        )
    else:
        print("✗ Verification failed. Please check the report above for details.")


def write_blindset(blindset: List[Dict], output_file: Path):
    if not output_file.parent.exists():
        output_file.parent.mkdir(parents=True, exist_ok=True)

    with open(output_file, "w", encoding="utf-8") as f:
        for document in blindset:
            f.write(json.dumps(document, ensure_ascii=False) + "\n")


def main():
    args = read_arguments()

    if not args.task == "merge-paragraphs-into-documents":
        input_blindset_path = Path(args.input_blindset)
        if not input_blindset_path.exists():
            raise FileNotFoundError(
                f"Input file {input_blindset_path} does not exist. Please provide a valid file path."
            )

        input_blindset = read_input_data(input_blindset_path)

    if args.task == "filter-langs":
        if not args.output_blindset:
            raise ValueError(
                "Output blindset file must be specified for the 'filter-langs' task."
            )
        output_blindset_path = Path(args.output_blindset)
        print(f"Filtering blindset for languages: {', '.join(args.langs)}")
        print(f"Original blindset: {len(input_blindset)} documents")

        output_blindset = filter_blindset_by_languages(input_blindset, args.langs)
        print(f"Filtered blindset: {len(output_blindset)} documents")

        write_blindset(output_blindset, output_blindset_path)
        print(f"Filtered blindset saved to {output_blindset_path}")

    elif args.task == "merge-paragraphs-into-documents":

        if not args.input_submission or not args.output_submission:
            raise ValueError(
                "Input submission and output submission files must be specified for the 'merge-paragraphs-into-documents' task."
            )

        input_submission_path = Path(args.input_submission)
        output_submission_path = Path(args.output_submission)

        if not input_submission_path.exists():
            raise FileNotFoundError(
                f"Input submission file {input_submission_path} does not exist."
            )

        input_submission = read_input_data(input_submission_path)
        output_submission = merge_paragraphs_into_documents(input_submission)
        write_blindset(output_submission, output_submission_path)

    elif args.task == "split-documents-into-paragraphs":

        if not args.output_blindset:
            raise ValueError(
                "Output blindset file must be specified for the 'split-documents-into-paragraphs' task."
            )
        output_blindset_path = Path(args.output_blindset)
        print("Splitting documents into paragraphs...")
        output_blindset = split_documents_into_paragraphs(input_blindset)
        write_blindset(output_blindset, output_blindset_path)
        print(
            f"Blindset with documents split into paragraphs saved to {output_blindset_path}"
        )

    elif args.task == "verify-submission":
        if not args.input_submission:
            raise ValueError(
                "Input submission file must be specified for the 'verify-submission' task."
            )

        input_submission_path = Path(args.input_submission)

        if not input_submission_path.exists():
            raise FileNotFoundError(
                f"Input submission file {input_submission_path} does not exist."
            )

        # Read the submission data
        input_submission = read_input_data(input_submission_path)
        verify_submission(input_blindset, input_submission)

    else:
        print(
            f"Error: Unknown task '{args.task}'. Available tasks are: {', '.join(TASKS)}"
        )
        return


if __name__ == "__main__":
    main()
